diff --git a/sentence_transformers/data_collator.py b/sentence_transformers/data_collator.py index 24314ec36..071fac7a6 100644 --- a/sentence_transformers/data_collator.py +++ b/sentence_transformers/data_collator.py @@ -1,10 +1,13 @@ from __future__ import annotations +import logging from dataclasses import dataclass, field from typing import Any, Callable import torch +logger = logging.getLogger(__name__) + @dataclass class SentenceTransformerDataCollator: @@ -12,31 +15,79 @@ class SentenceTransformerDataCollator: This encodes the text columns to {column}_input_ids and {column}_attention_mask columns. This works with the two text dataset that is used as the example in the training overview: https://www.sbert.net/docs/sentence_transformer/training_overview.html + + It is important that the columns are in the expected order. For example, if your dataset has columns + "answer", "question" in that order, then the MultipleNegativesRankingLoss will consider + "answer" as the anchor and "question" as the positive, and it will (unexpectedly) optimize for + "given the answer, what is the question?". """ tokenize_fn: Callable valid_label_columns: list[str] = field(default_factory=lambda: ["label", "score"]) + _warned_columns: set[tuple[str]] = field(default_factory=set, init=False, repr=False) def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]: - columns = list(features[0].keys()) + column_names = list(features[0].keys()) # We should always be able to return a loss, label or not: batch = {} - if "dataset_name" in columns: - columns.remove("dataset_name") + if "dataset_name" in column_names: + column_names.remove("dataset_name") batch["dataset_name"] = features[0]["dataset_name"] + if tuple(column_names) not in self._warned_columns: + self.maybe_warn_about_column_order(column_names) + # Extract the label column if it exists for label_column in self.valid_label_columns: - if label_column in columns: + if label_column in column_names: batch["label"] = torch.tensor([row[label_column] for row in features]) - columns.remove(label_column) + column_names.remove(label_column) break # Extract the feature columns - for column in columns: - tokenized = self.tokenize_fn([row[column] for row in features]) + for column_name in column_names: + tokenized = self.tokenize_fn([row[column_name] for row in features]) for key, value in tokenized.items(): - batch[f"{column}_{key}"] = value + batch[f"{column_name}_{key}"] = value return batch + + def maybe_warn_about_column_order(self, column_names: list[str]) -> None: + """Warn the user if the columns are likely not in the expected order.""" + # A mapping from common column names to the expected index in the dataset + column_name_to_expected_idx = { + "anchor": 0, + "positive": 1, + "negative": 2, + "question": 0, + "answer": 1, + "query": 0, + "response": 1, + "hypothesis": 0, + "entailment": 1, + "contradiction": 2, + } + for column_name, expected_idx in column_name_to_expected_idx.items(): + if column_name in column_names and column_names.index(column_name) != expected_idx: + if column_name in ("anchor", "positive", "negative"): + proposed_fix_columns = ["anchor", "positive", "negative"] + elif column_name in ("question", "answer"): + proposed_fix_columns = ["question", "answer"] + elif column_name in ("query", "response"): + proposed_fix_columns = ["query", "response"] + elif column_name in ("hypothesis", "entailment", "contradiction"): + proposed_fix_columns = ["hypothesis", "entailment", "contradiction"] + + logger.warning( + f"Column {column_name!r} is at index {column_names.index(column_name)}, whereas " + f"a column with this name is usually expected at index {expected_idx}. Note that the column " + "order can be important for some losses, e.g. MultipleNegativesRankingLoss will always " + "consider the first column as the anchor and the second as the positive, regardless of " + "the dataset column names. Consider renaming the columns to match the expected order, e.g.:\n" + f"dataset = dataset.select_columns({proposed_fix_columns})" + ) + # We only need to warn once per list of column names to prevent spamming the user + break + + self._warned_columns.add(tuple(column_names))