From 5ccd65d4173a48bfdbc52603f82f7bf636973eb7 Mon Sep 17 00:00:00 2001 From: Baber Abbasi <92168766+baberabb@users.noreply.github.com> Date: Tue, 27 Feb 2024 19:03:56 +0500 Subject: [PATCH] Refactor `evaluater.evaluate` (#1441) * change `all_gather` to `gather` * add TaskOutput utility class * Add FilterResults class and refactor task handling. * Rename `key` to `filter_key` for clarity * Add `print_writeout` function in utils.py * Add function to calculate limit size. * Add doc_iterator method to Task class * Refactor `doc_iterator` and cleanup in Task class * remove superfluous bits * change `all_gather` to `gather` * bugfix * bugfix * fix `gather` * Refactor `gather` loop * Refactor aggregate metrics calculation * Refactor and simplify aggregate metrics calculation Removed unused code * Simplify metrics calculation and remove unused code. * simplify the metrics calculation in `utils.py` and `evaluator.py`. * Fix group metric * change evaluate to hf_evaluate * change evaluate to hf_evaluate * add docs * add docs * nits * make isslice keyword only * nit * add todo * nit * nit * nit: swap order samples_metrics tuple * move instance sorting outside loop * nit * nit * Add __repr__ for ConfigurableTask * nit * nit * Revert "nit" This reverts commit dab8d9977a643752a17f840fd8cf7e4b107df28f. * fix some logging * nit * fix `predict_only` bug. thanks to `@LSinev`! * change `print_tasks` to `prepare_print_tasks` * nits * move eval utils * move eval utils * nit * add comment * added tqdm descriptions * Update lm_eval/evaluator_utils.py Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> * fix mgsm bug * nit * fix `build_all_requests` * pre-commit * add ceil to limit --------- Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> --- lm_eval/api/model.py | 6 +- lm_eval/api/task.py | 56 +++-- lm_eval/evaluator.py | 373 ++++++++++--------------------- lm_eval/evaluator_utils.py | 312 ++++++++++++++++++++++++++ lm_eval/models/huggingface.py | 12 +- lm_eval/models/vllm_causallms.py | 12 +- lm_eval/utils.py | 53 +---- 7 files changed, 487 insertions(+), 337 deletions(-) create mode 100644 lm_eval/evaluator_utils.py diff --git a/lm_eval/api/model.py b/lm_eval/api/model.py index 6bb93a344e..0bb16d419b 100644 --- a/lm_eval/api/model.py +++ b/lm_eval/api/model.py @@ -225,7 +225,7 @@ def fn(requests): eval_logger.info( f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..." ) - for req in tqdm(requests): + for req in tqdm(requests, desc="Checking cached requests"): hsh = hash_args(attr, req.args) if attr == "generate_until" and req.args[1].get("do_sample", False): # when we are doing non-greedy generation, don't use the cache @@ -246,7 +246,9 @@ def fn(requests): else: res.append(None) remaining_reqs.append(req) - + eval_logger.info( + f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" + ) # actually run the LM on the requests that do not have cached results rem_res = getattr(self.lm, attr)(remaining_reqs) diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index 26f5333f42..9fa4b78249 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -7,7 +7,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from inspect import getsource -from typing import Any, List, Literal, Tuple, Union +from typing import Any, Iterator, List, Literal, Tuple, Union import datasets import numpy as np @@ -327,7 +327,7 @@ def _process_doc(self, doc): return doc @property - def instances(self): + def instances(self) -> List[Instance]: """After calling `task.build_all_requests()`, tasks maintain a list of the dataset instances which will be evaluated. """ @@ -355,6 +355,7 @@ def doc_to_target(self, doc): def build_all_requests( self, + *, limit=None, rank=None, world_size=None, @@ -382,13 +383,6 @@ def build_all_requests( self._instances = flattened_instances return - if self.has_test_docs(): - docs = self.test_docs() - elif self.has_validation_docs(): - docs = self.validation_docs() - else: - assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" - eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...") instances = [] @@ -402,12 +396,7 @@ def build_all_requests( limit = None doc_id_docs = list( - utils.create_iterator( - enumerate(docs), - rank, - world_size, - limit, - ) + self.doc_iterator(rank=rank, limit=limit, world_size=world_size) ) num_docs = len(doc_id_docs) @@ -632,6 +621,27 @@ def override_metric(self, metric_name: str) -> None: setattr(self._config, "metric_list", [{"metric": metric_name}]) setattr(self._config, "process_results", None) + @property + def eval_docs(self) -> Union[datasets.Dataset, List[dict]]: + if self.has_test_docs(): + return self.test_docs() + elif self.has_validation_docs(): + return self.validation_docs() + else: + assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" + + def doc_iterator( + self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1 + ) -> Iterator[Tuple[int, Any]]: + limit = int(limit) if limit else None + doc_iterator = utils.create_iterator( + enumerate(self.eval_docs), + rank=int(rank), + limit=limit, + world_size=int(world_size), + ) + return doc_iterator + class ConfigurableTask(Task): VERSION = "Yaml" @@ -781,12 +791,7 @@ def __init__( else "default" )(list(self.fewshot_docs()), self, rnd=random.Random(1234)) - if self.has_test_docs(): - self.task_docs = self.test_docs() - elif self.has_validation_docs(): - self.task_docs = self.validation_docs() - else: - assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" + self.task_docs = self.eval_docs # Test One Doc self.features = list(self.task_docs.features.keys()) @@ -1336,6 +1341,15 @@ def higher_is_better(self) -> dict: def get_config(self, key: str) -> Any: return getattr(self._config, key, None) + def __repr__(self): + return ( + f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," + f"group_name={getattr(self.config, 'group', None)}," + f"output_type={self.OUTPUT_TYPE}," + f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," + f"num_samples={len(self.eval_docs)})" + ) + class MultipleChoiceTask(Task): OUTPUT_TYPE: str = "loglikelihood" diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index d3afc70c61..c0c76306f4 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -1,7 +1,6 @@ import collections import itertools import logging -import math import random from typing import TYPE_CHECKING, Optional, Union @@ -11,12 +10,19 @@ import lm_eval.api.metrics import lm_eval.api.registry import lm_eval.models +from lm_eval.evaluator_utils import ( + consolidate_results, + get_sample_size, + get_task_list, + prepare_print_tasks, + print_writeout, + run_task_tests, +) from lm_eval.logging_utils import add_env_info, get_git_commit_hash from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.utils import ( eval_logger, positional_deprecated, - run_task_tests, simple_parse_args_string, ) @@ -111,19 +117,23 @@ def simple_evaluate( eval_logger.info("Deleting requests cache...") delete_cache() + seed_message = [] if random_seed is not None: # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412 - eval_logger.info(f"Setting random seed to {random_seed}") + seed_message.append(f"Setting random seed to {random_seed}") random.seed(random_seed) if numpy_random_seed is not None: - eval_logger.info(f"Setting numpy seed to {numpy_random_seed}") + seed_message.append(f"Setting numpy seed to {numpy_random_seed}") np.random.seed(numpy_random_seed) if torch_random_seed is not None: - eval_logger.info(f"Setting torch manual seed to {torch_random_seed}") + seed_message.append(f"Setting torch manual seed to {torch_random_seed}") torch.manual_seed(torch_random_seed) + if seed_message: + eval_logger.info(" | ".join(seed_message)) + if tasks is None: tasks = [] assert ( @@ -166,7 +176,7 @@ def simple_evaluate( lm = model if use_cache is not None: - print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}") + eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}") lm = lm_eval.api.model.CachingLM( lm, use_cache @@ -198,13 +208,13 @@ def simple_evaluate( key="generation_kwargs", value=gen_kwargs, update=True ) - if predict_only: - log_samples = True - eval_logger.info( - f"Processing {task_name} in output-only mode. Metrics will not be calculated!" - ) - # we have to change the class properties post-hoc. This is pretty hacky. - task_obj.override_metric(metric_name="bypass") + if predict_only: + log_samples = True + eval_logger.info( + f"Processing {task_name} in output-only mode. Metrics will not be calculated!" + ) + # we have to change the class properties post-hoc. This is pretty hacky. + task_obj.override_metric(metric_name="bypass") if num_fewshot is not None: if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0: @@ -299,82 +309,22 @@ def evaluate( eval_logger.setLevel(getattr(logging, f"{verbosity}")) # decontaminate = decontamination_ngrams_path is not None - for task_name, task in task_dict.items(): - if isinstance(task, tuple): - _, task = task - if not log_samples: - assert ( - "bypass" not in getattr(task, "_metric_fn_list", {}).keys() - ), f"log_samples must be True for 'bypass' only tasks: {task_name}" - - # stores the final result for each task, for each metric/filter pair. - results = collections.defaultdict(dict) - # Tracks each task's version. - versions = collections.defaultdict(dict) - # Tracks the YAML configs of all chosen tasks. - configs = collections.defaultdict(dict) - # logs info about each document evaluated. - samples = collections.defaultdict(list) # tracks all Instances/requests a model must generate output on. requests = collections.defaultdict(list) - # Aggregated task scores presented with groups - results_agg = collections.defaultdict(dict) - # Aggregated groups scores only - groups_agg = collections.defaultdict(dict) # stores the amount to pad out reqs per req. type so that # number of fwd passes per distributed rank is equal padding_requests = collections.defaultdict(int) - # store the hierarchy to do proper ordering - task_hierarchy = collections.defaultdict(list) - # store num-fewshot value per task - num_fewshot = collections.defaultdict(int) - - # get lists of each type of request - for task_name, task in task_dict.items(): - task: Task - - if isinstance(task, tuple): - group_name, task = task - task_hierarchy[group_name].append(task_name) - versions[group_name] = "N/A" - - else: - group_name = None - task_hierarchy[task_name] = [] - - if task is None: - continue - - versions[task_name] = task.VERSION - configs[task_name] = dict(task.dump_config()) - - # Number of few-shots for printing. - if (n_shot := configs[task_name].get("num_fewshot")) == 0: - n_shot = configs[task_name].get("metadata", {}).get("num_fewshot", 0) - num_fewshot[task_name] = n_shot - - if "task_alias" in configs[task_name]: - results[task_name]["alias"] = configs[task_name]["task_alias"] - - if ( - ("group_alias" in configs[task_name]) - and (group_name not in results) - and (group_name is not None) - ): - results[group_name]["alias"] = configs[task_name]["group_alias"] - - if limit is not None: - if task.has_test_docs(): - task_docs = task.test_docs() - elif task.has_validation_docs(): - task_docs = task.validation_docs() - else: - raise RuntimeError("Task has neither test_docs nor validation_docs") - - num_docs = len(task_docs) * limit - # ceil to prevent limit being equal to 0 - limit = int(math.ceil(num_docs)) if limit < 1.0 else int(limit) + # get lists of group hierarchy and each type of request + task_hierarchy, eval_tasks = get_task_list(task_dict) + if not log_samples: + assert all( + "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() + for task_output in eval_tasks + ), "log_samples must be True for 'bypass' only tasks" + for task_output in eval_tasks: + task: Task = task_output.task + limit = get_sample_size(task, limit) task.build_all_requests( limit=limit, rank=lm.rank, @@ -382,21 +332,12 @@ def evaluate( cache_requests=cache_requests, rewrite_requests_cache=rewrite_requests_cache, ) - eval_logger.debug( - f"Task: {task_name}; number of requests on this rank: {len(task.instances)}" + f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}" ) if write_out: - for inst in task.instances: - # print the prompt for the first few documents - if inst.doc_id < 1: - eval_logger.info( - f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\ -\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)" - ) - eval_logger.info(f"Request: {str(inst)}") - + print_writeout(task) # aggregate Instances by LM method requested to get output. for instance in task.instances: reqtype = instance.request_type @@ -408,7 +349,7 @@ def evaluate( lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() ) - # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks) + # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks) numpad = max(gathered_item) - gathered_item[lm.rank] padding_requests[task.OUTPUT_TYPE] += numpad @@ -435,42 +376,33 @@ def evaluate( if lm.world_size > 1: lm.accelerator.wait_for_everyone() + RANK = lm.rank + WORLD_SIZE = lm.world_size ### Postprocess outputs ### # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) - for task_name, task in task_dict.items(): - if isinstance(task, tuple): - group, task = task - if task is None: - continue + for task_output in eval_tasks: + task = task_output.task task.apply_filters() - ### Collect values of metrics on all datapoints ### - vals = collections.defaultdict(list) - - # unpack results and sort back in order and return control to Task - for task_name, task in task_dict.items(): - if isinstance(task, tuple): - group, task = task - if task is None: - continue + ### Collect values of metrics on all datapoints ### + # # unpack results and sort back in order and return control to Task # TODO: make it possible to use a different metric per filter + # Pre-process task.instances to group by doc_id + instances_by_doc_id = collections.defaultdict(list) + for instance in task.instances: + instances_by_doc_id[instance.doc_id].append(instance) + # Sort instances within each group + for instances in instances_by_doc_id.values(): + instances.sort(key=lambda x: x.idx) # iterate over different filters used - for key in task.instances[0].filtered_resps.keys(): - doc_iterator = ( - itertools.islice( - enumerate(task.test_docs()), lm.rank, limit, lm.world_size - ) - if task.has_test_docs() - else itertools.islice( - enumerate(task.validation_docs()), lm.rank, limit, lm.world_size - ) + for filter_key in task.instances[0].filtered_resps.keys(): + doc_iterator = task.doc_iterator( + rank=RANK, limit=limit, world_size=WORLD_SIZE ) for doc_id, doc in doc_iterator: - # subset instances to only this document id ; sort by idx - requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) - requests.sort(key=lambda x: x.idx) + requests = instances_by_doc_id[doc_id] metrics = task.process_results( - doc, [req.filtered_resps[key] for req in requests] + doc, [req.filtered_resps[filter_key] for req in requests] ) if log_samples: target = task.doc_to_target(doc) @@ -480,93 +412,56 @@ def evaluate( "target": target, "arguments": [req.args for req in requests], "resps": [req.resps for req in requests], - "filtered_resps": [req.filtered_resps[key] for req in requests], + "filtered_resps": [ + req.filtered_resps[filter_key] for req in requests + ], } example.update(metrics) - samples[task_name].append(example) + task_output.logged_samples.append(example) for metric, value in metrics.items(): - vals[(task_name, key, metric)].append(value) + task_output.sample_metrics[(metric, filter_key)].append(value) - if lm.world_size > 1: - # if multigpu, then gather data across all ranks + if WORLD_SIZE > 1: + # if multigpu, then gather data across all ranks to rank 0 # first gather logged samples across all ranks - for task_name, task_samples in list(samples.items()): - full_samples = [None] * lm.world_size - torch.distributed.all_gather_object(full_samples, task_samples) - - samples[task_name] = list(itertools.chain.from_iterable(full_samples)) - - # then collect metrics across all ranks - vals_torch = collections.defaultdict(list) - for (task_name, key, metric), items in vals.items(): - numitem = 0 - if isinstance(items[0], tuple): - numitem = len(items[0]) - - if isinstance(items[0], (str, list, tuple)): - # handle the string case - gathered_items = [None] * lm.accelerator.num_processes - torch.distributed.all_gather_object(gathered_items, items) - - gathered_item = list(itertools.chain.from_iterable(gathered_items)) - else: - # distributed gather requires all ranks to have same dimensions - # so we pad out with float32 min value - pad_value = torch.finfo(torch.float32).min - metrics_tensor = torch.tensor(items, device=lm.device) - - original_dtype = metrics_tensor.dtype # store original dtype - torch_device_tensor = lm.accelerator.pad_across_processes( - metrics_tensor.to(torch.float32), pad_index=pad_value + for task_output in eval_tasks: + if log_samples: + # for task_name, task_samples in list(samples.items()): + full_samples = [None] * WORLD_SIZE if RANK == 0 else None + torch.distributed.gather_object( + obj=task_output.logged_samples, + object_gather_list=full_samples, + dst=0, ) - gathered_item = lm.accelerator.gather(torch_device_tensor) - if numitem > 0: - gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value] - else: - gathered_filtered = gathered_item[gathered_item != pad_value] + if RANK == 0: + task_output.logged_samples = list( + itertools.chain.from_iterable(full_samples) + ) - gathered_item = ( - gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist() + # then collect metrics across all ranks + for metrics in task_output.sample_metrics: + metric_list = [None] * WORLD_SIZE if RANK == 0 else None + torch.distributed.gather_object( + obj=task_output.sample_metrics[metrics], + object_gather_list=metric_list, + dst=0, ) - # reconvert if we were passed a tuple of values - if numitem > 0: - gathered_item = [tuple(g) for g in gathered_item] - - if lm.rank == 0: - vals_torch[(task_name, key, metric)] = gathered_item - - vals = vals_torch + if RANK == 0: + task_output.sample_metrics[metrics] = list( + itertools.chain.from_iterable(metric_list) + ) - if lm.rank == 0: + if RANK == 0: ### Aggregate results over all datapoints ### # aggregate results ; run bootstrap CIs - for (task_name, key, metric), items in vals.items(): - task = task_dict[task_name] - group_name, task = task if isinstance(task, tuple) else (None, task) - - metric_key = f"{metric},{key}" - agg_fn = task.aggregation()[metric] - - results[task_name][metric_key] = agg_fn(items) - results[task_name]["samples"] = len(items) - - # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap - # so we run them less iterations. still looking for a cleaner way to do this - if bootstrap_iters > 0: - stderr_fn = lm_eval.api.metrics.stderr_for_metric( - metric=agg_fn, - bootstrap_iters=( - min(bootstrap_iters, 100) - if metric in ["bleu", "chrf", "ter"] - else bootstrap_iters - ), - ) - - results[task_name][f"{metric}_stderr,{key}"] = ( - stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A" - ) + for task_output in eval_tasks: + task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters) + results, samples, configs, versions, num_fewshot = consolidate_results( + eval_tasks + ) + ### Calculate group metrics ### if bool(results): for group, task_list in reversed(task_hierarchy.items()): if len(task_list) == 0: @@ -575,19 +470,33 @@ def evaluate( # or `task_name: []`. # we only want to operate on groups here. continue - for metric in [ - key - for key in results[task_list[0]].keys() - if "_stderr" not in key and key not in ["alias", "samples"] - ]: # TODO: what if tasks don't all share the same metrics + metric_list = list( + { + key + for task in task_list + for key in results[task].keys() + if "_stderr" not in key and key not in ["alias", "samples"] + } + ) + for metric in metric_list: stderr = "_stderr,".join(metric.split(",")) # gather metrics, sizes, and stderrs from subtasks metrics = [ - results[task][metric] for task in task_list + results[task][metric] + for task in task_list + if metric in results[task] ] # TODO: copy? - stderrs = [results[task][stderr] for task in task_list] - sizes = [results[task]["samples"] for task in task_list] + stderrs = [ + results[task][stderr] + for task in task_list + if stderr in results[task] + ] + sizes = [ + results[task]["samples"] + for task in task_list + if metric in results[task] + ] # compute group's pooled metric and stderr results[group][ @@ -606,60 +515,6 @@ def evaluate( results[group]["samples"] = sum(sizes) - def print_tasks(task_hierarchy, results, tab=0): - results_agg = collections.defaultdict(dict) - groups_agg = collections.defaultdict(dict) - - (group_name, task_list), *_ = task_hierarchy.items() - task_list = sorted(task_list) - - results_agg[group_name] = results[group_name].copy() - # results_agg[group_name]["tab"] = tab - if "samples" in results_agg[group_name]: - results_agg[group_name].pop("samples") - - tab_string = " " * tab + "- " if tab > 0 else "" - - if "alias" in results_agg[group_name]: - results_agg[group_name]["alias"] = ( - tab_string + results_agg[group_name]["alias"] - ) - else: - results_agg[group_name]["alias"] = tab_string + group_name - - if len(task_list) > 0: - groups_agg[group_name] = results[group_name].copy() - # groups_agg[group_name]["tab"] = tab - if "samples" in groups_agg[group_name]: - groups_agg[group_name].pop("samples") - - if "alias" in groups_agg[group_name]: - groups_agg[group_name]["alias"] = ( - tab_string + groups_agg[group_name]["alias"] - ) - else: - groups_agg[group_name]["alias"] = tab_string + group_name - - for task_name in task_list: - if task_name in task_hierarchy: - _task_hierarchy = { - **{task_name: task_hierarchy[task_name]}, - **task_hierarchy, - } - else: - _task_hierarchy = { - **{task_name: []}, - **task_hierarchy, - } - - _results_agg, _groups_agg = print_tasks( - _task_hierarchy, results, tab + 1 - ) - results_agg = {**results_agg, **_results_agg} - groups_agg = {**groups_agg, **_groups_agg} - - return results_agg, groups_agg - results_agg = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict) all_tasks_list = list(task_hierarchy.keys()) @@ -673,7 +528,7 @@ def print_tasks(task_hierarchy, results, tab=0): _task_hierarchy = { k: v for k, v in task_hierarchy.items() if k in left_tasks_list } - _results_agg, _groups_agg = print_tasks(_task_hierarchy, results) + _results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results) results_agg = {**results_agg, **_results_agg} groups_agg = {**groups_agg, **_groups_agg} diff --git a/lm_eval/evaluator_utils.py b/lm_eval/evaluator_utils.py new file mode 100644 index 0000000000..fcb18206f6 --- /dev/null +++ b/lm_eval/evaluator_utils.py @@ -0,0 +1,312 @@ +import collections +import math +import pathlib +import sys +from typing import Dict, List, Optional, Tuple, Union + +from lm_eval.api import metrics +from lm_eval.utils import eval_logger, positional_deprecated + + +class TaskOutput: + """ + Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task. + + Attributes: + task (object): The task object. + task_name (str): The name of the task. + task_config (dict): The configuration of the task. + version (str): The version of the task. + group_name (str): The name of the task group. + n_shot (int): The number of shots for the task. + task_alias (str): The alias of the task. + group_alias (str): The alias of the task group. + is_group (bool): Indicates if the task is a group. + logged_samples (list): The list of logged samples. + sample_len (int): The length of the samples. + sample_metrics (defaultdict): The dictionary of samples' metrics. + agg_metrics (defaultdict): The dictionary of aggregate metrics. + + Methods: + from_taskdict(cls, task_name: str, task): + Creates a TaskOutput instance from a task dictionary. + + calculate_aggregate_metric(bootstrap_iters=100000) -> None: + Calculates the aggregate metrics for the task. + """ + + def __init__( + self, + task=None, + task_name=None, + task_config=None, + version=None, + group_name=None, + n_shot=None, + task_alias=None, + group_alias=None, + is_group=None, + ): + self.task = task + self.task_config = task_config + self.task_name = task_name + self.group_name = group_name + self.version = version + self.n_shot = n_shot + self.task_alias = task_alias + self.group_alias = group_alias + self.is_group = is_group + self.logged_samples = [] + self.sample_len = None + self.sample_metrics = collections.defaultdict(list) + self.agg_metrics = collections.defaultdict(list) + + @classmethod + def from_taskdict(cls, task_name: str, task): + if isinstance(task, tuple): + group_name, task = task + else: + group_name = None + if not task: + # these gets filtered out in get_task_list + # once they are added to group hierarchy + is_group = True + return cls( + task=task, task_name=task_name, is_group=is_group, group_name=group_name + ) + version = task.VERSION + task_config = dict(task.dump_config()) + if (n_shot := task_config.get("num_fewshot")) == 0: + n_shot = task_config.get("metadata", {}).get("num_fewshot", 0) + task_alias = task_config.get("alias") + group_alias = task_config.get("group_alias") + return cls( + task=task, + task_name=task_name, + task_config=task_config, + group_name=group_name, + version=version, + n_shot=n_shot, + task_alias=task_alias, + group_alias=group_alias, + ) + + def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None: + for (metric, filter_key), items in self.sample_metrics.items(): + agg_fn = self.task.aggregation()[metric] + metric_key = f"{metric},{filter_key}" + self.agg_metrics[metric_key] = agg_fn(items) + self.sample_len = len(items) # TODO: same sample size for each metric? + if bootstrap_iters: + stderr_fn = metrics.stderr_for_metric( + metric=agg_fn, + bootstrap_iters=min(bootstrap_iters, 100) + if metric in ["bleu", "chrf", "ter"] + else bootstrap_iters, + ) + self.agg_metrics[f"{metric}_stderr,{filter_key}"] = ( + stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A" + ) + + def __repr__(self): + return ( + f"TaskOutput(task_name={self.task_name}, " + f"group_name={self.group_name}, " + f"version={self.version}," + f"n_shot={self.n_shot}" + f"task_alias={self.task_alias}, group_alias={self.group_alias})" + ) + + +def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]: + task_hierarchy = collections.defaultdict(list) + outputs = list(TaskOutput.from_taskdict(x, y) for x, y in task_dict.items()) + for task_output in outputs: + if group_name := task_output.group_name: + task_hierarchy[group_name].append(task_output.task_name) + else: + task_hierarchy[task_output.task_name] = [] + # returns task_hierarchy tracking which groups contain which subtasks, + # and a list of TaskOutput classes for each non-group subtask + return task_hierarchy, [x for x in outputs if x.task] + + +def print_writeout(task) -> None: + for inst in task.instances: + # print the prompt for the first few documents + if inst.doc_id < 1: + eval_logger.info( + f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\ + \n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)" + ) + eval_logger.info(f"Request: {str(inst)}") + + +def get_sample_size(task, limit: Optional[int]) -> Union[int, None]: + if limit is not None: + limit = ( + int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit) + ) + return limit + + +def prepare_print_tasks( + task_hierarchy: dict, results: dict, tab=0 +) -> Tuple[dict, dict]: + """ + @param task_hierarchy: Dictionary representing the group hierarchy of tasks. Each key is a group name and its + value is a list of task names. + @param results: Dictionary containing the results of each task. Each key is a + group name and its value is a dictionary of task results. + @param tab: The indentation level for printing the task + hierarchy. Default is 0. + @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains + aggregated results for each task, and groups_agg contains aggregated results for each group. + + Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. + """ + results_agg = collections.defaultdict(dict) + groups_agg = collections.defaultdict(dict) + + (group_name, task_list), *_ = task_hierarchy.items() + task_list = sorted(task_list) + + results_agg[group_name] = results[group_name].copy() + # results_agg[group_name]["tab"] = tab + if "samples" in results_agg[group_name]: + results_agg[group_name].pop("samples") + + tab_string = " " * tab + "- " if tab > 0 else "" + + if "alias" in results_agg[group_name]: + results_agg[group_name]["alias"] = tab_string + results_agg[group_name]["alias"] + else: + results_agg[group_name]["alias"] = tab_string + group_name + + if len(task_list) > 0: + groups_agg[group_name] = results[group_name].copy() + # groups_agg[group_name]["tab"] = tab + if "samples" in groups_agg[group_name]: + groups_agg[group_name].pop("samples") + + if "alias" in groups_agg[group_name]: + groups_agg[group_name]["alias"] = ( + tab_string + groups_agg[group_name]["alias"] + ) + else: + groups_agg[group_name]["alias"] = tab_string + group_name + + for task_name in task_list: + if task_name in task_hierarchy: + _task_hierarchy = { + **{task_name: task_hierarchy[task_name]}, + **task_hierarchy, + } + else: + _task_hierarchy = { + **{task_name: []}, + **task_hierarchy, + } + + _results_agg, _groups_agg = prepare_print_tasks( + _task_hierarchy, results, tab + 1 + ) + results_agg = {**results_agg, **_results_agg} + groups_agg = {**groups_agg, **_groups_agg} + + return results_agg, groups_agg + + +def consolidate_results( + eval_tasks: List[TaskOutput], +) -> Tuple[dict, dict, dict, dict, dict]: + """ + @param eval_tasks: list(TaskOutput). + @return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot. + + Consolidates the results of multiple evaluation tasks into a single structure. + + The method iterates over each evaluation instance and extracts relevant information to create the consolidated + results structure. The consolidated results structure has the following properties: + + - results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains + metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task + aliases specified in the task configuration. + - samples: A defaultdict with task names as keys and lists of log samples as values. + - configs: A defaultdict with task names as keys and task configurations as values. + - versions: A defaultdict with task names as keys and task versions as values. + - num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values. + + The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple. + """ + # stores the final result for each task, for each metric/filter pair. + results = collections.defaultdict(dict) + # logs info about each document evaluated. + samples = collections.defaultdict(list) + # store num-fewshot value per task + num_fewshot = collections.defaultdict(int) + # Tracks the YAML configs of all chosen task + configs = collections.defaultdict(dict) + # Tracks each task's version. + versions = collections.defaultdict(dict) + for task_output in eval_tasks: + if "task_alias" in (task_config := task_output.task_config): + results[task_output.task_name]["alias"] = task_config["task_alias"] + if group_alias := task_output.group_alias: + if group_alias not in results and (group_name := task_output.group_name): + results[group_name]["alias"] = group_alias + num_fewshot[task_output.task_name] = task_output.n_shot + configs[task_output.task_name] = task_output.task_config + versions[task_output.task_name] = task_output.version + samples[task_output.task_name] = task_output.logged_samples + for (metric, filter_key), items in task_output.sample_metrics.items(): + metric_key = f"{metric},{filter_key}" + results[task_output.task_name][metric_key] = task_output.agg_metrics[ + metric_key + ] + results[task_output.task_name]["samples"] = task_output.sample_len + results[task_output.task_name][ + f"{metric}_stderr,{filter_key}" + ] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] + return results, samples, configs, versions, num_fewshot + + +@positional_deprecated +def find_test_root(start_path: pathlib.Path) -> pathlib.Path: + """ + Search upward in the directory tree to a maximum of three layers + to find and return the package root (containing the 'tests' folder) + """ + cur_path = start_path.resolve() + max_layers = 3 + for _ in range(max_layers): + if (cur_path / "tests" / "test_version_stable.py").exists(): + return cur_path + else: + cur_path = cur_path.parent.resolve() + raise FileNotFoundError( + f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" + ) + + +@positional_deprecated +def run_task_tests(task_list: List[str]): + """ + Find the package root and run the tests for the given tasks + """ + import pytest + + package_root = find_test_root(start_path=pathlib.Path(__file__)) + task_string = " or ".join(task_list) + args = [ + f"{package_root}/tests/test_version_stable.py", + f"--rootdir={package_root}", + "-k", + f"{task_string}", + ] + sys.path.append(str(package_root)) + pytest_return_val = pytest.main(args) + if pytest_return_val: + raise ValueError( + f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" + ) diff --git a/lm_eval/models/huggingface.py b/lm_eval/models/huggingface.py index 49085e202a..a1b4346d41 100644 --- a/lm_eval/models/huggingface.py +++ b/lm_eval/models/huggingface.py @@ -921,7 +921,11 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): ) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) - pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) for chunk in chunks: inps = [] cont_toks_list = [] @@ -1089,7 +1093,11 @@ def _collate(req: Tuple[str, dict]): toks = self.tok_encode(req[0]) return -len(toks), req[0] - pbar = tqdm(total=len(requests), disable=(self.rank != 0)) + pbar = tqdm( + total=len(requests), + disable=(self.rank != 0), + desc="Running generate_until requests", + ) adaptive_batch_size = None if self.batch_size == "auto": # using rolling window with maximum context diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index b53a4299ad..63f5e64a78 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -254,7 +254,11 @@ def _collate_gen(_requests): n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None ) - pbar = tqdm(total=len(requests), disable=(self.rank != 0)) + pbar = tqdm( + total=len(requests), + disable=(self.rank != 0), + desc="Running generate_until requests", + ) # for each different set of kwargs, we execute all requests, by batch. for chunk in chunks: context_and_encoding, all_gen_kwargs = zip(*chunk) @@ -329,7 +333,11 @@ def _collate(x): n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None ) - pbar = tqdm(total=len(requests), disable=disable_tqdm) + pbar = tqdm( + total=len(requests), + disable=disable_tqdm, + desc="Running loglikelihood requests", + ) for chunk in chunks: inputs = [] ctxlens = [] diff --git a/lm_eval/utils.py b/lm_eval/utils.py index 215b44b850..30de3e2506 100644 --- a/lm_eval/utils.py +++ b/lm_eval/utils.py @@ -6,9 +6,7 @@ import logging import os import re -import sys from itertools import islice -from pathlib import Path from typing import Any, Callable, List import numpy as np @@ -244,7 +242,7 @@ def make_table(result_dict, column: str = "results"): values = [] for k, dic in result_dict[column].items(): - version = result_dict["versions"][k] + version = result_dict["versions"].get(k, "N/A") n = str(result_dict["n-shot"][k]) if "alias" in dic: @@ -292,47 +290,6 @@ def _wrapper(*args, **kwargs): return _wrapper -@positional_deprecated -def find_test_root(start_path: Path) -> Path: - """ - Search upward in the directory tree to a maximum of three layers - to find and return the package root (containing the 'tests' folder) - """ - cur_path = start_path.resolve() - max_layers = 3 - for _ in range(max_layers): - if (cur_path / "tests" / "test_version_stable.py").exists(): - return cur_path - else: - cur_path = cur_path.parent.resolve() - raise FileNotFoundError( - f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" - ) - - -@positional_deprecated -def run_task_tests(task_list: List[str]): - """ - Find the package root and run the tests for the given tasks - """ - import pytest - - package_root = find_test_root(start_path=Path(__file__)) - task_string = " or ".join(task_list) - args = [ - f"{package_root}/tests/test_version_stable.py", - f"--rootdir={package_root}", - "-k", - f"{task_string}", - ] - sys.path.append(str(package_root)) - pytest_return_val = pytest.main(args) - if pytest_return_val: - raise ValueError( - f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" - ) - - def ignore_constructor(loader, node): return node @@ -414,16 +371,10 @@ def apply_template(template: str, doc: dict) -> str: return rtemplate.render(**doc) -def create_iterator(raw_iterator, rank, world_size, limit=None): +def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): """ Method for creating a (potentially) sliced and limited iterator from a raw document iterator. Used for splitting data among ranks in multigpu setting or only pulling a sample of documents """ return islice(raw_iterator, rank, limit, world_size) - - -# Multi-token stopping criteria - - -# from more_itertools