Skip to content

Commit

Permalink
Merge pull request #23 from lightonai/knowledge-distillation
Browse files Browse the repository at this point in the history
Knowledge distillation
  • Loading branch information
NohTow authored Jul 26, 2024
2 parents 123e9aa + 7b939ca commit 8775b7a
Show file tree
Hide file tree
Showing 14 changed files with 421 additions and 109 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ wandb/

# output models
output_models/
output/

# datasets
evaluation_datasets/
evaluation_datasets/
datasets/
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The following parameters can be passed to the constructor to set different prope
- ```attend_to_expansion_tokens```, whether queries tokens should attend to MASK expansion tokens (original ColBERT did not)
- ```skiplist_words```, a list of words to ignore in documents during scoring (default to punctuation)

# Training
## Training

Given that giga-cherche ColBERT models are sentence-transformers models, we can benefit from all the bells and whistles from the latest update, including multi-gpu and BF16 training.
For now, you can train ColBERT models using triplets dataset (datasets containing a positive and a negative for each query). The syntax is the same as sentence-transformers, using the specific elements adapted to ColBERT from giga-cherche:
Expand Down Expand Up @@ -83,6 +83,23 @@ trainer = SentenceTransformerTrainer(
trainer.train()
```

## Tokenization

```
import ast
def add_queries_and_documents(example: dict) -> dict:
"""Add queries and documents text to the examples."""
scores = ast.literal_eval(node_or_string=example["scores"])
processed_example = {"scores": scores, "query": queries[example["query_id"]]}
n_scores = len(scores)
for i in range(n_scores):
processed_example[f"document_{i}"] = documents[example[f"document_id_{i}"]]
return processed_example
```

## Inference
Once trained, the model can then be loaded to perform inference (you can also load the models directly from Hugging Face, for example using the provided ColBERTv2 model [NohTow/colbertv2_sentence_transformer](https://huggingface.co/NohTow/colbertv2_sentence_transformer)):

Expand Down
2 changes: 1 addition & 1 deletion evaluation_beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

for batch in utils.iter_batch(queries, batch_size=5):
queries_embeddings = model.encode(
sentences=[query["text"] for query in batch],
sentences=batch,
convert_to_numpy=True,
is_query=True,
)
Expand Down
2 changes: 1 addition & 1 deletion giga_cherche/data_collator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .colbert_data_collator import ColBERT
from .colbert import ColBERT

__all__ = ["ColBERT"]
23 changes: 12 additions & 11 deletions giga_cherche/data_collator/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ class ColBERT:
"""

tokenize_fn: Callable
valid_label_columns: list[str] = field(default_factory=lambda: ["label", "score"])
valid_label_columns: list[str] = field(default_factory=lambda: ["label", "scores"])

def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]:
"""Collate a list of features into a batch."""
columns = list(features[0].keys())

# We should always be able to return a loss, label or not:
batch = {"return_loss": True}

Expand All @@ -36,14 +35,16 @@ def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]:

# Extract the feature columns
for column in columns:
# We tokenize the query differently than the documents, TODO: define a parameter "query_column"
is_query = "query" in column or "anchor" in column

tokenized = self.tokenize_fn(
[row[column] for row in features], is_query=is_query
)

for key, value in tokenized.items():
batch[f"{column}_{key}"] = value
# We do not tokenize columns containing the ids. It would be better to throw them away during the dataset processing (TODO), but this break sentence transformers datasets extraction.
if "_id" not in column:
# We tokenize the query differently than the documents, TODO: define a parameter "query_column"
is_query = "query" in column or "anchor" in column
tokenized = self.tokenize_fn(
[row[column] for row in features],
is_query=is_query,
pad_document=True,
)
for key, value in tokenized.items():
batch[f"{column}_{key}"] = value

return batch
4 changes: 2 additions & 2 deletions giga_cherche/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .colbert import ColBERTLoss
from .colbert import ColBERTLossv1, ColBERTLossv2

__all__ = ["ColBERTLoss"]
__all__ = ["ColBERTLossv1", "ColBERTLossv2"]
125 changes: 122 additions & 3 deletions giga_cherche/losses/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sentence_transformers import SentenceTransformer
from torch import Tensor, nn

__all__ = ["ColBERTLoss", "ColBERTSimilarityMetric"]
__all__ = ["ColBERTSimilarityMetric", "ColBERTLossv1", "ColBERTLossv2"]


class ColBERTSimilarityMetric(Enum):
Expand All @@ -20,8 +20,21 @@ def COLBERT_SIMILARITY(x, y, mask):
simis = simis * mask.unsqueeze(0).unsqueeze(2)
return simis.max(axis=3).values.sum(axis=2)

def COLBERT_SIMILARITY_KD(x, y, mask):
# x: (a, s, h) where a is num_queries, s is queries_seqlen, h is hidden_size
# y: (a, b, t, h) a is number query, where b is num_documents_per_query, t is documents_seqlen and h is hidden_size
# mask: (a, b, t)

class ColBERTLoss(nn.Module):
# Compute similarities
simis = torch.einsum("ash,abth->abst", x, y)
mask = mask.unsqueeze(2)
simis = simis * mask

# Compute max along t axis and sum along s axis
return simis.max(axis=3).values.sum(axis=2)


class ColBERTLossv1(nn.Module):
def __init__(
self,
model: SentenceTransformer,
Expand Down Expand Up @@ -78,7 +91,7 @@ def __init__(
)
trainer.train()
"""
super(ColBERTLoss, self).__init__()
super(ColBERTLossv1, self).__init__()
self.distance_metric = distance_metric
self.model = model
self.size_average = size_average
Expand Down Expand Up @@ -132,6 +145,7 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor
],
dim=1,
)

# create corresponding labels
# labels = torch.arange(0, rep_anchor.size(0), device=rep_anchor.device)
labels = torch.arange(0, reps[0].size(0), device=reps[0].device)
Expand Down Expand Up @@ -166,3 +180,108 @@ def citation(self) -> str:
abstract = "Neural information retrieval (IR) has greatly advanced search and other knowledge-intensive language tasks. While many neural IR methods encode queries and documents into single-vector representations, late interaction models produce multi-vector representations at the granularity of each token and decompose relevance modeling into scalable token-level computations. This decomposition has been shown to make late interaction more effective, but it inflates the space footprint of these models by an order of magnitude. In this work, we introduce ColBERTv2, a retriever that couples an aggressive residual compression mechanism with a denoised supervision strategy to simultaneously improve the quality and space footprint of late interaction. We evaluate ColBERTv2 across a wide range of benchmarks, establishing state-of-the-art quality within and outside the training domain while reducing the space footprint of late interaction models by 6{--}10x.",
}
"""


class ColBERTLossv2(nn.Module):
def __init__(
self,
model: SentenceTransformer,
distance_metric=ColBERTSimilarityMetric.COLBERT_SIMILARITY_KD,
size_average: bool = True,
):
"""
Args:
model: SentenceTransformer model
distance_metric: Function that returns a distance between
two embeddings. The class ColBERTDistanceMetric contains
pre-defined metrices that can be used
size_average: Average by the size of the mini-batch.
Requirements:
1. query, list of documents, teacher scores
"""
super(ColBERTLossv2, self).__init__()
self.distance_metric = distance_metric
self.model = model
self.size_average = size_average

def get_config_dict(self):
distance_metric_name = self.distance_metric.__name__
for name, value in vars(ColBERTSimilarityMetric).items():
if value == self.distance_metric:
distance_metric_name = "ColBERTSimilarityMetric.{}".format(name)
break

return {
"distance_metric": distance_metric_name,
"size_average": self.size_average,
}

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor):
reps = [
torch.nn.functional.normalize(
self.model(sentence_feature)["token_embeddings"], p=2, dim=-1
)
for sentence_feature in sentence_features
]

attention_masks = [
sentence_feature["attention_mask"] for sentence_feature in sentence_features
]
# We are not applying skiplist mask to the queries
skiplist_masks = [
torch.ones_like(sentence_features[0]["input_ids"], dtype=torch.bool)
]
skiplist_masks.extend(
[
self.model.skiplist_mask(
sentence_feature["input_ids"], skiplist=self.model.skiplist
)
for sentence_feature in sentence_features[1:]
]
)

masks = [
torch.logical_and(skiplist_mask, attention_mask)
for skiplist_mask, attention_mask in zip(skiplist_masks, attention_masks)
]
# Compute the distances between the anchor (0) and the positives (1) as well as the negatives (2)
# Note: the queries mask is not used, if added, take care that the expansion tokens are not masked from scoring (because they might be masked during encoding). We might not need to compute the mask for queries but I let the logic there for now
documents = torch.stack(reps[1:], dim=1)

documents_mask = torch.stack(masks[1:], dim=1)

distances = self.distance_metric(reps[0], documents, documents_mask)
target_scores = torch.nn.functional.log_softmax(labels, dim=-1)
log_scores = torch.nn.functional.log_softmax(distances, dim=-1)
loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)(
log_scores, target_scores
)
return loss

@property
def citation(self) -> str:
return """
@inproceedings{santhanam-etal-2022-colbertv2,
title = "{C}ol{BERT}v2: Effective and Efficient Retrieval via Lightweight Late Interaction",
author = "Santhanam, Keshav and
Khattab, Omar and
Saad-Falcon, Jon and
Potts, Christopher and
Zaharia, Matei",
editor = "Carpuat, Marine and
de Marneffe, Marie-Catherine and
Meza Ruiz, Ivan Vladimir",
booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies",
month = jul,
year = "2022",
address = "Seattle, United States",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2022.naacl-main.272",
doi = "10.18653/v1/2022.naacl-main.272",
pages = "3715--3734",
abstract = "Neural information retrieval (IR) has greatly advanced search and other knowledge-intensive language tasks. While many neural IR methods encode queries and documents into single-vector representations, late interaction models produce multi-vector representations at the granularity of each token and decompose relevance modeling into scalable token-level computations. This decomposition has been shown to make late interaction more effective, but it inflates the space footprint of these models by an order of magnitude. In this work, we introduce ColBERTv2, a retriever that couples an aggressive residual compression mechanism with a denoised supervision strategy to simultaneously improve the quality and space footprint of late interaction. We evaluate ColBERTv2 across a wide range of benchmarks, establishing state-of-the-art quality within and outside the training domain while reducing the space footprint of late interaction models by 6{--}10x.",
}
"""
Loading

0 comments on commit 8775b7a

Please sign in to comment.