-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[
feat
] Add column order warnings to the data collator (#2928)
* Use clearer variable names * Add column order warnings to the data collator * Make comment clearer
- Loading branch information
Showing
1 changed file
with
59 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,93 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from dataclasses import dataclass, field | ||
from typing import Any, Callable | ||
|
||
import torch | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class SentenceTransformerDataCollator: | ||
"""Collator for a SentenceTransformers model. | ||
This encodes the text columns to {column}_input_ids and {column}_attention_mask columns. | ||
This works with the two text dataset that is used as the example in the training overview: | ||
https://www.sbert.net/docs/sentence_transformer/training_overview.html | ||
It is important that the columns are in the expected order. For example, if your dataset has columns | ||
"answer", "question" in that order, then the MultipleNegativesRankingLoss will consider | ||
"answer" as the anchor and "question" as the positive, and it will (unexpectedly) optimize for | ||
"given the answer, what is the question?". | ||
""" | ||
|
||
tokenize_fn: Callable | ||
valid_label_columns: list[str] = field(default_factory=lambda: ["label", "score"]) | ||
_warned_columns: set[tuple[str]] = field(default_factory=set, init=False, repr=False) | ||
|
||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]: | ||
columns = list(features[0].keys()) | ||
column_names = list(features[0].keys()) | ||
|
||
# We should always be able to return a loss, label or not: | ||
batch = {} | ||
|
||
if "dataset_name" in columns: | ||
columns.remove("dataset_name") | ||
if "dataset_name" in column_names: | ||
column_names.remove("dataset_name") | ||
batch["dataset_name"] = features[0]["dataset_name"] | ||
|
||
if tuple(column_names) not in self._warned_columns: | ||
self.maybe_warn_about_column_order(column_names) | ||
|
||
# Extract the label column if it exists | ||
for label_column in self.valid_label_columns: | ||
if label_column in columns: | ||
if label_column in column_names: | ||
batch["label"] = torch.tensor([row[label_column] for row in features]) | ||
columns.remove(label_column) | ||
column_names.remove(label_column) | ||
break | ||
|
||
# Extract the feature columns | ||
for column in columns: | ||
tokenized = self.tokenize_fn([row[column] for row in features]) | ||
for column_name in column_names: | ||
tokenized = self.tokenize_fn([row[column_name] for row in features]) | ||
for key, value in tokenized.items(): | ||
batch[f"{column}_{key}"] = value | ||
batch[f"{column_name}_{key}"] = value | ||
return batch | ||
|
||
def maybe_warn_about_column_order(self, column_names: list[str]) -> None: | ||
"""Warn the user if the columns are likely not in the expected order.""" | ||
# A mapping from common column names to the expected index in the dataset | ||
column_name_to_expected_idx = { | ||
"anchor": 0, | ||
"positive": 1, | ||
"negative": 2, | ||
"question": 0, | ||
"answer": 1, | ||
"query": 0, | ||
"response": 1, | ||
"hypothesis": 0, | ||
"entailment": 1, | ||
"contradiction": 2, | ||
} | ||
for column_name, expected_idx in column_name_to_expected_idx.items(): | ||
if column_name in column_names and column_names.index(column_name) != expected_idx: | ||
if column_name in ("anchor", "positive", "negative"): | ||
proposed_fix_columns = ["anchor", "positive", "negative"] | ||
elif column_name in ("question", "answer"): | ||
proposed_fix_columns = ["question", "answer"] | ||
elif column_name in ("query", "response"): | ||
proposed_fix_columns = ["query", "response"] | ||
elif column_name in ("hypothesis", "entailment", "contradiction"): | ||
proposed_fix_columns = ["hypothesis", "entailment", "contradiction"] | ||
|
||
logger.warning( | ||
f"Column {column_name!r} is at index {column_names.index(column_name)}, whereas " | ||
f"a column with this name is usually expected at index {expected_idx}. Note that the column " | ||
"order can be important for some losses, e.g. MultipleNegativesRankingLoss will always " | ||
"consider the first column as the anchor and the second as the positive, regardless of " | ||
"the dataset column names. Consider renaming the columns to match the expected order, e.g.:\n" | ||
f"dataset = dataset.select_columns({proposed_fix_columns})" | ||
) | ||
# We only need to warn once per list of column names to prevent spamming the user | ||
break | ||
|
||
self._warned_columns.add(tuple(column_names)) |