Skip to content

Commit

Permalink
Merge pull request #32 from lightonai/rename_variables
Browse files Browse the repository at this point in the history
Renaming anchor, distances and pos/neg
  • Loading branch information
NohTow authored Aug 9, 2024
2 parents ff338a8 + 7539f5c commit f290299
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 35 deletions.
14 changes: 7 additions & 7 deletions giga_cherche/losses/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Contrastive(nn.Module):
----------
model
ColBERT model.
distance_metric
score_metric
ColBERT scoring function. Defaults to colbert_scores.
size_average
Average by the size of the mini-batch.
Expand Down Expand Up @@ -111,11 +111,11 @@ class Contrastive(nn.Module):
def __init__(
self,
model: ColBERT,
distance_metric=colbert_scores,
score_metric=colbert_scores,
size_average: bool = True,
) -> None:
super(Contrastive, self).__init__()
self.distance_metric = distance_metric
self.score_metric = score_metric
self.model = model
self.size_average = size_average

Expand Down Expand Up @@ -145,9 +145,9 @@ def forward(

# Note: the queries mask is not used, if added, take care that the expansion tokens are not masked from scoring (because they might be masked during encoding).
# We might not need to compute the mask for queries but I let the logic there for now
distances = torch.cat(
scores = torch.cat(
[
self.distance_metric(embeddings[0], group_embeddings, mask)
self.score_metric(embeddings[0], group_embeddings, mask)
for group_embeddings, mask in zip(embeddings[1:], masks[1:])
],
dim=1,
Expand All @@ -156,10 +156,10 @@ def forward(
# create corresponding labels
# labels = torch.arange(0, rep_anchor.size(0), device=rep_anchor.device)
labels = torch.arange(0, embeddings[0].size(0), device=embeddings[0].device)
# compute constrastive loss using cross-entropy over the distances
# compute constrastive loss using cross-entropy over the scores

return F.cross_entropy(
input=distances,
input=scores,
target=labels,
reduction="mean" if self.size_average else "sum",
)
51 changes: 23 additions & 28 deletions giga_cherche/losses/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class Distillation(torch.nn.Module):
----------
model
SentenceTransformer model.
distance_metric
Function that returns a distance between two embeddings.
score_metric
Function that returns a score between two sequences of embeddings.
size_average
Average by the size of the mini-batch or perform sum.
Expand All @@ -29,19 +29,16 @@ class Distillation(torch.nn.Module):
>>> distillation = losses.Distillation(model=model)
>>> anchor = model.tokenize([
>>> query = model.tokenize([
... "fruits are healthy.",
... ], is_query=True)
>>> positive = model.tokenize([
>>> documents = model.tokenize([
... "fruits are good for health.",
... "fruits are bad for health."
... ], is_query=False)
>>> negative = model.tokenize([
... "fruits are bad for health.",
... ], is_query=False)
>>> sentence_features = [anchor, positive, negative]
>>> sentence_features = [query, documents]
>>> labels = torch.tensor([
... [0.7, 0.3],
Expand All @@ -55,12 +52,12 @@ class Distillation(torch.nn.Module):
def __init__(
self,
model: ColBERT,
distance_metric: Callable = colbert_kd_scores,
score_metric: Callable = colbert_kd_scores,
size_average: bool = True,
normalize_scores: bool = True,
) -> None:
super(Distillation, self).__init__()
self.distance_metric = distance_metric
self.score_metric = score_metric
self.model = model
self.loss_function = torch.nn.KLDivLoss(
reduction="batchmean" if size_average else "sum", log_target=True
Expand All @@ -75,49 +72,47 @@ def forward(
Parameters
----------
sentence_features
List of tokenized sentences. The first sentence is the anchor and the rest are the positive and negative examples.
List of tokenized sentences. The first sentence is the query and the rest are documents.
labels
The logits for the distillation loss.
"""
anchor_embeddings = torch.nn.functional.normalize(
queries_embeddings = torch.nn.functional.normalize(
self.model(sentence_features[0])["token_embeddings"], p=2, dim=-1
)
# Compute the bs * n_ways embeddings
positive_negative_embeddings = torch.nn.functional.normalize(
documents_embeddings = torch.nn.functional.normalize(
self.model(sentence_features[1])["token_embeddings"], p=2, dim=-1
)

# Reshape them to (bs, n_ways)
positive_negative_embeddings = positive_negative_embeddings.view(
anchor_embeddings.size(0), -1, *positive_negative_embeddings.shape[1:]
documents_embeddings = documents_embeddings.view(
queries_embeddings.size(0), -1, *documents_embeddings.shape[1:]
)

masks = extract_skiplist_mask(
sentence_features=sentence_features, skiplist=self.model.skiplist
)

positive_negative_embeddings_mask = masks[1].view(
anchor_embeddings.size(0), -1, *masks[1].shape[1:]
documents_embeddings_mask = masks[1].view(
queries_embeddings.size(0), -1, *masks[1].shape[1:]
)
distances = self.distance_metric(
anchor_embeddings,
positive_negative_embeddings,
positive_negative_embeddings_mask,
scores = self.score_metric(
queries_embeddings,
documents_embeddings,
documents_embeddings_mask,
)
if self.normalize_scores:
# Compute max and min along the num_scores dimension (dim=1)
max_distances, _ = torch.max(distances, dim=1, keepdim=True)
min_distances, _ = torch.min(distances, dim=1, keepdim=True)
max_scores, _ = torch.max(scores, dim=1, keepdim=True)
min_scores, _ = torch.min(scores, dim=1, keepdim=True)

# Avoid division by zero by adding a small epsilon
epsilon = 1e-8

# Normalize the scores
distances = (distances - min_distances) / (
max_distances - min_distances + epsilon
)
scores = (scores - min_scores) / (max_scores - min_scores + epsilon)
return self.loss_function(
torch.nn.functional.log_softmax(distances, dim=-1),
torch.nn.functional.log_softmax(scores, dim=-1),
torch.nn.functional.log_softmax(labels, dim=-1),
)

0 comments on commit f290299

Please sign in to comment.