From 4e611ed4c71a1670e6d6f79b38c8f61427507007 Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Mon, 17 Jun 2024 11:58:23 +0000 Subject: [PATCH] Adding the attend_to_expansion_tokens option --- giga_cherche/models/colbert.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/giga_cherche/models/colbert.py b/giga_cherche/models/colbert.py index ec87040..f2dea7b 100644 --- a/giga_cherche/models/colbert.py +++ b/giga_cherche/models/colbert.py @@ -191,6 +191,7 @@ def __init__( add_special_tokens: Optional[bool] = True, truncation: Optional[bool] = True, query_length: Optional[int] = 32, + attend_to_expansion_tokens: Optional[bool] = False, ): # Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name` self.prompts = prompts or {} @@ -204,6 +205,7 @@ def __init__( self.query_prefix = query_prefix self.document_prefix = document_prefix self.query_length = query_length + self.attend_to_expansion_tokens = attend_to_expansion_tokens if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v3 of SentenceTransformers.", @@ -634,8 +636,10 @@ def encode( ): last_mask_id = len(attention) - 1 # TODO: isn't torch.sum(attention) better? - while last_mask_id > 0 and attention[last_mask_id].item() == 0: - last_mask_id -= 1 + # We do not want to prune expansion tokens in queries even if we do not attend to them in attention layers + if not is_query: + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 # TODO: normalize at the list level/use the module Normalize? if normalize_embeddings: token_emb = torch.nn.functional.normalize( @@ -1105,16 +1109,16 @@ def tokenize( features["token_type_ids"] = features["token_type_ids"][ :, : self.query_length ] - # Fill the attention mask with ones (we attend to "padding" tokens) - features["attention_mask"].fill_(1) - # features["attention_mask"] = torch.ones_like(features["attention_mask"]) + # In the original ColBERT, the original tokens do not attend to the expansion tokens (but the expansion tokens attend to original tokens) + if self.attend_to_expansion_tokens: + # Fill the attention mask with ones (we attend to "padding" tokens used for expansion) + features["attention_mask"].fill_(1) return features else: features = self._first_module().tokenize(texts) # Remplace the second token by the document prefix features["input_ids"][:, 1] = self.document_prefix_id return features - return self._first_module().tokenize(texts) def get_sentence_features(self, *features): return self._first_module().get_sentence_features(*features) @@ -1222,6 +1226,7 @@ def save( config["query_prefix"] = self.query_prefix config["document_prefix"] = self.document_prefix config["query_length"] = self.query_length + config["attend_to_expansion_tokens"] = self.attend_to_expansion_tokens json.dump(config, fOut, indent=2) # Save modules @@ -1627,6 +1632,10 @@ def _load_sbert_model( self.document_prefix = self._model_config["document_prefix"] if "query_length" in self._model_config: self.query_length = self._model_config["query_length"] + if "attend_to_expansion_tokens" in self._model_config: + self.attend_to_expansion_tokens = self._model_config[ + "attend_to_expansion_tokens" + ] # Check if a readme exists model_card_path = load_file_path(