From 2caa6361b2a657e1454081518e8659396ded8d40 Mon Sep 17 00:00:00 2001 From: Milutin Studen <81291567+milistu@users.noreply.github.com> Date: Mon, 22 Jan 2024 10:38:51 +0100 Subject: [PATCH] Add '@k' at the end of csv file name (#2427) * Add '@k' at the end of csv file name * Add NDCG metric * Update sentence_transformers/evaluation/RerankingEvaluator.py Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * Update sentence_transformers/evaluation/RerankingEvaluator.py Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * Update sentence_transformers/evaluation/RerankingEvaluator.py Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * Fix code style with command 'make style' * Fix missing self at at_k assigning --------- Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> --- .../evaluation/RerankingEvaluator.py | 58 ++++++++++++++----- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/sentence_transformers/evaluation/RerankingEvaluator.py b/sentence_transformers/evaluation/RerankingEvaluator.py index 2976dc422..06640b5da 100644 --- a/sentence_transformers/evaluation/RerankingEvaluator.py +++ b/sentence_transformers/evaluation/RerankingEvaluator.py @@ -5,8 +5,9 @@ import csv from ..util import cos_sim import torch -from sklearn.metrics import average_precision_score +from sklearn.metrics import average_precision_score, ndcg_score import tqdm +from typing import Optional logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ class RerankingEvaluator(SentenceEvaluator): This class evaluates a SentenceTransformer model for the task of re-ranking. Given a query and a list of documents, it computes the score [query, doc_i] for all possible - documents and sorts them in decreasing order. Then, MRR@10 and MAP is compute to measure the quality of the ranking. + documents and sorts them in decreasing order. Then, MRR@10, NDCG@10 and MAP is compute to measure the quality of the ranking. :param samples: Must be a list and each element is of the form: {'query': '', 'positive': [], 'negative': []}. Query is the search query, positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents. @@ -25,17 +26,22 @@ class RerankingEvaluator(SentenceEvaluator): def __init__( self, samples, - mrr_at_k: int = 10, + at_k: int = 10, name: str = "", write_csv: bool = True, similarity_fct=cos_sim, batch_size: int = 64, show_progress_bar: bool = False, use_batched_encoding: bool = True, + mrr_at_k: Optional[int] = None, ): self.samples = samples self.name = name - self.mrr_at_k = mrr_at_k + if mrr_at_k is not None: + logger.warning(f"The `mrr_at_k` parameter has been deprecated; please use `at_k={mrr_at_k}` instead.") + self.at_k = mrr_at_k + else: + self.at_k = at_k self.similarity_fct = similarity_fct self.batch_size = batch_size self.show_progress_bar = show_progress_bar @@ -49,8 +55,14 @@ def __init__( sample for sample in self.samples if len(sample["positive"]) > 0 and len(sample["negative"]) > 0 ] - self.csv_file = "RerankingEvaluator" + ("_" + name if name else "") + "_results.csv" - self.csv_headers = ["epoch", "steps", "MAP", "MRR@{}".format(mrr_at_k)] + self.csv_file = "RerankingEvaluator" + ("_" + name if name else "") + f"_results_@{self.at_k}.csv" + self.csv_headers = [ + "epoch", + "steps", + "MAP", + "MRR@{}".format(self.at_k), + "NDCG@{}".format(self.at_k), + ] self.write_csv = write_csv def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: @@ -67,6 +79,7 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = scores = self.compute_metrices(model) mean_ap = scores["map"] mean_mrr = scores["mrr"] + mean_ndcg = scores["ndcg"] #### Some stats about the dataset num_positives = [len(sample["positive"]) for sample in self.samples] @@ -84,7 +97,8 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = ) ) logger.info("MAP: {:.2f}".format(mean_ap * 100)) - logger.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr * 100)) + logger.info("MRR@{}: {:.2f}".format(self.at_k, mean_mrr * 100)) + logger.info("NDCG@{}: {:.2f}".format(self.at_k, mean_ndcg * 100)) #### Write results to disc if output_path is not None and self.write_csv: @@ -95,7 +109,7 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = if not output_file_exists: writer.writerow(self.csv_headers) - writer.writerow([epoch, steps, mean_ap, mean_mrr]) + writer.writerow([epoch, steps, mean_ap, mean_mrr, mean_ndcg]) return mean_ap @@ -112,6 +126,7 @@ def compute_metrices_batched(self, model): all documents together """ all_mrr_scores = [] + all_ndcg_scores = [] all_ap_scores = [] all_query_embs = model.encode( @@ -150,23 +165,28 @@ def compute_metrices_batched(self, model): pred_scores = pred_scores[0] pred_scores_argsort = torch.argsort(-pred_scores) # Sort in decreasing order + pred_scores = pred_scores.cpu().tolist() # Compute MRR score - is_relevant = [True] * num_pos + [False] * num_neg + is_relevant = [1] * num_pos + [0] * num_neg mrr_score = 0 - for rank, index in enumerate(pred_scores_argsort[0 : self.mrr_at_k]): + for rank, index in enumerate(pred_scores_argsort[0 : self.at_k]): if is_relevant[index]: mrr_score = 1 / (rank + 1) break all_mrr_scores.append(mrr_score) + # Compute NDCG score + all_ndcg_scores.append(ndcg_score([is_relevant], [pred_scores], k=self.at_k)) + # Compute AP - all_ap_scores.append(average_precision_score(is_relevant, pred_scores.cpu().tolist())) + all_ap_scores.append(average_precision_score(is_relevant, pred_scores)) mean_ap = np.mean(all_ap_scores) mean_mrr = np.mean(all_mrr_scores) + mean_ndcg = np.mean(all_ndcg_scores) - return {"map": mean_ap, "mrr": mean_mrr} + return {"map": mean_ap, "mrr": mean_mrr, "ndcg": mean_ndcg} def compute_metrices_individual(self, model): """ @@ -176,6 +196,7 @@ def compute_metrices_individual(self, model): a really large test set """ all_mrr_scores = [] + all_ndcg_scores = [] all_ap_scores = [] for instance in tqdm.tqdm(self.samples, disable=not self.show_progress_bar, desc="Samples"): @@ -187,7 +208,7 @@ def compute_metrices_individual(self, model): continue docs = positive + negative - is_relevant = [True] * len(positive) + [False] * len(negative) + is_relevant = [1] * len(positive) + [0] * len(negative) query_emb = model.encode( [query], convert_to_tensor=True, batch_size=self.batch_size, show_progress_bar=False @@ -199,19 +220,24 @@ def compute_metrices_individual(self, model): pred_scores = pred_scores[0] pred_scores_argsort = torch.argsort(-pred_scores) # Sort in decreasing order + pred_scores = pred_scores.cpu().tolist() # Compute MRR score mrr_score = 0 - for rank, index in enumerate(pred_scores_argsort[0 : self.mrr_at_k]): + for rank, index in enumerate(pred_scores_argsort[0 : self.at_k]): if is_relevant[index]: mrr_score = 1 / (rank + 1) break all_mrr_scores.append(mrr_score) + # Compute NDCG score + all_ndcg_scores.append(ndcg_score([is_relevant], [pred_scores], k=self.at_k)) + # Compute AP - all_ap_scores.append(average_precision_score(is_relevant, pred_scores.cpu().tolist())) + all_ap_scores.append(average_precision_score(is_relevant, pred_scores)) mean_ap = np.mean(all_ap_scores) mean_mrr = np.mean(all_mrr_scores) + mean_ndcg = np.mean(all_ndcg_scores) - return {"map": mean_ap, "mrr": mean_mrr} + return {"map": mean_ap, "mrr": mean_mrr, "ndcg": mean_ndcg}