diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index a05c89b3f..59e7bf0c4 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -520,6 +520,7 @@ def semantic_search( def mine_hard_negatives( dataset: Dataset, model: SentenceTransformer, + corpus: list[str] | None = None, cross_encoder: CrossEncoder | None = None, range_min: int = 0, range_max: int | None = None, @@ -528,7 +529,8 @@ def mine_hard_negatives( num_negatives: int = 3, sampling_strategy: Literal["random", "top"] = "top", as_triplets: bool = True, - batch_size=32, + batch_size: int = 32, + faiss_batch_size: int = 16384, use_faiss: bool = False, verbose: bool = True, ) -> Dataset: @@ -619,6 +621,9 @@ def mine_hard_negatives( Args: dataset (Dataset): A dataset containing (anchor, positive) pairs. model (SentenceTransformer): A SentenceTransformer model to use for embedding the sentences. + corpus (List[str], optional): A list containing documents as strings that will be used as candidate negatives + in addition to the second column in `dataset`. Defaults to None, in which case the second column in + `dataset` will exclusively be used as the negative candidate corpus. cross_encoder (CrossEncoder, optional): A CrossEncoder model to use for rescoring the candidates. Defaults to None. range_min (int): Minimum rank of the closest matches to consider as negatives. Defaults to 0. range_max (int, optional): Maximum rank of the closest matches to consider as negatives. Defaults to None. @@ -628,7 +633,8 @@ def mine_hard_negatives( sampling_strategy (Literal["random", "top"]): Sampling strategy for negatives: "top" or "random". Defaults to "top". as_triplets (bool): If True, returns up to `num_negatives` (anchor, positive, negative) triplets for each input sample. If False, returns 1 (anchor, positive, negative_1, ..., negative_n) tuple for each input sample. Defaults to True. - batch_size (int): Batch size for processing. Defaults to 32. + batch_size (int): Batch size for encoding the dataset. Defaults to 32. + faiss_batch_size (int): Batch size for FAISS top-k search. Defaults to 16384. use_faiss (bool): Whether to use FAISS for similarity search. May be recommended for large datasets. Defaults to False. verbose (bool): Whether to print statistics and logging. Defaults to True. @@ -640,39 +646,65 @@ def mine_hard_negatives( from datasets import Dataset + # If a dataset has duplicate queries, assume that all duplicates are positive pairs. + columns = dataset.column_names + if len(columns) != 2: + raise ValueError("Dataset must contain exactly two columns.") + + # To avoid re-embedding the same query multiple times, we keep a counter of the number of positives per query + positives_per_query = list(dataset.to_pandas().groupby(columns[0]).count().to_dict()[columns[1]].values()) + max_positives = max(positives_per_query) + if range_max is None: if margin is not None or max_score is not None: - # 1 positive, 10 * num_negatives negatives because some might be skipped, and range_min skipped - range_max = range_min + (num_negatives * 10) + 1 + # max_positives + 10 * num_negatives negatives because some might be skipped, and range_min skipped + range_max = range_min + (num_negatives * 10) + max_positives else: - # 1 positive, num_negatives negatives, and range_min skipped - range_max = range_min + num_negatives + 1 + # max_positives, num_negatives negatives, and range_min skipped + range_max = range_min + num_negatives + max_positives + if range_max > 2048 and use_faiss: + # FAISS on GPU can only retrieve up to 2048 documents per query + range_max = 2048 + if verbose: + print("Using FAISS, we can only retrieve up to 2048 documents per query. Setting range_max to 2048.") if verbose: - print(f"Setting range_max to {range_max} based on other parameters.") - - # Combine anchor and positive sentences to get unique corpus - columns = dataset.column_names - if len(columns) != 2: - raise ValueError("Dataset must contain exactly two columns.") + print(f"Setting range_max to {range_max} based on the provided parameters.") log_counters = {} queries = dataset[columns[0]] - corpus = dataset[columns[1]] + positives = dataset[columns[1]] + separate_corpus = corpus is not None + if not separate_corpus: + corpus = positives + + # Deduplicate the corpus + # make sure all the positives are also in the corpus and de-duplicate it. + corpus = list(set(corpus) | set(positives)) - # Embed the corpus and queries - corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True) - query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True) - batch_idx = torch.arange(len(queries)).unsqueeze(-1) + # corpus_idx maps the corpus text into its position in the corpus + # This position does not necessarily matches the original corpus, as it was de-duplicated. + corpus_idx = {text: idx for idx, text in enumerate(corpus)} + + # Deduplicate the queries, but keep the original one for later reference. + all_queries = queries.copy() + queries = list(set(queries)) + queries_idx = {query: idx for idx, query in enumerate(queries)} + n_queries = len(queries) + batch_idx = torch.arange(n_queries).unsqueeze(-1) + + device = model.device + + if n_queries != len(all_queries) and verbose: + print(f"Found {n_queries} unique queries out of {len(all_queries)} total queries.") + + if max_positives > 1: + avg_positives_per_query = np.mean(positives_per_query) + print(f"Found an average of {avg_positives_per_query:.3f} positives per query.") if use_faiss: import faiss - # Compute the positive scores separate from FAISS - positive_scores = model.similarity_pairwise(query_embeddings, corpus_embeddings).cpu() - - query_embeddings = query_embeddings.cpu().numpy() - corpus_embeddings = corpus_embeddings.cpu().numpy() - index = faiss.IndexFlatIP(len(corpus_embeddings[0])) + index = faiss.IndexFlatIP(model.get_sentence_embedding_dimension()) # Move the index to the GPU if available try: co = faiss.GpuMultipleClonerOptions() @@ -681,57 +713,98 @@ def mine_hard_negatives( index: faiss.IndexFlatIP = faiss.index_cpu_to_all_gpus(index, co=co) except Exception: pass + + corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) + query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) index.add(corpus_embeddings) + scores_list = [] indices_list = [] # Iterate over query embeddings in batches so we can track the progress - for i in trange(0, len(query_embeddings), batch_size, desc="Querying FAISS index"): - query_chunk = query_embeddings[i : i + batch_size] + for i in trange(0, len(query_embeddings), faiss_batch_size, desc="Querying FAISS index"): + query_chunk = query_embeddings[i : i + faiss_batch_size] scores, indices = index.search(query_chunk, k=range_max + 1) scores_list.append(scores) indices_list.append(indices) - scores = torch.from_numpy(np.concatenate(scores_list, axis=0)) - indices = torch.from_numpy(np.concatenate(indices_list, axis=0)) + scores = torch.from_numpy(np.concatenate(scores_list, axis=0)).to(device) + indices = torch.from_numpy(np.concatenate(indices_list, axis=0)).to(device) + else: - # Compute all similarity scores - scores = model.similarity(query_embeddings, corpus_embeddings).cpu() - positive_scores = scores.diagonal().clone() + # Embed the corpus and the queries + corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) + query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) + scores = model.similarity(query_embeddings, corpus_embeddings).to(device) - # Keep only the range_max + 1 highest scores. We offset by 1 to potentially include the positive pair - scores, indices = torch.topk(scores, k=range_max + 1, dim=1) - del query_embeddings - del corpus_embeddings + # Keep only the range_max + max_positives highest scores. We offset by 1 to potentially include the positive pair + scores, indices = torch.topk(scores, k=range_max + max_positives, dim=1) - # Scores is a [num_queries, range_max + 1] tensor, where we set the values to -inf to disqualify the corresponding - # text as a negative candidate. Here we disqualify the positive pair - positive_indices = indices == torch.arange(len(queries), device=indices.device).unsqueeze(-1) - scores[positive_indices] = -float("inf") + # As we may have duplicated queries (i.e., a single query with multiple positives), + # We keep track, for each unique query, of where their positives are in the list of positives (positive_indices). + # Note that as queries may have differing numbers of positives, we cannot guarantee that this is a fixed-length matrix. + positive_indices = [[] for _ in range(n_queries)] - num_candidates = scores.numel() + for query, positive in zip(all_queries, positives): + query_idx = queries_idx[query] + positive_indices[query_idx].append(corpus_idx[positive]) + + n_positives = [len(p) for p in positive_indices] + + # re-sort the positives and all_queries according to the deduplicated queries + positives = [] + all_queries = [] + for idx in range(n_queries): + positives.extend([corpus[doc_idx] for doc_idx in positive_indices[idx]]) + all_queries.extend([queries[idx]] * n_positives[idx]) + + positive_indices = [torch.tensor(p, device=device) for p in positive_indices] + + # Compute the positive scores + query_embeddings = query_embeddings[[idx for idx in range(n_queries) for _ in range(n_positives[idx])]] + positive_embeddings = corpus_embeddings[torch.cat(positive_indices).tolist()] + positive_scores = model.similarity_pairwise(query_embeddings, positive_embeddings).to(device) + + del query_embeddings + del positive_embeddings + del corpus_embeddings # Rescore with cross_encoder if cross_encoder is not None and (margin is not None or max_score is not None): - for idx, candidate_neg_idx in tqdm(enumerate(indices), desc="Rescoring with CrossEncoder", total=len(indices)): + for idx, candidate_idx in tqdm(enumerate(indices), desc="Rescoring with CrossEncoder", total=len(indices)): query = queries[idx] - candidate_passages = [corpus[neg_idx] for neg_idx in candidate_neg_idx] + candidate_passages = [corpus[_idx] for _idx in candidate_idx] pred_scores = cross_encoder.predict( list(zip([query] * (range_max + 1), candidate_passages)), batch_size=batch_size, convert_to_tensor=True, ) - # If we rescored a positive pair, make sure that it is disqualified again - if idx in candidate_neg_idx: - pred_scores[candidate_neg_idx == idx] = -float("inf") scores[idx] = pred_scores positive_scores = cross_encoder.predict( - list(zip(queries, corpus)), + list(zip(all_queries, positives)), batch_size=batch_size, convert_to_tensor=True, ) + # for each query, create a mask that is True for the positives and False for the negatives in the indices + positive_mask = torch.stack([torch.isin(indices[q_idx], positive_indices[q_idx]) for q_idx in range(n_queries)]) + + # Scores is a [num_queries, range_max] tensor, where we set the values to -inf to disqualify the corresponding + # positive candidates + scores[positive_mask] = -float("inf") + + num_candidates = scores.numel() + # Remove based on margin if margin is not None: - removed_indices = scores + margin > positive_scores.repeat(scores.size(1), 1).T + # If we have a margin, we will remove candidates that are too close to the positive pair + # If there are multiple positives, we need to define which one to use for the margin + # To be on the safe side, we will use the _minimum_ positive score (i.e., harder positive) for the margin + max_positive_scores = torch.empty(n_queries, device=positive_scores.device, dtype=positive_scores.dtype) + start_idx = 0 + for q_idx in range(n_queries): + max_positive_scores[q_idx] = torch.min(positive_scores[start_idx : start_idx + n_positives[q_idx]]) + start_idx += n_positives[q_idx - 1] + + removed_indices = scores + margin > max_positive_scores.repeat(scores.size(1), 1).T scores[removed_indices] = -float("inf") num_skipped = removed_indices.sum().item() @@ -757,6 +830,7 @@ def mine_hard_negatives( # Grab the top negative candidates and remove the first range_min candidates negative_scores, local_indices = torch.topk(scores, k=range_max, dim=1) indices = indices[batch_idx, local_indices] + if range_min: indices = indices[:, range_min:] negative_scores = negative_scores[:, range_min:] @@ -765,6 +839,7 @@ def mine_hard_negatives( if sampling_strategy == "top": indices = indices[:, :num_negatives] negative_scores = negative_scores[:, :num_negatives] + elif sampling_strategy == "random": # Prevent sampling -inf values if possible num_options = indices.size(1) - negative_scores.isinf().sum(1) @@ -777,23 +852,42 @@ def mine_hard_negatives( negative_scores, local_indices = negative_scores.sort(dim=1, descending=True) indices = indices[batch_idx, local_indices] + # repeat indices and negative_scores by the number of positives of each query + indices = torch.cat([indices[idx].repeat(n_positives[idx], 1) for idx in range(n_queries)]) + negative_scores = torch.cat([negative_scores[idx].repeat(n_positives[idx], 1) for idx in range(n_queries)]) + if as_triplets: - # negative_scores is [num_queries, num_negatives], but may contain some -inf values if not enough negatives were found + # If calling as triples and there are multiple positives per query, we will explode the dataset into triplets. indices_to_keep = negative_scores != -float("inf") - # This turns indices and negative_scores into 1d tensors + anchor_indices = torch.empty_like(indices) + pos_indices = torch.empty_like(indices) + indices = indices[indices_to_keep] negative_scores = negative_scores[indices_to_keep] - anchor_indices = torch.arange(len(queries), device=indices_to_keep.device).repeat(num_negatives, 1).T + + # the anchor_indices matrix is shaped [n_total_queries, n_negatives] + start_idx = 0 + for q_idx in range(n_queries): + anchor_indices[start_idx : start_idx + n_positives[q_idx]] = torch.tensor(q_idx).repeat( + n_positives[q_idx], num_negatives + ) + pos_indices[start_idx : start_idx + n_positives[q_idx]] = ( + positive_indices[q_idx].repeat(num_negatives, 1).T + ) + start_idx += n_positives[q_idx] + anchor_indices = anchor_indices[indices_to_keep] + positive_indices = pos_indices[indices_to_keep] triplets_data = { columns[0]: [], columns[1]: [], "negative": [], } - for anchor_idx, negative_idx in zip(anchor_indices, indices): + + for anchor_idx, negative_idx, positive_idx in zip(anchor_indices, indices, positive_indices): triplets_data[columns[0]].append(queries[anchor_idx]) - triplets_data[columns[1]].append(corpus[anchor_idx]) + triplets_data[columns[1]].append(corpus[positive_idx]) triplets_data["negative"].append(corpus[negative_idx]) difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep] - negative_scores @@ -803,18 +897,16 @@ def mine_hard_negatives( negative_scores = negative_scores[indices_to_keep] indices = indices[indices_to_keep] - # Create a list of (anchor, positive, negative_1, ..., negative_`num_negatives`) tuples triplets_data = { - columns[0]: [queries[idx] for idx in range(len(queries)) if indices_to_keep[idx]], - columns[1]: [corpus[idx] for idx in range(len(corpus)) if indices_to_keep[idx]], + columns[0]: [all_queries[idx] for idx, keep in enumerate(indices_to_keep) if keep], + columns[1]: [positives[idx] for idx, keep in enumerate(indices_to_keep) if keep], **{ f"negative_{i}": [corpus[neg_idx] for neg_idx in neg_indices] for i, neg_indices in enumerate(indices.T, start=1) }, } - # Flatten it so we can use for logging negative_scores = negative_scores.flatten() - difference_scores = positive_scores[indices_to_keep].repeat(num_negatives, 1).T.flatten() - negative_scores + difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep].flatten() - negative_scores if len(triplets_data) == 0: raise ValueError("No triplets could be generated. Please check the parameters and dataset.") @@ -823,10 +915,17 @@ def mine_hard_negatives( # Report some statistics if verbose: row_format = "{:<6} {:>14} {:>14} {:>14}" - formatter = lambda value: f"{value.item():.4f}" if isinstance(value, torch.Tensor) else f"{value:,}" + formatter = lambda value: (f"{value.item():.4f}" if isinstance(value, torch.Tensor) else f"{value:,}") print(row_format.format("Metric", "Positive", "Negative", "Difference")) + print( + row_format.format( + "Count", + formatter(len(positive_scores)), + formatter(len(negative_scores)), + "", + ) + ) for metric, function in [ - ("count", len), ("mean", torch.mean), ("median", torch.median), ("std", torch.std), @@ -854,7 +953,7 @@ def mine_hard_negatives( f"Skipped {log_counters['max_score']['skipped']} potential negatives ({log_counters['max_score']['ratio']:.2%}) due to the maximum score of {max_score}." ) - missing_negatives = (num_negatives * len(queries)) - len(negative_scores) + missing_negatives = (num_negatives * len(dataset)) - len(negative_scores) if missing_negatives > 0: solutions = ["range_max"] if range_min > 0: @@ -866,7 +965,7 @@ def mine_hard_negatives( considerations = ", ".join(solutions[:-1]) if len(solutions) > 1: considerations += " and " + solutions[-1] - missing_negatives_ratio = missing_negatives / (num_negatives * len(queries)) + missing_negatives_ratio = missing_negatives / (num_negatives * len(dataset)) print( f"Could not find enough negatives for {missing_negatives} samples ({missing_negatives_ratio:.2%})." f" Consider adjusting the {considerations} parameter{'s' if len(solutions) > 1 else ''} if you'd like to find more valid negatives."