Skip to content

Commit

Permalink
Merge pull request #9 from lightonai/attend_expansion_tokens
Browse files Browse the repository at this point in the history
Adding the attend_to_expansion_tokens option
  • Loading branch information
raphaelsty authored Jun 17, 2024
2 parents 9ba70a7 + 4e611ed commit 21437e6
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions giga_cherche/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
add_special_tokens: Optional[bool] = True,
truncation: Optional[bool] = True,
query_length: Optional[int] = 32,
attend_to_expansion_tokens: Optional[bool] = False,
):
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
Expand All @@ -204,6 +205,7 @@ def __init__(
self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_length = query_length
self.attend_to_expansion_tokens = attend_to_expansion_tokens
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v3 of SentenceTransformers.",
Expand Down Expand Up @@ -634,8 +636,10 @@ def encode(
):
last_mask_id = len(attention) - 1
# TODO: isn't torch.sum(attention) better?
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1
# We do not want to prune expansion tokens in queries even if we do not attend to them in attention layers
if not is_query:
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1
# TODO: normalize at the list level/use the module Normalize?
if normalize_embeddings:
token_emb = torch.nn.functional.normalize(
Expand Down Expand Up @@ -1105,16 +1109,16 @@ def tokenize(
features["token_type_ids"] = features["token_type_ids"][
:, : self.query_length
]
# Fill the attention mask with ones (we attend to "padding" tokens)
features["attention_mask"].fill_(1)
# features["attention_mask"] = torch.ones_like(features["attention_mask"])
# In the original ColBERT, the original tokens do not attend to the expansion tokens (but the expansion tokens attend to original tokens)
if self.attend_to_expansion_tokens:
# Fill the attention mask with ones (we attend to "padding" tokens used for expansion)
features["attention_mask"].fill_(1)
return features
else:
features = self._first_module().tokenize(texts)
# Remplace the second token by the document prefix
features["input_ids"][:, 1] = self.document_prefix_id
return features
return self._first_module().tokenize(texts)

def get_sentence_features(self, *features):
return self._first_module().get_sentence_features(*features)
Expand Down Expand Up @@ -1222,6 +1226,7 @@ def save(
config["query_prefix"] = self.query_prefix
config["document_prefix"] = self.document_prefix
config["query_length"] = self.query_length
config["attend_to_expansion_tokens"] = self.attend_to_expansion_tokens
json.dump(config, fOut, indent=2)

# Save modules
Expand Down Expand Up @@ -1627,6 +1632,10 @@ def _load_sbert_model(
self.document_prefix = self._model_config["document_prefix"]
if "query_length" in self._model_config:
self.query_length = self._model_config["query_length"]
if "attend_to_expansion_tokens" in self._model_config:
self.attend_to_expansion_tokens = self._model_config[
"attend_to_expansion_tokens"
]

# Check if a readme exists
model_card_path = load_file_path(
Expand Down

0 comments on commit 21437e6

Please sign in to comment.