Skip to content

Commit

Permalink
[feat] Update mine_hard_negatives to using a full corpus and multip…
Browse files Browse the repository at this point in the history
…le positives (#2848)

* updated mine_hard_negatives method to include a seperate corpus for mining hard negatives.

* Run 'make check'

* Update "corpus" to just a list of strings

* Prevent duplicate embeddings if no separate corpus

* Deduplicate corpus

Add a positive to corpus indices mapping, useful to get non-deduplicated positives and to filter away positives taken from the corpus

* Skip rescoring positive pairs via pos_to_corpus_indices instead

* Add a mine_hard_negatives_from_corpus util

* Speedup pos_to_corpus_indices for large corpora

* Fix range_max by number of max_positives in dataset

* encode in chunks, ensure at least one positive per query always

* Hard_negative_mining with corpus and multiple positives is possible

* docstring

* Fix for random sampling

* fix for return_triplets=False

* Typo on list

* Fix bug with multiple positives. More efficient creation of some tensors.

* Fix offset of positives scoring with multiple chunks

* fix pytorch copy warning

* Only embed each text once; no need for chunking if convert_to_numpy=True

* Undo unintended changes

* Fix mismatch in anchor/positive and negatives if multiple positives per query

* Don't repeat positive_scores as it inflates the positive score counts

* Remove the "Count" for Difference as it's rather confusing

---------

Co-authored-by: Christian Geishauser <[email protected]>
Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
3 people committed Sep 11, 2024
1 parent 8af7c5d commit 6e222cb
Showing 1 changed file with 159 additions and 60 deletions.
219 changes: 159 additions & 60 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def semantic_search(
def mine_hard_negatives(
dataset: Dataset,
model: SentenceTransformer,
corpus: list[str] | None = None,
cross_encoder: CrossEncoder | None = None,
range_min: int = 0,
range_max: int | None = None,
Expand All @@ -528,7 +529,8 @@ def mine_hard_negatives(
num_negatives: int = 3,
sampling_strategy: Literal["random", "top"] = "top",
as_triplets: bool = True,
batch_size=32,
batch_size: int = 32,
faiss_batch_size: int = 16384,
use_faiss: bool = False,
verbose: bool = True,
) -> Dataset:
Expand Down Expand Up @@ -619,6 +621,9 @@ def mine_hard_negatives(
Args:
dataset (Dataset): A dataset containing (anchor, positive) pairs.
model (SentenceTransformer): A SentenceTransformer model to use for embedding the sentences.
corpus (List[str], optional): A list containing documents as strings that will be used as candidate negatives
in addition to the second column in `dataset`. Defaults to None, in which case the second column in
`dataset` will exclusively be used as the negative candidate corpus.
cross_encoder (CrossEncoder, optional): A CrossEncoder model to use for rescoring the candidates. Defaults to None.
range_min (int): Minimum rank of the closest matches to consider as negatives. Defaults to 0.
range_max (int, optional): Maximum rank of the closest matches to consider as negatives. Defaults to None.
Expand All @@ -628,7 +633,8 @@ def mine_hard_negatives(
sampling_strategy (Literal["random", "top"]): Sampling strategy for negatives: "top" or "random". Defaults to "top".
as_triplets (bool): If True, returns up to `num_negatives` (anchor, positive, negative) triplets for each input sample.
If False, returns 1 (anchor, positive, negative_1, ..., negative_n) tuple for each input sample. Defaults to True.
batch_size (int): Batch size for processing. Defaults to 32.
batch_size (int): Batch size for encoding the dataset. Defaults to 32.
faiss_batch_size (int): Batch size for FAISS top-k search. Defaults to 16384.
use_faiss (bool): Whether to use FAISS for similarity search. May be recommended for large datasets. Defaults to False.
verbose (bool): Whether to print statistics and logging. Defaults to True.
Expand All @@ -640,39 +646,65 @@ def mine_hard_negatives(

from datasets import Dataset

# If a dataset has duplicate queries, assume that all duplicates are positive pairs.
columns = dataset.column_names
if len(columns) != 2:
raise ValueError("Dataset must contain exactly two columns.")

# To avoid re-embedding the same query multiple times, we keep a counter of the number of positives per query
positives_per_query = list(dataset.to_pandas().groupby(columns[0]).count().to_dict()[columns[1]].values())
max_positives = max(positives_per_query)

if range_max is None:
if margin is not None or max_score is not None:
# 1 positive, 10 * num_negatives negatives because some might be skipped, and range_min skipped
range_max = range_min + (num_negatives * 10) + 1
# max_positives + 10 * num_negatives negatives because some might be skipped, and range_min skipped
range_max = range_min + (num_negatives * 10) + max_positives
else:
# 1 positive, num_negatives negatives, and range_min skipped
range_max = range_min + num_negatives + 1
# max_positives, num_negatives negatives, and range_min skipped
range_max = range_min + num_negatives + max_positives
if range_max > 2048 and use_faiss:
# FAISS on GPU can only retrieve up to 2048 documents per query
range_max = 2048
if verbose:
print("Using FAISS, we can only retrieve up to 2048 documents per query. Setting range_max to 2048.")
if verbose:
print(f"Setting range_max to {range_max} based on other parameters.")

# Combine anchor and positive sentences to get unique corpus
columns = dataset.column_names
if len(columns) != 2:
raise ValueError("Dataset must contain exactly two columns.")
print(f"Setting range_max to {range_max} based on the provided parameters.")

log_counters = {}
queries = dataset[columns[0]]
corpus = dataset[columns[1]]
positives = dataset[columns[1]]
separate_corpus = corpus is not None
if not separate_corpus:
corpus = positives

# Deduplicate the corpus
# make sure all the positives are also in the corpus and de-duplicate it.
corpus = list(set(corpus) | set(positives))

# Embed the corpus and queries
corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
batch_idx = torch.arange(len(queries)).unsqueeze(-1)
# corpus_idx maps the corpus text into its position in the corpus
# This position does not necessarily matches the original corpus, as it was de-duplicated.
corpus_idx = {text: idx for idx, text in enumerate(corpus)}

# Deduplicate the queries, but keep the original one for later reference.
all_queries = queries.copy()
queries = list(set(queries))
queries_idx = {query: idx for idx, query in enumerate(queries)}
n_queries = len(queries)
batch_idx = torch.arange(n_queries).unsqueeze(-1)

device = model.device

if n_queries != len(all_queries) and verbose:
print(f"Found {n_queries} unique queries out of {len(all_queries)} total queries.")

if max_positives > 1:
avg_positives_per_query = np.mean(positives_per_query)
print(f"Found an average of {avg_positives_per_query:.3f} positives per query.")

if use_faiss:
import faiss

# Compute the positive scores separate from FAISS
positive_scores = model.similarity_pairwise(query_embeddings, corpus_embeddings).cpu()

query_embeddings = query_embeddings.cpu().numpy()
corpus_embeddings = corpus_embeddings.cpu().numpy()
index = faiss.IndexFlatIP(len(corpus_embeddings[0]))
index = faiss.IndexFlatIP(model.get_sentence_embedding_dimension())
# Move the index to the GPU if available
try:
co = faiss.GpuMultipleClonerOptions()
Expand All @@ -681,57 +713,98 @@ def mine_hard_negatives(
index: faiss.IndexFlatIP = faiss.index_cpu_to_all_gpus(index, co=co)
except Exception:
pass

corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True)
query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True)
index.add(corpus_embeddings)

scores_list = []
indices_list = []
# Iterate over query embeddings in batches so we can track the progress
for i in trange(0, len(query_embeddings), batch_size, desc="Querying FAISS index"):
query_chunk = query_embeddings[i : i + batch_size]
for i in trange(0, len(query_embeddings), faiss_batch_size, desc="Querying FAISS index"):
query_chunk = query_embeddings[i : i + faiss_batch_size]
scores, indices = index.search(query_chunk, k=range_max + 1)
scores_list.append(scores)
indices_list.append(indices)
scores = torch.from_numpy(np.concatenate(scores_list, axis=0))
indices = torch.from_numpy(np.concatenate(indices_list, axis=0))
scores = torch.from_numpy(np.concatenate(scores_list, axis=0)).to(device)
indices = torch.from_numpy(np.concatenate(indices_list, axis=0)).to(device)

else:
# Compute all similarity scores
scores = model.similarity(query_embeddings, corpus_embeddings).cpu()
positive_scores = scores.diagonal().clone()
# Embed the corpus and the queries
corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True)
query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True)
scores = model.similarity(query_embeddings, corpus_embeddings).to(device)

# Keep only the range_max + 1 highest scores. We offset by 1 to potentially include the positive pair
scores, indices = torch.topk(scores, k=range_max + 1, dim=1)
del query_embeddings
del corpus_embeddings
# Keep only the range_max + max_positives highest scores. We offset by 1 to potentially include the positive pair
scores, indices = torch.topk(scores, k=range_max + max_positives, dim=1)

# Scores is a [num_queries, range_max + 1] tensor, where we set the values to -inf to disqualify the corresponding
# text as a negative candidate. Here we disqualify the positive pair
positive_indices = indices == torch.arange(len(queries), device=indices.device).unsqueeze(-1)
scores[positive_indices] = -float("inf")
# As we may have duplicated queries (i.e., a single query with multiple positives),
# We keep track, for each unique query, of where their positives are in the list of positives (positive_indices).
# Note that as queries may have differing numbers of positives, we cannot guarantee that this is a fixed-length matrix.
positive_indices = [[] for _ in range(n_queries)]

num_candidates = scores.numel()
for query, positive in zip(all_queries, positives):
query_idx = queries_idx[query]
positive_indices[query_idx].append(corpus_idx[positive])

n_positives = [len(p) for p in positive_indices]

# re-sort the positives and all_queries according to the deduplicated queries
positives = []
all_queries = []
for idx in range(n_queries):
positives.extend([corpus[doc_idx] for doc_idx in positive_indices[idx]])
all_queries.extend([queries[idx]] * n_positives[idx])

positive_indices = [torch.tensor(p, device=device) for p in positive_indices]

# Compute the positive scores
query_embeddings = query_embeddings[[idx for idx in range(n_queries) for _ in range(n_positives[idx])]]
positive_embeddings = corpus_embeddings[torch.cat(positive_indices).tolist()]
positive_scores = model.similarity_pairwise(query_embeddings, positive_embeddings).to(device)

del query_embeddings
del positive_embeddings
del corpus_embeddings

# Rescore with cross_encoder
if cross_encoder is not None and (margin is not None or max_score is not None):
for idx, candidate_neg_idx in tqdm(enumerate(indices), desc="Rescoring with CrossEncoder", total=len(indices)):
for idx, candidate_idx in tqdm(enumerate(indices), desc="Rescoring with CrossEncoder", total=len(indices)):
query = queries[idx]
candidate_passages = [corpus[neg_idx] for neg_idx in candidate_neg_idx]
candidate_passages = [corpus[_idx] for _idx in candidate_idx]
pred_scores = cross_encoder.predict(
list(zip([query] * (range_max + 1), candidate_passages)),
batch_size=batch_size,
convert_to_tensor=True,
)
# If we rescored a positive pair, make sure that it is disqualified again
if idx in candidate_neg_idx:
pred_scores[candidate_neg_idx == idx] = -float("inf")
scores[idx] = pred_scores
positive_scores = cross_encoder.predict(
list(zip(queries, corpus)),
list(zip(all_queries, positives)),
batch_size=batch_size,
convert_to_tensor=True,
)

# for each query, create a mask that is True for the positives and False for the negatives in the indices
positive_mask = torch.stack([torch.isin(indices[q_idx], positive_indices[q_idx]) for q_idx in range(n_queries)])

# Scores is a [num_queries, range_max] tensor, where we set the values to -inf to disqualify the corresponding
# positive candidates
scores[positive_mask] = -float("inf")

num_candidates = scores.numel()

# Remove based on margin
if margin is not None:
removed_indices = scores + margin > positive_scores.repeat(scores.size(1), 1).T
# If we have a margin, we will remove candidates that are too close to the positive pair
# If there are multiple positives, we need to define which one to use for the margin
# To be on the safe side, we will use the _minimum_ positive score (i.e., harder positive) for the margin
max_positive_scores = torch.empty(n_queries, device=positive_scores.device, dtype=positive_scores.dtype)
start_idx = 0
for q_idx in range(n_queries):
max_positive_scores[q_idx] = torch.min(positive_scores[start_idx : start_idx + n_positives[q_idx]])
start_idx += n_positives[q_idx - 1]

removed_indices = scores + margin > max_positive_scores.repeat(scores.size(1), 1).T
scores[removed_indices] = -float("inf")

num_skipped = removed_indices.sum().item()
Expand All @@ -757,6 +830,7 @@ def mine_hard_negatives(
# Grab the top negative candidates and remove the first range_min candidates
negative_scores, local_indices = torch.topk(scores, k=range_max, dim=1)
indices = indices[batch_idx, local_indices]

if range_min:
indices = indices[:, range_min:]
negative_scores = negative_scores[:, range_min:]
Expand All @@ -765,6 +839,7 @@ def mine_hard_negatives(
if sampling_strategy == "top":
indices = indices[:, :num_negatives]
negative_scores = negative_scores[:, :num_negatives]

elif sampling_strategy == "random":
# Prevent sampling -inf values if possible
num_options = indices.size(1) - negative_scores.isinf().sum(1)
Expand All @@ -777,23 +852,42 @@ def mine_hard_negatives(
negative_scores, local_indices = negative_scores.sort(dim=1, descending=True)
indices = indices[batch_idx, local_indices]

# repeat indices and negative_scores by the number of positives of each query
indices = torch.cat([indices[idx].repeat(n_positives[idx], 1) for idx in range(n_queries)])
negative_scores = torch.cat([negative_scores[idx].repeat(n_positives[idx], 1) for idx in range(n_queries)])

if as_triplets:
# negative_scores is [num_queries, num_negatives], but may contain some -inf values if not enough negatives were found
# If calling as triples and there are multiple positives per query, we will explode the dataset into triplets.
indices_to_keep = negative_scores != -float("inf")
# This turns indices and negative_scores into 1d tensors
anchor_indices = torch.empty_like(indices)
pos_indices = torch.empty_like(indices)

indices = indices[indices_to_keep]
negative_scores = negative_scores[indices_to_keep]
anchor_indices = torch.arange(len(queries), device=indices_to_keep.device).repeat(num_negatives, 1).T

# the anchor_indices matrix is shaped [n_total_queries, n_negatives]
start_idx = 0
for q_idx in range(n_queries):
anchor_indices[start_idx : start_idx + n_positives[q_idx]] = torch.tensor(q_idx).repeat(
n_positives[q_idx], num_negatives
)
pos_indices[start_idx : start_idx + n_positives[q_idx]] = (
positive_indices[q_idx].repeat(num_negatives, 1).T
)
start_idx += n_positives[q_idx]

anchor_indices = anchor_indices[indices_to_keep]
positive_indices = pos_indices[indices_to_keep]

triplets_data = {
columns[0]: [],
columns[1]: [],
"negative": [],
}
for anchor_idx, negative_idx in zip(anchor_indices, indices):

for anchor_idx, negative_idx, positive_idx in zip(anchor_indices, indices, positive_indices):
triplets_data[columns[0]].append(queries[anchor_idx])
triplets_data[columns[1]].append(corpus[anchor_idx])
triplets_data[columns[1]].append(corpus[positive_idx])
triplets_data["negative"].append(corpus[negative_idx])
difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep] - negative_scores

Expand All @@ -803,18 +897,16 @@ def mine_hard_negatives(
negative_scores = negative_scores[indices_to_keep]
indices = indices[indices_to_keep]

# Create a list of (anchor, positive, negative_1, ..., negative_`num_negatives`) tuples
triplets_data = {
columns[0]: [queries[idx] for idx in range(len(queries)) if indices_to_keep[idx]],
columns[1]: [corpus[idx] for idx in range(len(corpus)) if indices_to_keep[idx]],
columns[0]: [all_queries[idx] for idx, keep in enumerate(indices_to_keep) if keep],
columns[1]: [positives[idx] for idx, keep in enumerate(indices_to_keep) if keep],
**{
f"negative_{i}": [corpus[neg_idx] for neg_idx in neg_indices]
for i, neg_indices in enumerate(indices.T, start=1)
},
}
# Flatten it so we can use for logging
negative_scores = negative_scores.flatten()
difference_scores = positive_scores[indices_to_keep].repeat(num_negatives, 1).T.flatten() - negative_scores
difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep].flatten() - negative_scores

if len(triplets_data) == 0:
raise ValueError("No triplets could be generated. Please check the parameters and dataset.")
Expand All @@ -823,10 +915,17 @@ def mine_hard_negatives(
# Report some statistics
if verbose:
row_format = "{:<6} {:>14} {:>14} {:>14}"
formatter = lambda value: f"{value.item():.4f}" if isinstance(value, torch.Tensor) else f"{value:,}"
formatter = lambda value: (f"{value.item():.4f}" if isinstance(value, torch.Tensor) else f"{value:,}")
print(row_format.format("Metric", "Positive", "Negative", "Difference"))
print(
row_format.format(
"Count",
formatter(len(positive_scores)),
formatter(len(negative_scores)),
"",
)
)
for metric, function in [
("count", len),
("mean", torch.mean),
("median", torch.median),
("std", torch.std),
Expand Down Expand Up @@ -854,7 +953,7 @@ def mine_hard_negatives(
f"Skipped {log_counters['max_score']['skipped']} potential negatives ({log_counters['max_score']['ratio']:.2%}) due to the maximum score of {max_score}."
)

missing_negatives = (num_negatives * len(queries)) - len(negative_scores)
missing_negatives = (num_negatives * len(dataset)) - len(negative_scores)
if missing_negatives > 0:
solutions = ["range_max"]
if range_min > 0:
Expand All @@ -866,7 +965,7 @@ def mine_hard_negatives(
considerations = ", ".join(solutions[:-1])
if len(solutions) > 1:
considerations += " and " + solutions[-1]
missing_negatives_ratio = missing_negatives / (num_negatives * len(queries))
missing_negatives_ratio = missing_negatives / (num_negatives * len(dataset))
print(
f"Could not find enough negatives for {missing_negatives} samples ({missing_negatives_ratio:.2%})."
f" Consider adjusting the {considerations} parameter{'s' if len(solutions) > 1 else ''} if you'd like to find more valid negatives."
Expand Down

0 comments on commit 6e222cb

Please sign in to comment.