From 96d185fa6232a5ab685ba7c43e45d1dbb3bb906d Mon Sep 17 00:00:00 2001 From: Lintang Sutawika Date: Mon, 26 Feb 2024 23:12:39 +0700 Subject: [PATCH] Cont metrics (#1475) * add brier_score * process brier_score * brier score is working for N-sized class * fxied brier score * add TED to BigBench and Brier score to MMLU * format * Update metrics.py * Update task.py * Update generate_until_template_yaml * Delete lm_eval/tasks/bigbench/aux_metric.py * Update generate_until_template_yaml * Update _default_template_yaml * Update _generate_configs.py * Update _generate_configs.py * Update _generate_configs.py * fix (format?) * format? * format, once more --------- Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> --- lm_eval/api/metrics.py | 19 +++++++++++++++++++ lm_eval/api/task.py | 9 +++++++++ lm_eval/tasks/ammlu/_generate_configs.py | 9 ++++----- lm_eval/utils.py | 7 +++++++ 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/lm_eval/api/metrics.py b/lm_eval/api/metrics.py index 9d66e7c8cf..acc70234b1 100644 --- a/lm_eval/api/metrics.py +++ b/lm_eval/api/metrics.py @@ -116,6 +116,25 @@ def ter(items): return sacrebleu.corpus_ter(preds, refs).score +@register_aggregation("brier_score") +def brier_score(items): # This is a passthrough function + gold, predictions = list(zip(*items)) + gold = list(gold) + gold_one_hot = np.eye(np.max(gold) + 1)[gold] + predictions = list(zip(*items))[1] + return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1)) + + +@register_metric( + metric="brier_score", + higher_is_better=False, + output_type=["multiple_choice"], + aggregation="brier_score", +) +def brier_score_fn(items): # This is a passthrough function + return items + + @register_metric( metric="acc", higher_is_better=True, diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index af640d98a5..26f5333f42 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -1227,12 +1227,21 @@ def process_results(self, doc, results): # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly exact_match = int(is_greedy[gold]) if gold != -100 else 0 + prob_norm = utils.softmax(lls) + + # TODO use keyword arguments to the metric? + # gold, pred, norm stuff, the original lls, result_dict = { **({"acc": acc} if "acc" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}), + **( + {"brier_score": (gold, prob_norm)} + if "brier_score" in use_metric + else {} + ), } if "acc_mutual_info" in use_metric: diff --git a/lm_eval/tasks/ammlu/_generate_configs.py b/lm_eval/tasks/ammlu/_generate_configs.py index b3776df802..5105e94c26 100644 --- a/lm_eval/tasks/ammlu/_generate_configs.py +++ b/lm_eval/tasks/ammlu/_generate_configs.py @@ -1,10 +1,10 @@ """ Take in a YAML, and output all other splits with this YAML """ -import os -import yaml import argparse +import os +import yaml from tqdm import tqdm @@ -68,6 +68,7 @@ "world_religions": "العلوم الانسانية", } + def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--base_yaml_path", required=True) @@ -95,9 +96,7 @@ def parse_args(): if args.cot_prompt_path is not None: description = cot_file[subject_eng] else: - description = ( - f"فم بعملية التقييم في مجال {category} \n\n" - ) + description = f"فم بعملية التقييم في مجال {category} \n\n" yaml_dict = { "include": base_yaml_name, diff --git a/lm_eval/utils.py b/lm_eval/utils.py index 803d2c132b..215b44b850 100644 --- a/lm_eval/utils.py +++ b/lm_eval/utils.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Any, Callable, List +import numpy as np import yaml from jinja2 import BaseLoader, Environment, StrictUndefined @@ -104,6 +105,12 @@ def pattern_match(patterns, source_list): return sorted(list(task_names)) +def softmax(x): + """Compute softmax values for each sets of scores in x.""" + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + + def general_detokenize(string): string = string.replace(" n't", "n't") string = string.replace(" )", ")")