diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml new file mode 100644 index 0000000..04a59ee --- /dev/null +++ b/.github/workflows/python-tests.yml @@ -0,0 +1,28 @@ +name: Python Tests + +on: + pull_request: + branches: + - '**' + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' # Specify the Python version you need + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ".[dev]" + + - name: Run tests + run: | + pytest giga_cherche --cov=giga_cherche diff --git a/README.md b/README.md index 4b1711c..be32489 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,9 @@ For example, to run the BEIR evaluations using giga-cherche indexes: # Modeling The modeling of giga-cherche is based on sentence-transformers which allow to build a ColBERT model from any encoder available by appending a projection layer applied to the output of the encoders to reduce the embeddings dimension. ``` -from giga_cherche.models import ColBERT +from giga_cherche import models model_name = "bert-base-uncased" -model = ColBERT(model_name_or_path=model_name) +model = models.ColBERT(model_name_or_path=model_name) ``` The following parameters can be passed to the constructor to set different properties of the model: - ```embedding_size```, the output size of the projection layer and so the dimension of the embeddings @@ -40,7 +40,7 @@ from sentence_transformers import ( SentenceTransformerTrainingArguments, ) -from giga_cherche import losses, models, data_collator, evaluation +from giga_cherche import losses, models, datasets, evaluation model_name = "bert-base-uncased" batch_size = 32 @@ -77,7 +77,7 @@ trainer = SentenceTransformerTrainer( eval_dataset=eval_dataset, loss=train_loss, evaluator=dev_evaluator, - data_collator=data_collator.ColBERT(model.tokenize), + data_collator=utils.ColBERTCollator(model.tokenize), ) trainer.train() @@ -88,7 +88,7 @@ trainer.train() ``` import ast -def add_queries_and_documents(example: dict) -> dict: +def add_queries_and_documents(Examples dict) -> dict: """Add queries and documents text to the examples.""" scores = ast.literal_eval(node_or_string=example["scores"]) processed_example = {"scores": scores, "query": queries[example["query_id"]]} @@ -135,7 +135,7 @@ You can then compute the ColBERT max-sim scores like this: ```python from giga_cherche import scores -similarity_scores = scores.colbert_score(query_embeddings, document_embeddings) +similarity_scores = scores.colbert_scores(query_embeddings, document_embeddings) ``` ## Indexing diff --git a/giga_cherche/__init__.py b/giga_cherche/__init__.py index cb4c52b..5d4011d 100644 --- a/giga_cherche/__init__.py +++ b/giga_cherche/__init__.py @@ -1,9 +1,10 @@ __all__ = [ - "models", - "losses", - "scores", "evaluation", "indexes", - "reranker", - "data_collator", + "losses", + "models", + "rerank", + "retrieve", + "scores", + "utils", ] diff --git a/giga_cherche/data_collator/__init__.py b/giga_cherche/data_collator/__init__.py deleted file mode 100644 index 057670e..0000000 --- a/giga_cherche/data_collator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .colbert import ColBERT - -__all__ = ["ColBERT"] diff --git a/giga_cherche/evaluation/beir.py b/giga_cherche/evaluation/beir.py index 229b0f8..f7fe783 100644 --- a/giga_cherche/evaluation/beir.py +++ b/giga_cherche/evaluation/beir.py @@ -3,11 +3,18 @@ import random from collections import defaultdict -__all__ = ["evaluate", "load_beir", "get_beir_triples"] - def add_duplicates(queries: list[str], scores: list[list[dict]]) -> list: - """Add back duplicates scores to the set of candidates.""" + """Add back duplicates scores to the set of candidates. + + Parameters + ---------- + queries + List of queries. + scores + Scores of the retrieval model. + + """ query_counts = defaultdict(int) for query in queries: query_counts[query] += 1 @@ -31,7 +38,9 @@ def load_beir(dataset_name: str, split: str = "test") -> tuple[list, list, dict] Parameters ---------- dataset_name - Dataset name: scifact. + Name of the beir dataset. + split + Split to load. """ from beir import util @@ -85,14 +94,14 @@ def get_beir_triples( Examples -------- - >>> from neural_cherche import utils + >>> from giga_cherche import evaluation - >>> documents, queries, qrels = utils.load_beir( + >>> documents, queries, qrels = evaluation.load_beir( ... "scifact", ... split="test", ... ) - >>> triples = utils.get_beir_triples( + >>> triples = evaluation.get_beir_triples( ... key="id", ... on=["title", "text"], ... documents=documents, @@ -146,59 +155,6 @@ def evaluate( metrics Metrics to compute. - Examples - -------- - >>> from neural_cherche import models, retrieve, utils - >>> import torch - - >>> _ = torch.manual_seed(42) - - >>> model = models.Splade( - ... model_name_or_path="raphaelsty/neural-cherche-sparse-embed", - ... device="cpu", - ... ) - - >>> documents, queries, qrels = utils.load_beir( - ... "scifact", - ... split="test", - ... ) - - >>> documents = documents[:10] - - >>> retriever = retrieve.Splade( - ... key="id", - ... on=["title", "text"], - ... model=model - ... ) - - >>> documents_embeddings = retriever.encode_documents( - ... documents=documents, - ... batch_size=1, - ... ) - - >>> documents_embeddings = retriever.add( - ... documents_embeddings=documents_embeddings, - ... ) - - >>> queries_embeddings = retriever.encode_queries( - ... queries=queries, - ... batch_size=1, - ... ) - - >>> scores = retriever( - ... queries_embeddings=queries_embeddings, - ... k=30, - ... batch_size=1, - ... ) - - >>> utils.evaluate( - ... scores=scores, - ... qrels=qrels, - ... queries=queries, - ... metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"] - ... ) - {'map': 0.0033333333333333335, 'ndcg@10': 0.0033333333333333335, 'ndcg@100': 0.0033333333333333335, 'recall@10': 0.0033333333333333335, 'recall@100': 0.0033333333333333335} - """ from ranx import Qrels, Run, evaluate diff --git a/giga_cherche/evaluation/colbert_triplet_evaluator.py b/giga_cherche/evaluation/colbert_triplet_evaluator.py index 5780ad2..b5ee6bc 100644 --- a/giga_cherche/evaluation/colbert_triplet_evaluator.py +++ b/giga_cherche/evaluation/colbert_triplet_evaluator.py @@ -10,19 +10,17 @@ from sentence_transformers.SentenceTransformer import SentenceTransformer from sentence_transformers.similarity_functions import SimilarityFunction -from giga_cherche.scores.colbert_score import colbert_score +from ..scores import colbert_scores logger = logging.getLogger(__name__) -__all__ = ["ColBERTTripletEvaluator"] - class ColBERTTripletEvaluator(SentenceEvaluator): """ Evaluate a model based on a triplet: (sentence, positive_example, negative_example). Checks if colbert distance(sentence, positive_example) < distance(sentence, negative_example). - Example: + Examples :: from sentence_transformers import SentenceTransformer @@ -198,15 +196,20 @@ def __call__( # Colbert distance # pos_colbert_distances = colbert_pairwise_score(embeddings_anchors, embeddings_positives) # neg_colbert_distances = colbert_pairwise_score(embeddings_anchors, embeddings_negatives) - pos_colbert_distances_full = colbert_score( - embeddings_anchors, embeddings_positives + pos_colbert_distances_full = colbert_scores( + queries_embeddings=embeddings_anchors, + documents_embeddings=embeddings_positives, ) - neg_colbert_distances_full = colbert_score( - embeddings_anchors, embeddings_negatives + + neg_colbert_distances_full = colbert_scores( + queries_embeddings=embeddings_anchors, + documents_embeddings=embeddings_negatives, ) + distances_full = torch.cat( [pos_colbert_distances_full, neg_colbert_distances_full], dim=1 ) + # print(distances_full.shape) labels = np.arange(0, len(embeddings_anchors)) indices = np.argsort(-distances_full.cpu().numpy(), axis=1) diff --git a/giga_cherche/indexes/base.py b/giga_cherche/indexes/base.py index 95861f9..b9078b7 100644 --- a/giga_cherche/indexes/base.py +++ b/giga_cherche/indexes/base.py @@ -1,7 +1,5 @@ from abc import ABC, abstractmethod -__all__ = ["Base"] - class Base(ABC): """Base class for all indexes. Indexes are used to store and retrieve embeddings.""" diff --git a/giga_cherche/indexes/weaviate.py b/giga_cherche/indexes/weaviate.py index 2b88512..db86ed5 100644 --- a/giga_cherche/indexes/weaviate.py +++ b/giga_cherche/indexes/weaviate.py @@ -1,13 +1,14 @@ import asyncio import time -import weaviate -import weaviate.classes as wvc +try: + import weaviate + import weaviate.classes as wvc +except ImportError: + pass from .base import Base -__all__ = ["Weaviate"] - # TODO: define Index metaclass # max_doc_length is used to set a limit in the fetch embeddings method as the speed is dependant on the number of embeddings fetched diff --git a/giga_cherche/losses/__init__.py b/giga_cherche/losses/__init__.py index 8a392ad..b53533c 100644 --- a/giga_cherche/losses/__init__.py +++ b/giga_cherche/losses/__init__.py @@ -1,3 +1,4 @@ -from .colbert import ColBERTLossv1, ColBERTLossv2 +from .contrastive import Contrastive +from .distillation import Distillation -__all__ = ["ColBERTLossv1", "ColBERTLossv2"] +__all__ = ["Contrastive", "Distillation"] diff --git a/giga_cherche/losses/colbert.py b/giga_cherche/losses/colbert.py deleted file mode 100644 index 1302e20..0000000 --- a/giga_cherche/losses/colbert.py +++ /dev/null @@ -1,287 +0,0 @@ -from enum import Enum -from typing import Iterable - -import torch -import torch.nn.functional as F -from sentence_transformers import SentenceTransformer -from torch import Tensor, nn - -__all__ = ["ColBERTSimilarityMetric", "ColBERTLossv1", "ColBERTLossv2"] - - -class ColBERTSimilarityMetric(Enum): - """The metric for the contrastive loss""" - - def COLBERT_SIMILARITY(x, y, mask): - # a num_queries, s queries_seqlen, h hidden_size, b num_documents, t documents_seqlen - # Take make along the t axis (get max similarity for each query tokens), then sum over all the query tokens - simis = torch.einsum("ash,bth->abst", x, y) - # Masking out the padding tokens using broadcasting (original mask has shape (b, t) -> (1, b, 1, t)) - simis = simis * mask.unsqueeze(0).unsqueeze(2) - return simis.max(axis=3).values.sum(axis=2) - - def COLBERT_SIMILARITY_KD(x, y, mask): - # x: (a, s, h) where a is num_queries, s is queries_seqlen, h is hidden_size - # y: (a, b, t, h) a is number query, where b is num_documents_per_query, t is documents_seqlen and h is hidden_size - # mask: (a, b, t) - - # Compute similarities - simis = torch.einsum("ash,abth->abst", x, y) - mask = mask.unsqueeze(2) - simis = simis * mask - - # Compute max along t axis and sum along s axis - return simis.max(axis=3).values.sum(axis=2) - - -class ColBERTLossv1(nn.Module): - def __init__( - self, - model: SentenceTransformer, - distance_metric=ColBERTSimilarityMetric.COLBERT_SIMILARITY, - size_average: bool = True, - ): - """ - Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the - two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. - - Args: - model: SentenceTransformer model - distance_metric: Function that returns a distance between - two embeddings. The class ColBERTDistanceMetric contains - pre-defined metrices that can be used - size_average: Average by the size of the mini-batch. - - References: - * Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf - * `Training Examples > Quora Duplicate Questions <../../examples/training/quora_duplicate_questions/README.html>`_ - - Requirements: - 1. (anchor, positive/negative) pairs - - Relations: - - :class:`OnlineContrastiveLoss` is similar, but uses hard positive and hard negative pairs. - It often yields better results. - - Inputs: - +-----------------------------------------------+------------------------------+ - | Texts | Labels | - +===============================================+==============================+ - | (anchor, positive/negative) pairs | 1 if positive, 0 if negative | - +-----------------------------------------------+------------------------------+ - - Example: - :: - - from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses - from datasets import Dataset - - model = SentenceTransformer("microsoft/mpnet-base") - train_dataset = Dataset.from_dict({ - "sentence1": ["It's nice weather outside today.", "He drove to work."], - "sentence2": ["It's so sunny.", "She walked to the store."], - "label": [1, 0], - }) - loss = losses.ContrastiveLoss(model) - - trainer = SentenceTransformerTrainer( - model=model, - train_dataset=train_dataset, - loss=loss, - ) - trainer.train() - """ - super(ColBERTLossv1, self).__init__() - self.distance_metric = distance_metric - self.model = model - self.size_average = size_average - - def get_config_dict(self): - distance_metric_name = self.distance_metric.__name__ - for name, value in vars(ColBERTSimilarityMetric).items(): - if value == self.distance_metric: - distance_metric_name = "ColBERTSimilarityMetric.{}".format(name) - break - - return { - "distance_metric": distance_metric_name, - "size_average": self.size_average, - } - - def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor): - reps = [ - torch.nn.functional.normalize( - self.model(sentence_feature)["token_embeddings"], p=2, dim=-1 - ) - for sentence_feature in sentence_features - ] - - attention_masks = [ - sentence_feature["attention_mask"] for sentence_feature in sentence_features - ] - # We are not applying skiplist mask to the queries - skiplist_masks = [ - torch.ones_like(sentence_features[0]["input_ids"], dtype=torch.bool) - ] - skiplist_masks.extend( - [ - self.model.skiplist_mask( - sentence_feature["input_ids"], skiplist=self.model.skiplist - ) - for sentence_feature in sentence_features[1:] - ] - ) - - masks = [ - torch.logical_and(skiplist_mask, attention_mask) - for skiplist_mask, attention_mask in zip(skiplist_masks, attention_masks) - ] - # Compute the distances between the anchor (0) and the positives (1) as well as the negatives (2) - # Note: the queries mask is not used, if added, take care that the expansion tokens are not masked from scoring (because they might be masked during encoding). We might not need to compute the mask for queries but I let the logic there for now - distances = torch.cat( - [ - self.distance_metric(reps[0], rep, mask) - for rep, mask in zip(reps[1:], masks[1:]) - ], - dim=1, - ) - - # create corresponding labels - # labels = torch.arange(0, rep_anchor.size(0), device=rep_anchor.device) - labels = torch.arange(0, reps[0].size(0), device=reps[0].device) - # compute contrastive loss using cross-entropy over the distances - loss = F.cross_entropy( - distances, labels, reduction="mean" if self.size_average else "sum" - ) - - return loss - - @property - def citation(self) -> str: - return """ - @inproceedings{santhanam-etal-2022-colbertv2, - title = "{C}ol{BERT}v2: Effective and Efficient Retrieval via Lightweight Late Interaction", - author = "Santhanam, Keshav and - Khattab, Omar and - Saad-Falcon, Jon and - Potts, Christopher and - Zaharia, Matei", - editor = "Carpuat, Marine and - de Marneffe, Marie-Catherine and - Meza Ruiz, Ivan Vladimir", - booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", - month = jul, - year = "2022", - address = "Seattle, United States", - publisher = "Association for Computational Linguistics", - url = "https://aclanthology.org/2022.naacl-main.272", - doi = "10.18653/v1/2022.naacl-main.272", - pages = "3715--3734", - abstract = "Neural information retrieval (IR) has greatly advanced search and other knowledge-intensive language tasks. While many neural IR methods encode queries and documents into single-vector representations, late interaction models produce multi-vector representations at the granularity of each token and decompose relevance modeling into scalable token-level computations. This decomposition has been shown to make late interaction more effective, but it inflates the space footprint of these models by an order of magnitude. In this work, we introduce ColBERTv2, a retriever that couples an aggressive residual compression mechanism with a denoised supervision strategy to simultaneously improve the quality and space footprint of late interaction. We evaluate ColBERTv2 across a wide range of benchmarks, establishing state-of-the-art quality within and outside the training domain while reducing the space footprint of late interaction models by 6{--}10x.", - } -""" - - -class ColBERTLossv2(nn.Module): - def __init__( - self, - model: SentenceTransformer, - distance_metric=ColBERTSimilarityMetric.COLBERT_SIMILARITY_KD, - size_average: bool = True, - ): - """ - - Args: - model: SentenceTransformer model - distance_metric: Function that returns a distance between - two embeddings. The class ColBERTDistanceMetric contains - pre-defined metrices that can be used - size_average: Average by the size of the mini-batch. - - Requirements: - 1. query, list of documents, teacher scores - - """ - super(ColBERTLossv2, self).__init__() - self.distance_metric = distance_metric - self.model = model - self.size_average = size_average - - def get_config_dict(self): - distance_metric_name = self.distance_metric.__name__ - for name, value in vars(ColBERTSimilarityMetric).items(): - if value == self.distance_metric: - distance_metric_name = "ColBERTSimilarityMetric.{}".format(name) - break - - return { - "distance_metric": distance_metric_name, - "size_average": self.size_average, - } - - def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor): - reps = [ - torch.nn.functional.normalize( - self.model(sentence_feature)["token_embeddings"], p=2, dim=-1 - ) - for sentence_feature in sentence_features - ] - - attention_masks = [ - sentence_feature["attention_mask"] for sentence_feature in sentence_features - ] - # We are not applying skiplist mask to the queries - skiplist_masks = [ - torch.ones_like(sentence_features[0]["input_ids"], dtype=torch.bool) - ] - skiplist_masks.extend( - [ - self.model.skiplist_mask( - sentence_feature["input_ids"], skiplist=self.model.skiplist - ) - for sentence_feature in sentence_features[1:] - ] - ) - - masks = [ - torch.logical_and(skiplist_mask, attention_mask) - for skiplist_mask, attention_mask in zip(skiplist_masks, attention_masks) - ] - # Compute the distances between the anchor (0) and the positives (1) as well as the negatives (2) - # Note: the queries mask is not used, if added, take care that the expansion tokens are not masked from scoring (because they might be masked during encoding). We might not need to compute the mask for queries but I let the logic there for now - documents = torch.stack(reps[1:], dim=1) - - documents_mask = torch.stack(masks[1:], dim=1) - - distances = self.distance_metric(reps[0], documents, documents_mask) - target_scores = torch.nn.functional.log_softmax(labels, dim=-1) - log_scores = torch.nn.functional.log_softmax(distances, dim=-1) - loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)( - log_scores, target_scores - ) - return loss - - @property - def citation(self) -> str: - return """ - @inproceedings{santhanam-etal-2022-colbertv2, - title = "{C}ol{BERT}v2: Effective and Efficient Retrieval via Lightweight Late Interaction", - author = "Santhanam, Keshav and - Khattab, Omar and - Saad-Falcon, Jon and - Potts, Christopher and - Zaharia, Matei", - editor = "Carpuat, Marine and - de Marneffe, Marie-Catherine and - Meza Ruiz, Ivan Vladimir", - booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", - month = jul, - year = "2022", - address = "Seattle, United States", - publisher = "Association for Computational Linguistics", - url = "https://aclanthology.org/2022.naacl-main.272", - doi = "10.18653/v1/2022.naacl-main.272", - pages = "3715--3734", - abstract = "Neural information retrieval (IR) has greatly advanced search and other knowledge-intensive language tasks. While many neural IR methods encode queries and documents into single-vector representations, late interaction models produce multi-vector representations at the granularity of each token and decompose relevance modeling into scalable token-level computations. This decomposition has been shown to make late interaction more effective, but it inflates the space footprint of these models by an order of magnitude. In this work, we introduce ColBERTv2, a retriever that couples an aggressive residual compression mechanism with a denoised supervision strategy to simultaneously improve the quality and space footprint of late interaction. We evaluate ColBERTv2 across a wide range of benchmarks, establishing state-of-the-art quality within and outside the training domain while reducing the space footprint of late interaction models by 6{--}10x.", - } -""" diff --git a/giga_cherche/losses/contrastive.py b/giga_cherche/losses/contrastive.py new file mode 100644 index 0000000..e09c957 --- /dev/null +++ b/giga_cherche/losses/contrastive.py @@ -0,0 +1,161 @@ +from typing import Iterable + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ..models import ColBERT +from ..scores import colbert_scores + + +def extract_skiplist_mask( + sentence_features: Iterable[dict[str, torch.Tensor]], + skiplist: list[int], +) -> list[torch.Tensor]: + """Extracts the attention masks from the sentence features. We apply a skiplist mask to the documents. + We skip the first sentence feature because it is the query. + + Examples + -------- + >>> import torch + + >>> sentence_features = [ + ... { + ... "input_ids": torch.tensor([[1, 2, 3, 4]]), + ... "attention_mask": torch.tensor([[1, 1, 1, 1]]), + ... }, + ... { + ... "input_ids": torch.tensor([[1, 2, 3, 4]]), + ... "attention_mask": torch.tensor([[1, 1, 1, 1]]), + ... }, + ... { + ... "input_ids": torch.tensor([[1, 2, 3, 4]]), + ... "attention_mask": torch.tensor([[1, 1, 1, 1]]), + ... }, + ... ] + + >>> extract_skiplist_mask( + ... sentence_features=sentence_features, + ... skiplist=[1, 2, 3], + ... ) + [tensor([[True, True, True, True]]), tensor([[False, False, False, True]]), tensor([[False, False, False, True]])] + + """ + attention_masks = [ + sentence_feature["attention_mask"] for sentence_feature in sentence_features + ] + + skiplist_masks = [ + torch.ones_like(sentence_features[0]["input_ids"], dtype=torch.bool) + ] + + # We skip the first sentence feature because it is the query. + skiplist_masks.extend( + [ + ColBERT.skiplist_mask( + input_ids=sentence_feature["input_ids"], skiplist=skiplist + ) + for sentence_feature in sentence_features[1:] + ] + ) + + return [ + torch.logical_and(skiplist_mask, attention_mask) + for skiplist_mask, attention_mask in zip(skiplist_masks, attention_masks) + ] + + +class Contrastive(nn.Module): + """ + Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the + two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. + + Parameters + ---------- + model + ColBERT model. + distance_metric + ColBERT scoring function. Defaults to colbert_scores. + size_average + Average by the size of the mini-batch. + + Examples + -------- + >>> from giga_cherche import models, losses + + >>> model = models.ColBERT( + ... model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu" + ... ) + + >>> loss = losses.Contrastive(model=model) + + >>> anchor = model.tokenize([ + ... "fruits are healthy.", + ... ], is_query=True) + + >>> positive = model.tokenize([ + ... "fruits are good for health.", + ... ], is_query=False) + + >>> negative = model.tokenize([ + ... "fruits are bad for health.", + ... ], is_query=False) + + >>> sentence_features = [anchor, positive, negative] + + >>> loss = loss(sentence_features=sentence_features) + >>> assert isinstance(loss.item(), float) + + """ + + def __init__( + self, + model: ColBERT, + distance_metric=colbert_scores, + size_average: bool = True, + ) -> None: + super(Contrastive, self).__init__() + self.distance_metric = distance_metric + self.model = model + self.size_average = size_average + + def forward( + self, sentence_features: Iterable[dict[str, Tensor]], **kwargs + ) -> torch.Tensor: + """Compute the Constrastive loss. + + Parameters + ---------- + sentence_features + List of tokenized sentences. The first sentence is the anchor and the rest are the positive and negative examples. + + """ + embeddings = [ + torch.nn.functional.normalize( + self.model(sentence_feature)["token_embeddings"], p=2, dim=-1 + ) + for sentence_feature in sentence_features + ] + + masks = extract_skiplist_mask( + sentence_features=sentence_features, skiplist=self.model.skiplist + ) + + # Note: the queries mask is not used, if added, take care that the expansion tokens are not masked from scoring (because they might be masked during encoding). + # We might not need to compute the mask for queries but I let the logic there for now + distances = torch.cat( + [ + self.distance_metric(embeddings[0], group_embeddings, mask) + for group_embeddings, mask in zip(embeddings[1:], masks[1:]) + ], + dim=1, + ) + + # create corresponding labels + # labels = torch.arange(0, rep_anchor.size(0), device=rep_anchor.device) + labels = torch.arange(0, embeddings[0].size(0), device=embeddings[0].device) + # compute constrastive loss using cross-entropy over the distances + + return F.cross_entropy( + distances, labels, reduction="mean" if self.size_average else "sum" + ) diff --git a/giga_cherche/losses/distillation.py b/giga_cherche/losses/distillation.py new file mode 100644 index 0000000..9df4f48 --- /dev/null +++ b/giga_cherche/losses/distillation.py @@ -0,0 +1,106 @@ +from typing import Callable, Iterable + +import torch + +from ..models import ColBERT +from ..scores import colbert_kd_scores +from .contrastive import extract_skiplist_mask + + +class Distillation(torch.nn.Module): + """Distillation loss for ColBERT model. The loss is computed with respect to the format of SentenceTransformer library. + + Parameters + ---------- + model + SentenceTransformer model. + distance_metric + Function that returns a distance between two embeddings. + size_average + Average by the size of the mini-batch or perform sum. + + Examples + -------- + >>> from giga_cherche import models, losses + + >>> model = models.ColBERT( + ... model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu" + ... ) + + >>> distillation = losses.Distillation(model=model) + + >>> anchor = model.tokenize([ + ... "fruits are healthy.", + ... ], is_query=True) + + >>> positive = model.tokenize([ + ... "fruits are good for health.", + ... ], is_query=False) + + >>> negative = model.tokenize([ + ... "fruits are bad for health.", + ... ], is_query=False) + + >>> sentence_features = [anchor, positive, negative] + + >>> labels = torch.tensor([ + ... [0.7, 0.3], + ... ], dtype=torch.float32) + + >>> loss = distillation(sentence_features=sentence_features, labels=labels) + + >>> assert isinstance(loss.item(), float) + """ + + def __init__( + self, + model: ColBERT, + distance_metric: Callable = colbert_kd_scores, + size_average: bool = True, + ) -> None: + super(Distillation, self).__init__() + self.distance_metric = distance_metric + self.model = model + self.loss_function = torch.nn.KLDivLoss( + reduction="batchmean" if size_average else "sum", log_target=True + ) + + def forward( + self, sentence_features: Iterable[dict[str, torch.Tensor]], labels: torch.Tensor + ) -> torch.Tensor: + """Computes the distillation loss with respect to SentenceTransformer. + + Parameters + ---------- + sentence_features + List of tokenized sentences. The first sentence is the anchor and the rest are the positive and negative examples. + labels + The logits for the distillation loss. + + """ + embeddings = [ + torch.nn.functional.normalize( + self.model(sentence_feature)["token_embeddings"], p=2, dim=-1 + ) + for sentence_feature in sentence_features + ] + + masks = extract_skiplist_mask( + sentence_features=sentence_features, skiplist=self.model.skiplist + ) + + # Compute the distance between the anchor and positive/negative embeddings. + anchor_embeddings = embeddings[0] + positive_negative_embeddings = torch.stack(embeddings[1:], dim=1) + positive_negative_embeddings_mask = torch.stack(masks[1:], dim=1) + + distances = self.distance_metric( + anchor_embeddings, + positive_negative_embeddings, + positive_negative_embeddings_mask, + ) + + return self.loss_function( + torch.nn.functional.log_softmax(distances, dim=-1), + torch.nn.functional.log_softmax(labels, dim=-1), + ) diff --git a/giga_cherche/models/colbert.py b/giga_cherche/models/colbert.py index 20bd740..e196d84 100644 --- a/giga_cherche/models/colbert.py +++ b/giga_cherche/models/colbert.py @@ -50,7 +50,7 @@ from tqdm.autonotebook import trange from transformers import is_torch_npu_available -from ..utils import MODELS_WITHOUT_FAMILY +from ..utils import HUGGINGFACE_MODELS from .LinearProjection import LinearProjection logger = logging.getLogger(__name__) @@ -252,7 +252,7 @@ def __init__( if ( "/" not in model_name_or_path - and model_name_or_path.lower() not in MODELS_WITHOUT_FAMILY + and model_name_or_path.lower() not in HUGGINGFACE_MODELS ): # A model from sentence-transformers model_name_or_path = ( @@ -722,7 +722,8 @@ def pool_embeddings_hierarchical( return pooled_embeddings - def skiplist_mask(self, input_ids, skiplist): + @staticmethod + def skiplist_mask(input_ids, skiplist): skiplist = torch.tensor(skiplist, dtype=torch.long, device=input_ids.device) # Create a tensor of ones with the same shape as input_ids @@ -757,11 +758,6 @@ def similarity_fn_name(self) -> str | None: Returns: str | None: The name of the similarity function. Can be None if not set, in which case any uses of :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise` default to "cosine". - - Example: - >>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1") - >>> model.similarity_fn_name - 'dot' """ return self._similarity_fn_name @@ -804,31 +800,6 @@ def similarity( Returns: torch.Tensor: A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores. - - Example: - :: - - >>> model = SentenceTransformer("all-mpnet-base-v2") - >>> sentences = [ - ... "The weather is so nice!", - ... "It's so sunny outside.", - ... "He's driving to the movie theater.", - ... "She's going to the cinema.", - ... ] - >>> embeddings = model.encode(sentences, normalize_embeddings=True) - >>> model.similarity(embeddings, embeddings) - tensor([[1.0000, 0.7235, 0.0290, 0.1309], - [0.7235, 1.0000, 0.0613, 0.1129], - [0.0290, 0.0613, 1.0000, 0.5027], - [0.1309, 0.1129, 0.5027, 1.0000]]) - >>> model.similarity_fn_name - "cosine" - >>> model.similarity_fn_name = "euclidean" - >>> model.similarity(embeddings, embeddings) - tensor([[-0.0000, -0.7437, -1.3935, -1.3184], - [-0.7437, -0.0000, -1.3702, -1.3320], - [-1.3935, -1.3702, -0.0000, -0.9973], - [-1.3184, -1.3320, -0.9973, -0.0000]]) """ if self.similarity_fn_name is None: self.similarity_fn_name = SimilarityFunction.COSINE @@ -860,25 +831,6 @@ def similarity_pairwise( Returns: torch.Tensor: A [num_embeddings]-shaped torch tensor with pairwise similarity scores. - - Example: - :: - - >>> model = SentenceTransformer("all-mpnet-base-v2") - >>> sentences = [ - ... "The weather is so nice!", - ... "It's so sunny outside.", - ... "He's driving to the movie theater.", - ... "She's going to the cinema.", - ... ] - >>> embeddings = model.encode(sentences, normalize_embeddings=True) - >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2]) - tensor([0.7235, 0.5027]) - >>> model.similarity_fn_name - "cosine" - >>> model.similarity_fn_name = "euclidean" - >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2]) - tensor([-0.7437, -0.9973]) """ if self.similarity_fn_name is None: self.similarity_fn_name = SimilarityFunction.COSINE @@ -1215,16 +1167,6 @@ def truncate_sentence_embeddings(self, truncate_dim: int | None): Args: truncate_dim (int, optional): The dimension to truncate sentence embeddings to. ``None`` does no truncation. - Example: - :: - - from sentence_transformers import SentenceTransformer - - model = SentenceTransformer("all-mpnet-base-v2") - - with model.truncate_sentence_embeddings(truncate_dim=16): - embeddings_truncated = model.encode(["hello there", "hiya"]) - assert embeddings_truncated.shape[-1] == 16 """ original_output_dim = self.truncate_dim try: diff --git a/giga_cherche/rerank/colbert.py b/giga_cherche/rerank/colbert.py index b77a38d..62f58b2 100644 --- a/giga_cherche/rerank/colbert.py +++ b/giga_cherche/rerank/colbert.py @@ -2,17 +2,11 @@ import torch from ..indexes import Base as BaseIndex -from ..scores import colbert_score - -__all__ = ["ColBERT"] +from ..scores import colbert_scores class ColBERT: - """Rerank - - Parameters - - """ + """ColBERT reranker.""" def __init__(self, index: BaseIndex) -> None: self.index = index @@ -38,12 +32,16 @@ def rerank( torch.tensor(embeddings, dtype=torch.float32, device=query.device) for embeddings in query_documents_embeddings ] + documents_embeddings = torch.nn.utils.rnn.pad_sequence( documents_embeddings, batch_first=True, padding_value=0 ) - query_scores = colbert_score.colbert_score( - query.unsqueeze(0), documents_embeddings + + query_scores = colbert_scores( + queries_embeddings=query.unsqueeze(0), + documents_embeddings=documents_embeddings, )[0] + reranked_query_scores, sorted_indices = torch.sort( query_scores, descending=True ) diff --git a/giga_cherche/retrieve/colbert.py b/giga_cherche/retrieve/colbert.py index 6605847..3c8f9c9 100644 --- a/giga_cherche/retrieve/colbert.py +++ b/giga_cherche/retrieve/colbert.py @@ -4,11 +4,18 @@ from ..indexes import Base as BaseIndex from ..rerank import ColBERT as ColBERTReranker -__all__ = ["ColBERT"] - -# TODO: define Retriever metaclass class ColBERT: + """ColBERT retriever. + + + Examples + -------- + + + + """ + def __init__(self, index: BaseIndex) -> None: self.index = index self.reranker = ColBERTReranker(index=index) diff --git a/giga_cherche/scores/__init__.py b/giga_cherche/scores/__init__.py new file mode 100644 index 0000000..f29e636 --- /dev/null +++ b/giga_cherche/scores/__init__.py @@ -0,0 +1,3 @@ +from .scores import colbert_scores, colbert_scores_pairwise, colbert_kd_scores + +__all__ = ["colbert_scores", "colbert_scores_pairwise", "colbert_kd_scores"] \ No newline at end of file diff --git a/giga_cherche/scores/colbert_score.py b/giga_cherche/scores/colbert_score.py deleted file mode 100644 index fb2a303..0000000 --- a/giga_cherche/scores/colbert_score.py +++ /dev/null @@ -1,87 +0,0 @@ -import logging - -import numpy as np -import torch -from sentence_transformers.util import _convert_to_batch_tensor - -logger = logging.getLogger(__name__) - -__all__ = ["colbert_score", "colbert_pairwise_score"] - - -def convert_to_tensor(data): - if not isinstance(data, torch.Tensor): - if isinstance(data[0], np.ndarray): - data = torch.from_numpy(np.array(data, dtype=np.float32)) - else: - data = torch.stack(data) - return data - - -def colbert_score( - a: list | np.ndarray | torch.Tensor, b: list | np.ndarray | torch.Tensor -) -> torch.Tensor: - """ - Computes the ColBERT score for all pairs of vectors in a and b. - - Args: - a (Union[list, np.ndarray, Tensor]): The first tensor. - b (Union[list, np.ndarray, Tensor]): The second tensor. - - Returns: - Tensor: Matrix with res[i][j] = colbert_score(a[i], b[j]) - """ - a = convert_to_tensor(a) - b = convert_to_tensor(b) - # We do not use explicit mask as padding tokens are full of zeros, thus will yield zero similarity - # a num_queries, s queries_seqlen, h hidden_size, b num_documents, t documents_seqlen - # Take make along the t axis (get max similarity for each query tokens), then sum over all the query tokens - return torch.einsum("ash,bth->abst", a, b).max(axis=3).values.sum(axis=2) - - -# TODO: only compute the diagonal -def colbert_pairwise_score(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """ - Computes the pairwise ColBERT score colbert_score(a[i], b[i]). - Args: - a (Union[list, np.ndarray, Tensor]): The first tensor. - b (Union[list, np.ndarray, Tensor]): The second tensor. - """ - a = _convert_to_batch_tensor(a) - b = _convert_to_batch_tensor(b) - - return torch.einsum("ash,bth->abst", a, b).max(axis=3).values.sum(axis=2).diag() - - -# def dot_score(a: Union[list, np.ndarray, Tensor], b: Union[list, np.ndarray, Tensor]) -> Tensor: -# """ -# Computes the dot-product dot_prod(a[i], b[j]) for all i and j. - -# Args: -# a (Union[list, np.ndarray, Tensor]): The first tensor. -# b (Union[list, np.ndarray, Tensor]): The second tensor. - -# Returns: -# Tensor: Matrix with res[i][j] = dot_prod(a[i], b[j]) -# """ -# a = _convert_to_batch_tensor(a) -# b = _convert_to_batch_tensor(b) - -# return torch.mm(a, b.transpose(0, 1)) - - -# def pairwise_dot_score(a: Tensor, b: Tensor) -> Tensor: -# """ -# Computes the pairwise dot-product dot_prod(a[i], b[i]). - -# Args: -# a (Union[list, np.ndarray, Tensor]): The first tensor. -# b (Union[list, np.ndarray, Tensor]): The second tensor. - -# Returns: -# Tensor: Vector with res[i] = dot_prod(a[i], b[i]) -# """ -# a = _convert_to_tensor(a) -# b = _convert_to_tensor(b) - -# return (a * b).sum(dim=-1) diff --git a/giga_cherche/scores/scores.py b/giga_cherche/scores/scores.py new file mode 100644 index 0000000..cc35fdf --- /dev/null +++ b/giga_cherche/scores/scores.py @@ -0,0 +1,155 @@ +"""ColBERT scores computation.""" + +import numpy as np +import torch + +from ..utils.tensor import convert_to_tensor + + +def colbert_scores( + queries_embeddings: list | np.ndarray | torch.Tensor, + documents_embeddings: list | np.ndarray | torch.Tensor, + mask: torch.Tensor = None, +) -> torch.Tensor: + """Computes the ColBERT scores between queries and documents embeddings. The score is computed as the sum of maximum similarities + between the query and the document. + + Parameters + ---------- + queries_embeddings + The first tensor. The queries embeddings. Shape: (batch_size, num tokens queries, embedding_size) + documents_embeddings + The second tensor. The documents embeddings. Shape: (batch_size, num tokens documents, embedding_size) + + Examples + -------- + >>> import torch + + >>> queries_embeddings = torch.tensor([ + ... [[1.], [0.], [0.], [0.]], + ... [[0.], [2.], [0.], [0.]], + ... [[0.], [0.], [3.], [0.]], + ... ]) + + >>> documents_embeddings = torch.tensor([ + ... [[10.], [0.], [1.]], + ... [[0.], [100.], [1.]], + ... [[1.], [0.], [1000.]], + ... ]) + + >>> scores = colbert_scores( + ... queries_embeddings=queries_embeddings, + ... documents_embeddings=documents_embeddings + ... ) + + >>> scores + tensor([[ 10., 100., 1000.], + [ 20., 200., 2000.], + [ 30., 300., 3000.]]) + + """ + queries_embeddings = convert_to_tensor(queries_embeddings) + documents_embeddings = convert_to_tensor(documents_embeddings) + + scores = torch.einsum( + "ash,bth->abst", + queries_embeddings, + documents_embeddings, + ) + + if mask is not None: + mask = convert_to_tensor(mask) + scores = scores * mask.unsqueeze(0).unsqueeze(2) + + return scores.max(axis=-1).values.sum(axis=-1) + + +def colbert_scores_pairwise( + queries_embeddings: list | np.ndarray | torch.Tensor, + documents_embeddings: list | np.ndarray | torch.Tensor, +) -> torch.Tensor: + """Computes the ColBERT score for each query-document pair. The score is computed as the sum of maximum similarities + between the query and the document for corresponding pairs. + + Parameters + ---------- + queries_embeddings + The first tensor. The queries embeddings. Shape: (batch_size, num tokens queries, embedding_size) + documents_embeddings + The second tensor. The documents embeddings. Shape: (batch_size, num tokens documents, embedding_size) + + Examples + -------- + >>> import torch + + >>> queries_embeddings = torch.tensor([ + ... [[1.], [0.], [0.], [0.]], + ... [[0.], [2.], [0.], [0.]], + ... [[0.], [0.], [3.], [0.]], + ... ]) + + >>> documents_embeddings = torch.tensor([ + ... [[10.], [0.], [1.]], + ... [[0.], [100.], [1.]], + ... [[1.], [0.], [1000.]], + ... ]) + + >>> scores = colbert_scores_pairwise( + ... queries_embeddings=queries_embeddings, + ... documents_embeddings=documents_embeddings + ... ) + + >>> scores + tensor([ 10., 200., 3000.]) + """ + return colbert_scores( + queries_embeddings=queries_embeddings, documents_embeddings=documents_embeddings + ).diagonal() + + +def colbert_kd_scores( + queries_embeddings: list | np.ndarray | torch.Tensor, + documents_embeddings: list | np.ndarray | torch.Tensor, + mask: torch.Tensor = None, +) -> torch.Tensor: + """Computes the ColBERT scores between queries and documents embeddings. This scoring function is dedicated to the knowledge distillation pipeline. + + Examples + -------- + >>> import torch + + >>> queries_embeddings = torch.tensor([ + ... [[1.], [0.], [0.], [0.]], + ... [[0.], [2.], [0.], [0.]], + ... [[0.], [0.], [3.], [0.]], + ... ]) + + >>> documents_embeddings = torch.tensor([ + ... [[[10.], [0.], [1.]], [[20.], [0.], [1.]], [[30.], [0.], [1.]]], + ... [[[0.], [100.], [1.]], [[0.], [200.], [1.]], [[0.], [300.], [1.]]], + ... [[[1.], [0.], [1000.]], [[1.], [0.], [2000.]], [[1.], [0.], [3000.]]], + ... ]) + + >>> colbert_kd_scores( + ... queries_embeddings=queries_embeddings, + ... documents_embeddings=documents_embeddings + ... ) + tensor([[ 10., 20., 30.], + [ 200., 400., 600.], + [3000., 6000., 9000.]]) + + """ + queries_embeddings = convert_to_tensor(queries_embeddings) + documents_embeddings = convert_to_tensor(documents_embeddings) + + scores = torch.einsum( + "ash,abth->abst", + queries_embeddings, + documents_embeddings, + ) + + if mask is not None: + mask = convert_to_tensor(mask) + scores = scores * mask.unsqueeze(2) + + return scores.max(axis=-1).values.sum(axis=-1) diff --git a/giga_cherche/utils/__init__.py b/giga_cherche/utils/__init__.py index b24cab1..e992a03 100644 --- a/giga_cherche/utils/__init__.py +++ b/giga_cherche/utils/__init__.py @@ -1,5 +1,13 @@ -from .constants import MODELS_WITHOUT_FAMILY -from .dataset_processing import DatasetProcessing +from .collator import ColBERTCollator +from .huggingface_models import HUGGINGFACE_MODELS from .iter_batch import iter_batch +from .processing import KDProcessing +from .tensor import convert_to_tensor -__all__ = ["MODELS_WITHOUT_FAMILY", "iter_batch", "DatasetProcessing"] +__all__ = [ + "HUGGINGFACE_MODELS", + "iter_batch", + "convert_to_tensor", + "ColBERTCollator", + "KDProcessing", +] diff --git a/giga_cherche/data_collator/colbert.py b/giga_cherche/utils/collator.py similarity index 98% rename from giga_cherche/data_collator/colbert.py rename to giga_cherche/utils/collator.py index 37d9196..bd0c762 100644 --- a/giga_cherche/data_collator/colbert.py +++ b/giga_cherche/utils/collator.py @@ -5,7 +5,7 @@ @dataclass -class ColBERT: +class ColBERTCollator: """Collator for a ColBERT model. This encodes the text columns to {column}_input_ids and {column}_attention_mask columns. The query and the documents are encoded differently. diff --git a/giga_cherche/utils/dataset_processing.py b/giga_cherche/utils/dataset_processing.py deleted file mode 100644 index dd7c0ef..0000000 --- a/giga_cherche/utils/dataset_processing.py +++ /dev/null @@ -1,101 +0,0 @@ -import ast - -import datasets - -__all__ = ["DatasetProcessing"] - - -class DatasetProcessing: - """Preprocess the data by adding queries and documents text to the examples. - - Example: - -------- - from datasets import load_dataset - - from giga_cherche import utils - - train = load_dataset(path="./msmarco_fr", name="train", cache_dir="./msmarco_fr") - queries = load_dataset(path="./msmarco_fr", name="queries", cache_dir="./msmarco_fr") - documents = load_dataset( - path="./msmarco_fr", name="documents", cache_dir="./msmarco_fr" - ) - - train = train.map( - utils.DatasetProcessing( - queries=queries, documents=documents - ).add_queries_and_documents, - remove_columns=[feature for feature in train["train"].features if "id" in feature], - ) - """ - - def __init__(self, queries: datasets.Dataset, documents: datasets.Dataset) -> None: - # self.queries = {query["query_id"]: query["text"] for query in queries["train"]} - self.queries = queries - self.queries_index = { - query_id: i for i, query_id in enumerate(self.queries["train"]["query_id"]) - } - # self.documents = { - # document["document_id"]: document["text"] for document in documents["train"] - # } - self.documents = documents - self.documents_index = { - document_id: i - for i, document_id in enumerate(self.documents["train"]["document_id"]) - } - - def add_queries_and_documents(self, example: dict) -> dict: - """Add queries and documents text to the examples.""" - scores = ast.literal_eval(node_or_string=example["scores"]) - - processed_example = { - "scores": scores, - "query": self.queries["train"][self.queries_index[example["query_id"]]][ - "text" - ], - } - - n_scores = len(scores) - for i in range(n_scores): - try: - processed_example[f"document_{i}"] = self.documents["train"][ - self.documents_index[example[f"document_id_{i}"]] - ]["text"] - except KeyError: - processed_example[f"document_{i}"] = "" - print(f"KeyError: {example[f'document_id_{i}']}") - return processed_example - - def add_queries_and_documents_transform(self, examples: dict) -> dict: - """Add queries and documents text to the examples.""" - examples["scores"] = [ - ast.literal_eval(node_or_string=score)[:32] for score in examples["scores"] - ] - examples["query"] = [ - self.queries["train"][self.queries_index[query_id]]["text"] - for query_id in examples["query_id"] - ] - n_scores = len(examples["scores"][0]) - for i in range(n_scores): - documents = [] - for document_id in examples[f"document_id_{i}"]: - try: - documents.append( - self.documents["train"][self.documents_index[document_id]][ - "text" - ] - ) - # print("loaded") - except KeyError: - documents.append("") - # print(f"KeyError: {document_id}") - examples[f"document_{i}"] = documents - # for i in range(n_scores): - # documents = [] - # try: - # processed_example[f"document_{i}"] = self.documents["train"][ - # self.documents_index[example[f"document_id_{i}"]] - # ]["text"] - # except KeyError: - # processed_example[f"document_{i}"] = "" - # print(f"KeyError: {example[f'document_id_{i}']}") - return examples diff --git a/giga_cherche/utils/constants.py b/giga_cherche/utils/huggingface_models.py similarity index 98% rename from giga_cherche/utils/constants.py rename to giga_cherche/utils/huggingface_models.py index 91c0675..4ac5174 100644 --- a/giga_cherche/utils/constants.py +++ b/giga_cherche/utils/huggingface_models.py @@ -1,7 +1,7 @@ """Relevant constants for the project.""" # List of models that do not have a family in the Hugging Face model hub. -MODELS_WITHOUT_FAMILY = [ +HUGGINGFACE_MODELS = [ "albert-base-v1", "albert-base-v2", "albert-large-v1", diff --git a/giga_cherche/utils/iter_batch.py b/giga_cherche/utils/iter_batch.py index 0e83f3a..76a3139 100644 --- a/giga_cherche/utils/iter_batch.py +++ b/giga_cherche/utils/iter_batch.py @@ -1,14 +1,12 @@ import tqdm -__all__ = ["iter_batch"] - def iter_batch( X: list[str], batch_size: int, tqdm_bar: bool = True, desc: str = "" ) -> list: """Iterate over a list of elements by batch. - Example + Examples ------- >>> from giga_cherche import utils diff --git a/giga_cherche/utils/processing.py b/giga_cherche/utils/processing.py new file mode 100644 index 0000000..ef18b24 --- /dev/null +++ b/giga_cherche/utils/processing.py @@ -0,0 +1,94 @@ +import ast + +import datasets + + +class KDProcessing: + """Dataset processing class for knowledge distillation training. + + Parameters + ---------- + queries + Queries dataset. + documents + Documents dataset. + n_scores + Number of scores to keep for the distillation. + + Examples + -------- + from datasets import load_dataset + from giga_cherche import utils + + train = load_dataset( + path="./msmarco_fr", + name="train", + cache_dir="./msmarco_fr" + ) + + queries = load_dataset( + path="./msmarco_fr", + name="queries", + cache_dir="./msmarco_fr" + ) + + documents = load_dataset( + path="./msmarco_fr", name="documents", cache_dir="./msmarco_fr" + ) + + train = train.map( + utils.DatasetProcessing( + queries=queries, documents=documents + ).transform, + ) + + """ + + def __init__( + self, queries: datasets.Dataset, documents: datasets.Dataset, n_scores: int = 32 + ) -> None: + self.queries = queries + self.documents = documents + self.n_scores = n_scores + + self.queries_index = { + query_id: i + for i, query_id in enumerate(iterable=self.queries["train"]["query_id"]) + } + + self.documents_index = { + document_id: i + for i, document_id in enumerate( + iterable=self.documents["train"]["document_id"] + ) + } + + def transform(self, examples: dict) -> dict: + """Update examples with queries and documents. Also""" + examples["scores"] = [ + ast.literal_eval(node_or_string=score)[: self.n_scores] + for score in examples["scores"] + ] + + examples["query"] = [ + self.queries["train"][self.queries_index[query_id]]["text"] + for query_id in examples["query_id"] + ] + + n_scores = len(examples["scores"][0]) + + for i in range(n_scores): + documents = [] + for document_id in examples[f"document_id_{i}"]: + try: + documents.append( + self.documents["train"][self.documents_index[document_id]][ + "text" + ] + ) + + except KeyError: + documents.append("") + examples[f"document_{i}"] = documents + + return examples diff --git a/giga_cherche/utils/tensor.py b/giga_cherche/utils/tensor.py new file mode 100644 index 0000000..d485ad3 --- /dev/null +++ b/giga_cherche/utils/tensor.py @@ -0,0 +1,63 @@ +import numpy as np +import torch + + +def convert_to_tensor( + x: torch.Tensor | np.ndarray | list[torch.Tensor | np.ndarray | list], +) -> torch.Tensor: + """Converts a list or numpy array to a torch tensor. + + Parameters + ---------- + x + The input data. It can be a torch tensor, a numpy array, or a list of torch tensors, numpy arrays, or lists. + + Examples + -------- + >>> import numpy as np + >>> import torch + + >>> x = [[1., 1., 1.]] + >>> convert_to_tensor(x) + tensor([[1., 1., 1.]]) + + >>> x = np.array([[1., 1., 1.], [2., 2., 2.]], dtype=np.float32) + >>> convert_to_tensor(x) + tensor([[1., 1., 1.], + [2., 2., 2.]]) + + >>> x = [torch.tensor([1., 1., 1.]), torch.tensor([2., 2., 2.])] + >>> convert_to_tensor(x) + tensor([[1., 1., 1.], + [2., 2., 2.]]) + + >>> x = torch.tensor([[1., 1., 1.], [2., 2., 2.]]) + >>> convert_to_tensor(x) + tensor([[1., 1., 1.], + [2., 2., 2.]]) + + >>> x = [] + >>> convert_to_tensor(x) + tensor([]) + + """ + if isinstance(x, torch.Tensor): + return x + + if isinstance(x, np.ndarray): + return torch.from_numpy(x) + + if isinstance(x, list): + if not x: + return torch.tensor([], dtype=torch.float32) + + if isinstance(x[0], np.ndarray): + return torch.from_numpy(np.array(x, dtype=np.float32)) + + if isinstance(x[0], list): + return torch.tensor(x, dtype=torch.float32) + + if isinstance(x[0], torch.Tensor): + return torch.stack(x) + + raise Exception("Unsupported data type") diff --git a/kd_training.py b/kd_training.py index a8404f6..5f0dbbd 100644 --- a/kd_training.py +++ b/kd_training.py @@ -1,10 +1,10 @@ +from datasets import load_dataset from sentence_transformers import ( SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) -from datasets import load_dataset -from giga_cherche import data_collator, losses, models, utils +from giga_cherche import losses, models, utils train = load_dataset( path="./datasets/msmarco_fr_full", @@ -22,17 +22,8 @@ ) train.set_transform( - utils.DatasetProcessing( - queries=queries, documents=documents - ).add_queries_and_documents_transform, - # remove_columns=[feature for feature in train["train"].features if "id" in feature], + utils.KDProcessing(queries=queries, documents=documents).transform, ) -# train = train.map( -# utils.DatasetProcessing( -# queries=queries, documents=documents -# ).add_queries_and_documents, -# # remove_columns=[feature for feature in train["train"].features if "id" in feature], -# ) model_name = "bert-base-uncased" @@ -53,14 +44,14 @@ learning_rate=1e-5, ) -train_loss = losses.ColBERTLossv2(model=model) +train_loss = losses.Distillation(model=model) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train, loss=train_loss, - data_collator=data_collator.ColBERT(tokenize_fn=model.tokenize), + data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize), ) trainer.train() diff --git a/setup.py b/setup.py index f2a51e0..d40aff5 100644 --- a/setup.py +++ b/setup.py @@ -11,13 +11,12 @@ "accelerate >= 0.31.0", ] -dev = ["ruff >= 0.4.9"] +weaviate = ["weaviate-client >= 4.6.7"] + +dev = ["ruff >= 0.4.9", "pytest-cov >= 5.0.0", "pytest >= 8.2.1"] eval = ["ranx >= 0.3.16", "beir >= 2.0.0"] -index = [ - "weaviate-client >= 4.7.0b0", -] setuptools.setup( name="giga_cherche", @@ -30,7 +29,11 @@ keywords=[], packages=setuptools.find_packages(), install_requires=base_packages, - extras_require={"dev": base_packages + dev + eval, "eval": base_packages + eval}, + extras_require={ + "weaviate": weaviate, + "eval": base_packages + weaviate + eval, + "dev": base_packages + dev + eval, + }, classifiers=[ "Programming Language :: Python :: 3", "Operating System :: OS Independent", diff --git a/training.py b/training.py index d995226..7dabd42 100644 --- a/training.py +++ b/training.py @@ -5,7 +5,7 @@ ) from sentence_transformers.training_args import BatchSamplers -from giga_cherche import data_collator, evaluation, losses, models +from giga_cherche import evaluation, losses, models, utils model_name = "NohTow/colbertv2_sentence_transformer" # "distilroberta-base" # Choose the model you want batch_size = 32 # The larger you select this, the better the results (usually). But it requires more GPU memory @@ -29,7 +29,7 @@ MAX_EXAMPLES = 100000 train_dataset = train_dataset.shuffle(seed=21).select(range(MAX_EXAMPLES)) -train_loss = losses.ColBERTLossv1(model=model) +train_loss = losses.Contrastive(model=model) # Subsample the evaluation dataset # max_samples = 1000 @@ -77,7 +77,7 @@ eval_dataset=eval_dataset, loss=train_loss, evaluator=dev_evaluator, - data_collator=data_collator.ColBERT(model.tokenize), + data_collator=utils.ColBERTCollator(model.tokenize), ) trainer.train()