From 5e5e42c3d989b8efad29dce52155b026ef0f7d2a Mon Sep 17 00:00:00 2001 From: Raphael Sourty Date: Mon, 22 Jul 2024 16:19:59 +0200 Subject: [PATCH 1/5] knowledge-distillation-feature --- README.md | 19 +++- giga_cherche/data_collator/__init__.py | 2 +- giga_cherche/data_collator/colbert.py | 6 +- giga_cherche/losses/__init__.py | 4 +- giga_cherche/losses/colbert.py | 135 ++++++++++++++++++++++- giga_cherche/models/colbert.py | 99 +++-------------- giga_cherche/utils/__init__.py | 4 +- giga_cherche/utils/constants.py | 73 ++++++++++++ giga_cherche/utils/dataset_processing.py | 53 +++++++++ kd_training.py | 54 +++++++++ training.py | 2 +- 11 files changed, 355 insertions(+), 96 deletions(-) create mode 100644 giga_cherche/utils/constants.py create mode 100644 giga_cherche/utils/dataset_processing.py create mode 100644 kd_training.py diff --git a/README.md b/README.md index 05550a7..4b1711c 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ The following parameters can be passed to the constructor to set different prope - ```attend_to_expansion_tokens```, whether queries tokens should attend to MASK expansion tokens (original ColBERT did not) - ```skiplist_words```, a list of words to ignore in documents during scoring (default to punctuation) -# Training +## Training Given that giga-cherche ColBERT models are sentence-transformers models, we can benefit from all the bells and whistles from the latest update, including multi-gpu and BF16 training. For now, you can train ColBERT models using triplets dataset (datasets containing a positive and a negative for each query). The syntax is the same as sentence-transformers, using the specific elements adapted to ColBERT from giga-cherche: @@ -83,6 +83,23 @@ trainer = SentenceTransformerTrainer( trainer.train() ``` +## Tokenization + +``` +import ast + +def add_queries_and_documents(example: dict) -> dict: + """Add queries and documents text to the examples.""" + scores = ast.literal_eval(node_or_string=example["scores"]) + processed_example = {"scores": scores, "query": queries[example["query_id"]]} + + n_scores = len(scores) + for i in range(n_scores): + processed_example[f"document_{i}"] = documents[example[f"document_id_{i}"]] + + return processed_example +``` + ## Inference Once trained, the model can then be loaded to perform inference (you can also load the models directly from Hugging Face, for example using the provided ColBERTv2 model [NohTow/colbertv2_sentence_transformer](https://huggingface.co/NohTow/colbertv2_sentence_transformer)): diff --git a/giga_cherche/data_collator/__init__.py b/giga_cherche/data_collator/__init__.py index ccb3bc9..057670e 100644 --- a/giga_cherche/data_collator/__init__.py +++ b/giga_cherche/data_collator/__init__.py @@ -1,3 +1,3 @@ -from .colbert_data_collator import ColBERT +from .colbert import ColBERT __all__ = ["ColBERT"] diff --git a/giga_cherche/data_collator/colbert.py b/giga_cherche/data_collator/colbert.py index 9d15fb4..fe13da7 100644 --- a/giga_cherche/data_collator/colbert.py +++ b/giga_cherche/data_collator/colbert.py @@ -14,7 +14,7 @@ class ColBERT: """ tokenize_fn: Callable - valid_label_columns: list[str] = field(default_factory=lambda: ["label", "score"]) + valid_label_columns: list[str] = field(default_factory=lambda: ["label", "scores"]) def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]: """Collate a list of features into a batch.""" @@ -40,7 +40,9 @@ def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]: is_query = "query" in column or "anchor" in column tokenized = self.tokenize_fn( - [row[column] for row in features], is_query=is_query + [row[column] for row in features], + is_query=is_query, + pad_document=True, ) for key, value in tokenized.items(): diff --git a/giga_cherche/losses/__init__.py b/giga_cherche/losses/__init__.py index 9e2d1d7..8a392ad 100644 --- a/giga_cherche/losses/__init__.py +++ b/giga_cherche/losses/__init__.py @@ -1,3 +1,3 @@ -from .colbert import ColBERTLoss +from .colbert import ColBERTLossv1, ColBERTLossv2 -__all__ = ["ColBERTLoss"] +__all__ = ["ColBERTLossv1", "ColBERTLossv2"] diff --git a/giga_cherche/losses/colbert.py b/giga_cherche/losses/colbert.py index d068767..325b11b 100644 --- a/giga_cherche/losses/colbert.py +++ b/giga_cherche/losses/colbert.py @@ -6,7 +6,7 @@ from sentence_transformers import SentenceTransformer from torch import Tensor, nn -__all__ = ["ColBERTLoss", "ColBERTSimilarityMetric"] +__all__ = ["ColBERTSimilarityMetric", "ColBERTLossv1", "ColBERTLossv2"] class ColBERTSimilarityMetric(Enum): @@ -20,8 +20,28 @@ def COLBERT_SIMILARITY(x, y, mask): simis = simis * mask.unsqueeze(0).unsqueeze(2) return simis.max(axis=3).values.sum(axis=2) + def COLBERT_SIMILARITY_KD(x, y, mask): + # x: (a, s, h) where a is num_queries, s is queries_seqlen, h is hidden_size + # y: (b, t, h) where b is num_documents, t is documents_seqlen + # mask: (b, t) + # Reshape x to (a, 1, s, h) and y to (a, docs_per_query, t, h) + a, s, h = x.shape + b, t, _ = y.shape -class ColBERTLoss(nn.Module): + docs_per_query = b // a + y = y.view(a, docs_per_query, t, h) + # Compute similarities + simis = torch.einsum("ash,adth->adst", x, y) + + # Reshape mask to (a, docs_per_query, 1, t) and apply it + mask = mask.view(a, docs_per_query, 1, t) + simis = simis * mask + + # Compute max along t axis and sum along s axis + return simis.max(axis=3).values.sum(axis=2) + + +class ColBERTLossv1(nn.Module): def __init__( self, model: SentenceTransformer, @@ -78,7 +98,7 @@ def __init__( ) trainer.train() """ - super(ColBERTLoss, self).__init__() + super(ColBERTLossv1, self).__init__() self.distance_metric = distance_metric self.model = model self.size_average = size_average @@ -166,3 +186,112 @@ def citation(self) -> str: abstract = "Neural information retrieval (IR) has greatly advanced search and other knowledge-intensive language tasks. While many neural IR methods encode queries and documents into single-vector representations, late interaction models produce multi-vector representations at the granularity of each token and decompose relevance modeling into scalable token-level computations. This decomposition has been shown to make late interaction more effective, but it inflates the space footprint of these models by an order of magnitude. In this work, we introduce ColBERTv2, a retriever that couples an aggressive residual compression mechanism with a denoised supervision strategy to simultaneously improve the quality and space footprint of late interaction. We evaluate ColBERTv2 across a wide range of benchmarks, establishing state-of-the-art quality within and outside the training domain while reducing the space footprint of late interaction models by 6{--}10x.", } """ + + +class ColBERTLossv2(nn.Module): + def __init__( + self, + model: SentenceTransformer, + distance_metric=ColBERTSimilarityMetric.COLBERT_SIMILARITY_KD, + size_average: bool = True, + ): + """ + + Args: + model: SentenceTransformer model + distance_metric: Function that returns a distance between + two embeddings. The class ColBERTDistanceMetric contains + pre-defined metrices that can be used + size_average: Average by the size of the mini-batch. + + Requirements: + 1. query, list of documents, teacher scores + + """ + super(ColBERTLossv2, self).__init__() + self.distance_metric = distance_metric + self.model = model + self.size_average = size_average + + def get_config_dict(self): + distance_metric_name = self.distance_metric.__name__ + for name, value in vars(ColBERTSimilarityMetric).items(): + if value == self.distance_metric: + distance_metric_name = "ColBERTSimilarityMetric.{}".format(name) + break + + return { + "distance_metric": distance_metric_name, + "size_average": self.size_average, + } + + def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor): + reps = [ + torch.nn.functional.normalize( + self.model(sentence_feature)["token_embeddings"], p=2, dim=-1 + ) + for sentence_feature in sentence_features + ] + + attention_masks = [ + sentence_feature["attention_mask"] for sentence_feature in sentence_features + ] + # We are not applying skiplist mask to the queries + skiplist_masks = [ + torch.ones_like(sentence_features[0]["input_ids"], dtype=torch.bool) + ] + skiplist_masks.extend( + [ + self.model.skiplist_mask( + sentence_feature["input_ids"], skiplist=self.model.skiplist + ) + for sentence_feature in sentence_features[1:] + ] + ) + + masks = [ + torch.logical_and(skiplist_mask, attention_mask) + for skiplist_mask, attention_mask in zip(skiplist_masks, attention_masks) + ] + # Compute the distances between the anchor (0) and the positives (1) as well as the negatives (2) + # Note: the queries mask is not used, if added, take care that the expansion tokens are not masked from scoring (because they might be masked during encoding). We might not need to compute the mask for queries but I let the logic there for now + documents = torch.stack(reps[1:]) + documents_mask = torch.stack(masks[1:]) + distances = torch.cat( + self.distance_metric(reps[0], documents, documents_mask), + ) + + # create corresponding labels + # labels = torch.arange(0, rep_anchor.size(0), device=rep_anchor.device) + labels = torch.arange(0, reps[0].size(0), device=reps[0].device) + # compute contrastive loss using cross-entropy over the distances + loss = F.cross_entropy( + distances, labels, reduction="mean" if self.size_average else "sum" + ) + + return loss + + @property + def citation(self) -> str: + return """ + @inproceedings{santhanam-etal-2022-colbertv2, + title = "{C}ol{BERT}v2: Effective and Efficient Retrieval via Lightweight Late Interaction", + author = "Santhanam, Keshav and + Khattab, Omar and + Saad-Falcon, Jon and + Potts, Christopher and + Zaharia, Matei", + editor = "Carpuat, Marine and + de Marneffe, Marie-Catherine and + Meza Ruiz, Ivan Vladimir", + booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", + month = jul, + year = "2022", + address = "Seattle, United States", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2022.naacl-main.272", + doi = "10.18653/v1/2022.naacl-main.272", + pages = "3715--3734", + abstract = "Neural information retrieval (IR) has greatly advanced search and other knowledge-intensive language tasks. While many neural IR methods encode queries and documents into single-vector representations, late interaction models produce multi-vector representations at the granularity of each token and decompose relevance modeling into scalable token-level computations. This decomposition has been shown to make late interaction more effective, but it inflates the space footprint of these models by an order of magnitude. In this work, we introduce ColBERTv2, a retriever that couples an aggressive residual compression mechanism with a denoised supervision strategy to simultaneously improve the quality and space footprint of late interaction. We evaluate ColBERTv2 across a wide range of benchmarks, establishing state-of-the-art quality within and outside the training domain while reducing the space footprint of late interaction models by 6{--}10x.", + } +""" diff --git a/giga_cherche/models/colbert.py b/giga_cherche/models/colbert.py index abfdadc..4a10d5d 100644 --- a/giga_cherche/models/colbert.py +++ b/giga_cherche/models/colbert.py @@ -17,7 +17,6 @@ Iterable, Literal, Optional, - Tuple, Union, overload, ) @@ -51,6 +50,7 @@ from tqdm.autonotebook import trange from transformers import is_torch_npu_available +from ..utils import MODELS_WITHOUT_FAMILY from .LinearProjection import LinearProjection logger = logging.getLogger(__name__) @@ -178,7 +178,7 @@ def __init__( device: str | None = None, prompts: dict[str, str] | None = None, default_prompt_name: str | None = None, - similarity_fn_name: Optional[Union[str, SimilarityFunction]] = None, + similarity_fn_name: Optional[str | SimilarityFunction] = None, cache_folder: str | None = None, trust_remote_code: bool = False, revision: str | None = None, @@ -245,78 +245,6 @@ def __init__( "Load pretrained SentenceTransformer: {}".format(model_name_or_path) ) - # Old models that don't belong to any organization - basic_transformer_models = [ - "albert-base-v1", - "albert-base-v2", - "albert-large-v1", - "albert-large-v2", - "albert-xlarge-v1", - "albert-xlarge-v2", - "albert-xxlarge-v1", - "albert-xxlarge-v2", - "bert-base-cased-finetuned-mrpc", - "bert-base-cased", - "bert-base-chinese", - "bert-base-german-cased", - "bert-base-german-dbmdz-cased", - "bert-base-german-dbmdz-uncased", - "bert-base-multilingual-cased", - "bert-base-multilingual-uncased", - "bert-base-uncased", - "bert-large-cased-whole-word-masking-finetuned-squad", - "bert-large-cased-whole-word-masking", - "bert-large-cased", - "bert-large-uncased-whole-word-masking-finetuned-squad", - "bert-large-uncased-whole-word-masking", - "bert-large-uncased", - "camembert-base", - "ctrl", - "distilbert-base-cased-distilled-squad", - "distilbert-base-cased", - "distilbert-base-german-cased", - "distilbert-base-multilingual-cased", - "distilbert-base-uncased-distilled-squad", - "distilbert-base-uncased-finetuned-sst-2-english", - "distilbert-base-uncased", - "distilgpt2", - "distilroberta-base", - "gpt2-large", - "gpt2-medium", - "gpt2-xl", - "gpt2", - "openai-gpt", - "roberta-base-openai-detector", - "roberta-base", - "roberta-large-mnli", - "roberta-large-openai-detector", - "roberta-large", - "t5-11b", - "t5-3b", - "t5-base", - "t5-large", - "t5-small", - "transfo-xl-wt103", - "xlm-clm-ende-1024", - "xlm-clm-enfr-1024", - "xlm-mlm-100-1280", - "xlm-mlm-17-1280", - "xlm-mlm-en-2048", - "xlm-mlm-ende-1024", - "xlm-mlm-enfr-1024", - "xlm-mlm-enro-1024", - "xlm-mlm-tlm-xnli15-1024", - "xlm-mlm-xnli15-1024", - "xlm-roberta-base", - "xlm-roberta-large-finetuned-conll02-dutch", - "xlm-roberta-large-finetuned-conll02-spanish", - "xlm-roberta-large-finetuned-conll03-english", - "xlm-roberta-large-finetuned-conll03-german", - "xlm-roberta-large", - "xlnet-base-cased", - "xlnet-large-cased", - ] - if not os.path.exists(model_name_or_path): # Not a path, load from hub if "\\" in model_name_or_path or model_name_or_path.count("/") > 1: @@ -324,7 +252,7 @@ def __init__( if ( "/" not in model_name_or_path - and model_name_or_path.lower() not in basic_transformer_models + and model_name_or_path.lower() not in MODELS_WITHOUT_FAMILY ): # A model from sentence-transformers model_name_or_path = ( @@ -481,12 +409,12 @@ def encode( is_query: bool = True, pool_factor: int = 1, protected_tokens: int = 1, - ) -> Union[list[torch.Tensor], ndarray, torch.Tensor]: + ) -> list[torch.Tensor] | ndarray | torch.Tensor: """ Computes sentence embeddings. Args: - sentences (Union[str, list[str]]): The sentences to embed. + sentences (str | list[str]): The sentences to embed. prompt_name (str | None, optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary, which is either set in the constructor or loaded from the model configuration. For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What @@ -516,7 +444,7 @@ def encode( protected_tokens (int, optional): The number of tokens at the beginning of the sequence that should not be pooled. Defaults to 1 (CLS token). Returns: - Union[list[torch.Tensor], ndarray, torch.Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. + list[torch.Tensor] | ndarray | torch.Tensor: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If ``convert_to_tensor``, a torch torch.Tensor is returned instead. If ``self.truncate_dim <= output_dimension`` then output_dimension is ``self.truncate_dim``. @@ -737,10 +665,7 @@ def encode( else: all_embeddings = [emb.numpy() for emb in all_embeddings] - if input_was_string: - all_embeddings = all_embeddings[0] - - return all_embeddings + return all_embeddings[0] if input_was_string else all_embeddings # TODO: add typing """ @@ -841,7 +766,7 @@ def similarity_fn_name(self) -> str | None: return self._similarity_fn_name @similarity_fn_name.setter - def similarity_fn_name(self, value: Union[str, SimilarityFunction]) -> None: + def similarity_fn_name(self, value: str | SimilarityFunction) -> None: if isinstance(value, SimilarityFunction): value = value.value self._similarity_fn_name = value @@ -1216,6 +1141,7 @@ def tokenize( self, texts: Union[list[str], list[dict], list[tuple[str, str]]], is_query: bool = True, + pad_document: bool = False, ) -> dict[str, torch.Tensor]: """ Tokenizes the texts. @@ -1245,7 +1171,10 @@ def tokenize( return features else: self._first_module().max_seq_length = self.document_length - features = self._first_module().tokenize(texts) + extra_parameters = {} + if pad_document: + extra_parameters["padding"] = "max_length" + features = self._first_module().tokenize(texts, **extra_parameters) # Remplace the second token by the document prefix features["input_ids"][:, 1] = self.document_prefix_id return features @@ -1925,7 +1854,7 @@ def device(self) -> torch.device: def find_tensor_attributes( module: nn.Module, - ) -> list[Tuple[str, torch.Tensor]]: + ) -> list[tuple[str, torch.Tensor]]: tuples = [ (k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) ] diff --git a/giga_cherche/utils/__init__.py b/giga_cherche/utils/__init__.py index 3a97631..b24cab1 100644 --- a/giga_cherche/utils/__init__.py +++ b/giga_cherche/utils/__init__.py @@ -1,3 +1,5 @@ +from .constants import MODELS_WITHOUT_FAMILY +from .dataset_processing import DatasetProcessing from .iter_batch import iter_batch -__all__ = ["iter_batch"] +__all__ = ["MODELS_WITHOUT_FAMILY", "iter_batch", "DatasetProcessing"] diff --git a/giga_cherche/utils/constants.py b/giga_cherche/utils/constants.py new file mode 100644 index 0000000..91c0675 --- /dev/null +++ b/giga_cherche/utils/constants.py @@ -0,0 +1,73 @@ +"""Relevant constants for the project.""" + +# List of models that do not have a family in the Hugging Face model hub. +MODELS_WITHOUT_FAMILY = [ + "albert-base-v1", + "albert-base-v2", + "albert-large-v1", + "albert-large-v2", + "albert-xlarge-v1", + "albert-xlarge-v2", + "albert-xxlarge-v1", + "albert-xxlarge-v2", + "bert-base-cased-finetuned-mrpc", + "bert-base-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "bert-base-multilingual-cased", + "bert-base-multilingual-uncased", + "bert-base-uncased", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking", + "bert-large-cased", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-uncased-whole-word-masking", + "bert-large-uncased", + "camembert-base", + "ctrl", + "distilbert-base-cased-distilled-squad", + "distilbert-base-cased", + "distilbert-base-german-cased", + "distilbert-base-multilingual-cased", + "distilbert-base-uncased-distilled-squad", + "distilbert-base-uncased-finetuned-sst-2-english", + "distilbert-base-uncased", + "distilgpt2", + "distilroberta-base", + "gpt2-large", + "gpt2-medium", + "gpt2-xl", + "gpt2", + "openai-gpt", + "roberta-base-openai-detector", + "roberta-base", + "roberta-large-mnli", + "roberta-large-openai-detector", + "roberta-large", + "t5-11b", + "t5-3b", + "t5-base", + "t5-large", + "t5-small", + "transfo-xl-wt103", + "xlm-clm-ende-1024", + "xlm-clm-enfr-1024", + "xlm-mlm-100-1280", + "xlm-mlm-17-1280", + "xlm-mlm-en-2048", + "xlm-mlm-ende-1024", + "xlm-mlm-enfr-1024", + "xlm-mlm-enro-1024", + "xlm-mlm-tlm-xnli15-1024", + "xlm-mlm-xnli15-1024", + "xlm-roberta-base", + "xlm-roberta-large-finetuned-conll02-dutch", + "xlm-roberta-large-finetuned-conll02-spanish", + "xlm-roberta-large-finetuned-conll03-english", + "xlm-roberta-large-finetuned-conll03-german", + "xlm-roberta-large", + "xlnet-base-cased", + "xlnet-large-cased", +] diff --git a/giga_cherche/utils/dataset_processing.py b/giga_cherche/utils/dataset_processing.py new file mode 100644 index 0000000..b4395ba --- /dev/null +++ b/giga_cherche/utils/dataset_processing.py @@ -0,0 +1,53 @@ +import ast + +import datasets + +__all__ = ["DatasetProcessing"] + + +class DatasetProcessing: + """Preprocess the data by adding queries and documents text to the examples. + + Example: + -------- + from datasets import load_dataset + + from giga_cherche import utils + + train = load_dataset(path="./msmarco_fr", name="train", cache_dir="./msmarco_fr") + queries = load_dataset(path="./msmarco_fr", name="queries", cache_dir="./msmarco_fr") + documents = load_dataset( + path="./msmarco_fr", name="documents", cache_dir="./msmarco_fr" + ) + + train = train.map( + utils.DatasetProcessing( + queries=queries, documents=documents + ).add_queries_and_documents, + remove_columns=[feature for feature in train["train"].features if "id" in feature], + ) + """ + + def __init__(self, queries: datasets.Dataset, documents: datasets.Dataset) -> None: + self.queries = {query["query_id"]: query["text"] for query in queries["train"]} + self.documents = { + document["document_id"]: document["text"] for document in documents["train"] + } + + def add_queries_and_documents(self, example: dict) -> dict: + """Add queries and documents text to the examples.""" + scores = ast.literal_eval(node_or_string=example["scores"]) + + processed_example = { + "scores": scores, + "query": self.queries[example["query_id"]], + } + + n_scores = len(scores) + + for i in range(n_scores): + processed_example[f"document_{i}"] = self.documents[ + example[f"document_id_{i}"] + ] + + return processed_example diff --git a/kd_training.py b/kd_training.py new file mode 100644 index 0000000..a3b8c69 --- /dev/null +++ b/kd_training.py @@ -0,0 +1,54 @@ +from datasets import load_dataset +from sentence_transformers import ( + SentenceTransformerTrainer, + SentenceTransformerTrainingArguments, +) +from sentence_transformers.training_args import BatchSamplers + +from giga_cherche import data_collator, losses, models, utils + +train = load_dataset(path="./msmarco_fr", name="train", cache_dir="./msmarco_fr") +queries = load_dataset(path="./msmarco_fr", name="queries", cache_dir="./msmarco_fr") +documents = load_dataset( + path="./msmarco_fr", name="documents", cache_dir="./msmarco_fr" +) + + +train = train.map( + utils.DatasetProcessing( + queries=queries, documents=documents + ).add_queries_and_documents, + remove_columns=[feature for feature in train["train"].features if "id" in feature], +) + + +model_name = "NohTow/colbertv2_sentence_transformer" +batch_size = 2 +num_train_epochs = 1 +output_dir = "output/msmarco" + +model = models.ColBERT(model_name_or_path=model_name) + +args = SentenceTransformerTrainingArguments( + output_dir=output_dir, + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + fp16=False, + bf16=False, + batch_sampler=BatchSamplers.NO_DUPLICATES, + logging_steps=10, + run_name="colbert-st-evaluation", + learning_rate=3e-6, +) + +train_loss = losses.ColBERTLossv2(model=model) + +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train, + loss=train_loss, + data_collator=data_collator.ColBERT(tokenize_fn=model.tokenize), +) + +trainer.train() diff --git a/training.py b/training.py index dfb7929..d995226 100644 --- a/training.py +++ b/training.py @@ -29,7 +29,7 @@ MAX_EXAMPLES = 100000 train_dataset = train_dataset.shuffle(seed=21).select(range(MAX_EXAMPLES)) -train_loss = losses.ColBERTLoss(model=model) +train_loss = losses.ColBERTLossv1(model=model) # Subsample the evaluation dataset # max_samples = 1000 From af28007e42a46d9b86819211242915eb036e7c79 Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Wed, 24 Jul 2024 13:37:55 +0000 Subject: [PATCH 2/5] Working distillation training --- giga_cherche/data_collator/colbert.py | 23 ++++----- giga_cherche/losses/colbert.py | 38 +++++--------- giga_cherche/utils/dataset_processing.py | 66 ++++++++++++++++++++---- kd_training.py | 42 +++++++++------ 4 files changed, 109 insertions(+), 60 deletions(-) diff --git a/giga_cherche/data_collator/colbert.py b/giga_cherche/data_collator/colbert.py index fe13da7..37d9196 100644 --- a/giga_cherche/data_collator/colbert.py +++ b/giga_cherche/data_collator/colbert.py @@ -19,7 +19,6 @@ class ColBERT: def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]: """Collate a list of features into a batch.""" columns = list(features[0].keys()) - # We should always be able to return a loss, label or not: batch = {"return_loss": True} @@ -36,16 +35,16 @@ def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]: # Extract the feature columns for column in columns: - # We tokenize the query differently than the documents, TODO: define a parameter "query_column" - is_query = "query" in column or "anchor" in column - - tokenized = self.tokenize_fn( - [row[column] for row in features], - is_query=is_query, - pad_document=True, - ) - - for key, value in tokenized.items(): - batch[f"{column}_{key}"] = value + # We do not tokenize columns containing the ids. It would be better to throw them away during the dataset processing (TODO), but this break sentence transformers datasets extraction. + if "_id" not in column: + # We tokenize the query differently than the documents, TODO: define a parameter "query_column" + is_query = "query" in column or "anchor" in column + tokenized = self.tokenize_fn( + [row[column] for row in features], + is_query=is_query, + pad_document=True, + ) + for key, value in tokenized.items(): + batch[f"{column}_{key}"] = value return batch diff --git a/giga_cherche/losses/colbert.py b/giga_cherche/losses/colbert.py index 325b11b..1302e20 100644 --- a/giga_cherche/losses/colbert.py +++ b/giga_cherche/losses/colbert.py @@ -22,19 +22,12 @@ def COLBERT_SIMILARITY(x, y, mask): def COLBERT_SIMILARITY_KD(x, y, mask): # x: (a, s, h) where a is num_queries, s is queries_seqlen, h is hidden_size - # y: (b, t, h) where b is num_documents, t is documents_seqlen - # mask: (b, t) - # Reshape x to (a, 1, s, h) and y to (a, docs_per_query, t, h) - a, s, h = x.shape - b, t, _ = y.shape - - docs_per_query = b // a - y = y.view(a, docs_per_query, t, h) - # Compute similarities - simis = torch.einsum("ash,adth->adst", x, y) + # y: (a, b, t, h) a is number query, where b is num_documents_per_query, t is documents_seqlen and h is hidden_size + # mask: (a, b, t) - # Reshape mask to (a, docs_per_query, 1, t) and apply it - mask = mask.view(a, docs_per_query, 1, t) + # Compute similarities + simis = torch.einsum("ash,abth->abst", x, y) + mask = mask.unsqueeze(2) simis = simis * mask # Compute max along t axis and sum along s axis @@ -152,6 +145,7 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor ], dim=1, ) + # create corresponding labels # labels = torch.arange(0, rep_anchor.size(0), device=rep_anchor.device) labels = torch.arange(0, reps[0].size(0), device=reps[0].device) @@ -255,20 +249,16 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor ] # Compute the distances between the anchor (0) and the positives (1) as well as the negatives (2) # Note: the queries mask is not used, if added, take care that the expansion tokens are not masked from scoring (because they might be masked during encoding). We might not need to compute the mask for queries but I let the logic there for now - documents = torch.stack(reps[1:]) - documents_mask = torch.stack(masks[1:]) - distances = torch.cat( - self.distance_metric(reps[0], documents, documents_mask), - ) + documents = torch.stack(reps[1:], dim=1) - # create corresponding labels - # labels = torch.arange(0, rep_anchor.size(0), device=rep_anchor.device) - labels = torch.arange(0, reps[0].size(0), device=reps[0].device) - # compute contrastive loss using cross-entropy over the distances - loss = F.cross_entropy( - distances, labels, reduction="mean" if self.size_average else "sum" - ) + documents_mask = torch.stack(masks[1:], dim=1) + distances = self.distance_metric(reps[0], documents, documents_mask) + target_scores = torch.nn.functional.log_softmax(labels, dim=-1) + log_scores = torch.nn.functional.log_softmax(distances, dim=-1) + loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)( + log_scores, target_scores + ) return loss @property diff --git a/giga_cherche/utils/dataset_processing.py b/giga_cherche/utils/dataset_processing.py index b4395ba..dd7c0ef 100644 --- a/giga_cherche/utils/dataset_processing.py +++ b/giga_cherche/utils/dataset_processing.py @@ -29,9 +29,18 @@ class DatasetProcessing: """ def __init__(self, queries: datasets.Dataset, documents: datasets.Dataset) -> None: - self.queries = {query["query_id"]: query["text"] for query in queries["train"]} - self.documents = { - document["document_id"]: document["text"] for document in documents["train"] + # self.queries = {query["query_id"]: query["text"] for query in queries["train"]} + self.queries = queries + self.queries_index = { + query_id: i for i, query_id in enumerate(self.queries["train"]["query_id"]) + } + # self.documents = { + # document["document_id"]: document["text"] for document in documents["train"] + # } + self.documents = documents + self.documents_index = { + document_id: i + for i, document_id in enumerate(self.documents["train"]["document_id"]) } def add_queries_and_documents(self, example: dict) -> dict: @@ -40,14 +49,53 @@ def add_queries_and_documents(self, example: dict) -> dict: processed_example = { "scores": scores, - "query": self.queries[example["query_id"]], + "query": self.queries["train"][self.queries_index[example["query_id"]]][ + "text" + ], } n_scores = len(scores) - for i in range(n_scores): - processed_example[f"document_{i}"] = self.documents[ - example[f"document_id_{i}"] - ] - + try: + processed_example[f"document_{i}"] = self.documents["train"][ + self.documents_index[example[f"document_id_{i}"]] + ]["text"] + except KeyError: + processed_example[f"document_{i}"] = "" + print(f"KeyError: {example[f'document_id_{i}']}") return processed_example + + def add_queries_and_documents_transform(self, examples: dict) -> dict: + """Add queries and documents text to the examples.""" + examples["scores"] = [ + ast.literal_eval(node_or_string=score)[:32] for score in examples["scores"] + ] + examples["query"] = [ + self.queries["train"][self.queries_index[query_id]]["text"] + for query_id in examples["query_id"] + ] + n_scores = len(examples["scores"][0]) + for i in range(n_scores): + documents = [] + for document_id in examples[f"document_id_{i}"]: + try: + documents.append( + self.documents["train"][self.documents_index[document_id]][ + "text" + ] + ) + # print("loaded") + except KeyError: + documents.append("") + # print(f"KeyError: {document_id}") + examples[f"document_{i}"] = documents + # for i in range(n_scores): + # documents = [] + # try: + # processed_example[f"document_{i}"] = self.documents["train"][ + # self.documents_index[example[f"document_id_{i}"]] + # ]["text"] + # except KeyError: + # processed_example[f"document_{i}"] = "" + # print(f"KeyError: {example[f'document_id_{i}']}") + return examples diff --git a/kd_training.py b/kd_training.py index a3b8c69..a8404f6 100644 --- a/kd_training.py +++ b/kd_training.py @@ -1,31 +1,44 @@ -from datasets import load_dataset from sentence_transformers import ( SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) -from sentence_transformers.training_args import BatchSamplers +from datasets import load_dataset from giga_cherche import data_collator, losses, models, utils -train = load_dataset(path="./msmarco_fr", name="train", cache_dir="./msmarco_fr") -queries = load_dataset(path="./msmarco_fr", name="queries", cache_dir="./msmarco_fr") -documents = load_dataset( - path="./msmarco_fr", name="documents", cache_dir="./msmarco_fr" +train = load_dataset( + path="./datasets/msmarco_fr_full", + name="train", ) +queries = load_dataset( + path="./datasets/msmarco_fr_full", + name="queries", +) + +documents = load_dataset( + path="./datasets/msmarco_fr_full", + name="documents", +) -train = train.map( +train.set_transform( utils.DatasetProcessing( queries=queries, documents=documents - ).add_queries_and_documents, - remove_columns=[feature for feature in train["train"].features if "id" in feature], + ).add_queries_and_documents_transform, + # remove_columns=[feature for feature in train["train"].features if "id" in feature], ) +# train = train.map( +# utils.DatasetProcessing( +# queries=queries, documents=documents +# ).add_queries_and_documents, +# # remove_columns=[feature for feature in train["train"].features if "id" in feature], +# ) -model_name = "NohTow/colbertv2_sentence_transformer" -batch_size = 2 +model_name = "bert-base-uncased" +batch_size = 16 num_train_epochs = 1 -output_dir = "output/msmarco" +output_dir = "output/distillation_run-bert-base" model = models.ColBERT(model_name_or_path=model_name) @@ -35,10 +48,9 @@ per_device_train_batch_size=batch_size, fp16=False, bf16=False, - batch_sampler=BatchSamplers.NO_DUPLICATES, logging_steps=10, - run_name="colbert-st-evaluation", - learning_rate=3e-6, + run_name="distillation_run-bert-base", + learning_rate=1e-5, ) train_loss = losses.ColBERTLossv2(model=model) From 48a2428c2e693e45f430b13b4117cbea785317b0 Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Thu, 25 Jul 2024 12:39:25 +0000 Subject: [PATCH 3/5] Fixing iteration for BEIR eval --- evaluation_beir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation_beir.py b/evaluation_beir.py index aa6d05d..7dc371a 100644 --- a/evaluation_beir.py +++ b/evaluation_beir.py @@ -32,7 +32,7 @@ scores = [] for batch in utils.iter_batch(queries, batch_size=5): queries_embeddings = model.encode( - sentences=[query["text"] for query in batch], + sentences=queries, convert_to_numpy=True, is_query=True, ) From 8214dabf019fd064bb49e657b3d74d84e0978d8f Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Thu, 25 Jul 2024 12:47:10 +0000 Subject: [PATCH 4/5] Fixing unfortunate replace --- .gitignore | 4 +++- giga_cherche/models/colbert.py | 2 +- giga_cherche/rerank/colbert.py | 13 +++++++++++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 8110fa2..f6f8166 100644 --- a/.gitignore +++ b/.gitignore @@ -144,6 +144,8 @@ wandb/ # output models output_models/ +output/ # datasets -evaluation_datasets/ \ No newline at end of file +evaluation_datasets/ +datasets/ \ No newline at end of file diff --git a/giga_cherche/models/colbert.py b/giga_cherche/models/colbert.py index 4a10d5d..20bd740 100644 --- a/giga_cherche/models/colbert.py +++ b/giga_cherche/models/colbert.py @@ -657,7 +657,7 @@ def encode( ] # Else, we already have a list of tensors, the expected output else: - all_embeddings = torch.torch.Tensor() + all_embeddings = torch.tensor() elif convert_to_numpy: # We return a list of numpy arrays and not a big numpy array because we cannot guarantee all element have the same sequence length if all_embeddings[0].dtype == torch.bfloat16: diff --git a/giga_cherche/rerank/colbert.py b/giga_cherche/rerank/colbert.py index cd0e11e..dff4d9d 100644 --- a/giga_cherche/rerank/colbert.py +++ b/giga_cherche/rerank/colbert.py @@ -34,14 +34,23 @@ def rerank( for query, query_documents_embeddings, query_doc_ids in zip( queries, batch_documents_embeddings, batch_doc_ids ): + print( + torch.tensor( + query_documents_embeddings[0], + dtype=torch.float32, + device=query.device, + ).shape + ) documents_embeddings = [ - torch.torch.Tensor(embeddings, dtype=torch.float32, device=query.device) + torch.tensor(embeddings, dtype=torch.float32, device=query.device) for embeddings in query_documents_embeddings ] documents_embeddings = torch.nn.utils.rnn.pad_sequence( documents_embeddings, batch_first=True, padding_value=0 ) - query_scores = colbert_score(query.unsqueeze(0), documents_embeddings)[0] + query_scores = colbert_score.colbert_score( + query.unsqueeze(0), documents_embeddings + )[0] reranked_query_scores, sorted_indices = torch.sort( query_scores, descending=True ) From 7b939caca6aa2dfe9a1d1aeef9dd9afba63a24be Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Thu, 25 Jul 2024 14:33:23 +0000 Subject: [PATCH 5/5] Removing unecessary print and actually fixing the iteration for BEIR eval (whoops) --- evaluation_beir.py | 2 +- giga_cherche/rerank/colbert.py | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/evaluation_beir.py b/evaluation_beir.py index 7dc371a..b0cfc52 100644 --- a/evaluation_beir.py +++ b/evaluation_beir.py @@ -32,7 +32,7 @@ scores = [] for batch in utils.iter_batch(queries, batch_size=5): queries_embeddings = model.encode( - sentences=queries, + sentences=batch, convert_to_numpy=True, is_query=True, ) diff --git a/giga_cherche/rerank/colbert.py b/giga_cherche/rerank/colbert.py index dff4d9d..b77a38d 100644 --- a/giga_cherche/rerank/colbert.py +++ b/giga_cherche/rerank/colbert.py @@ -34,13 +34,6 @@ def rerank( for query, query_documents_embeddings, query_doc_ids in zip( queries, batch_documents_embeddings, batch_doc_ids ): - print( - torch.tensor( - query_documents_embeddings[0], - dtype=torch.float32, - device=query.device, - ).shape - ) documents_embeddings = [ torch.tensor(embeddings, dtype=torch.float32, device=query.device) for embeddings in query_documents_embeddings