Skip to content

Commit

Permalink
NoDuplicatesDataLoader Compatability with Asymmetric models (#3220)
Browse files Browse the repository at this point in the history
* adapt NoDuplicatesDataLoader Compatability with Asymmetric models

* adapt NoDuplicatesBatchSampler

* Update documentation for Asym to v3+ training

---------

Co-authored-by: osama salem <[email protected]>
Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
3 people authored Feb 14, 2025
1 parent 8562d6f commit d4d198d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 12 deletions.
4 changes: 4 additions & 0 deletions sentence_transformers/datasets/NoDuplicatesDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@ def __iter__(self):

valid_example = True
for text in example.texts:
if not isinstance(text, str):
text = str(text)
if text.strip().lower() in texts_in_batch:
valid_example = False
break

if valid_example:
batch.append(example)
for text in example.texts:
if not isinstance(text, str):
text = str(text)
texts_in_batch.add(text.strip().lower())

self.data_pointer += 1
Expand Down
58 changes: 47 additions & 11 deletions sentence_transformers/models/Asym.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,53 @@ def __init__(self, sub_modules: dict[str, list[nn.Module]], allow_empty_key: boo
Note, that when you call encode(), that only inputs of the same type can be encoded. Mixed-Types cannot be encoded.
Example::
word_embedding_model = models.Transformer(model_name)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
asym_model = models.Asym({'query': [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)], 'doc': [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)]})
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, asym_model])
model.encode([{'query': 'Q1'}, {'query': 'Q2'}]
model.encode([{'doc': 'Doc1'}, {'doc': 'Doc2'}]
#You can train it with InputExample like this. Note, that the order must always be the same:
train_example = InputExample(texts=[{'query': 'Train query'}, {'doc': 'Document'}], label=1)
Example:
::
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset
# Load a SentenceTransformer model (pretrained or not), and add an Asym module
model = SentenceTransformer("microsoft/mpnet-base")
dim = model.get_sentence_embedding_dimension()
asym_model = models.Asym({
'query': [models.Dense(dim, dim)],
'doc': [models.Dense(dim, dim)]
})
model.add_module("asym", asym_model)
train_dataset = Dataset.from_dict({
"query": ["is toprol xl the same as metoprolol?", "are eyes always the same size?"],
"answer": ["Metoprolol succinate is also known by the brand name Toprol XL.", "The eyes are always the same size from birth to death."],
})
# This mapper turns normal texts into a dictionary mapping Asym keys to the text
def mapper(sample):
return {
"question": {"query": sample["question"]},
"answer": {"doc": sample["answer"]},
}
train_dataset = train_dataset.map(mapper)
loss = losses.MultipleNegativesRankingLoss(model)
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
# For inference, you can pass dictionaries with the Asym keys:
model.encode([
{'query': 'how long do you have to wait to apply for cerb?'},
{'query': '<3 what does this symbol mean?'},
{'doc': 'The definition of <3 is "Love".'}]
)
Note:
These models are not necessarily stronger than non-asymmetric models. Rudimentary experiments indicate
that non-Asym models perform better in most cases.
Args:
sub_modules: Dict in the format str -> List[models]. The
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __iter__(self) -> Iterator[list[int]]:
batch_indices = []
for index in remaining_indices:
sample_values = {
value
str(value)
for key, value in self.dataset[index].items()
if not key.endswith("_prompt_length") and key != "dataset_name"
}
Expand Down

0 comments on commit d4d198d

Please sign in to comment.