Skip to content

Commit

Permalink
Refactor evaluater.evaluate (#1441)
Browse files Browse the repository at this point in the history
* change `all_gather` to `gather`

* add TaskOutput utility class

* Add FilterResults class and refactor task handling.

* Rename `key` to `filter_key` for clarity

* Add `print_writeout` function in utils.py

* Add function to calculate limit size.

* Add doc_iterator method to Task class

* Refactor `doc_iterator` and cleanup in Task class

* remove superfluous bits

* change `all_gather` to `gather`

* bugfix

* bugfix

* fix `gather`

* Refactor `gather` loop

* Refactor aggregate metrics calculation

* Refactor and simplify aggregate metrics calculation
Removed unused code

* Simplify metrics calculation and remove unused code.

* simplify the metrics calculation in `utils.py` and `evaluator.py`.

* Fix group metric

* change evaluate to hf_evaluate

* change evaluate to hf_evaluate

* add docs

* add docs

* nits

* make isslice keyword only

* nit

* add todo

* nit

* nit

* nit: swap order samples_metrics tuple

* move instance sorting outside loop

* nit

* nit

* Add __repr__ for ConfigurableTask

* nit

* nit

* Revert "nit"

This reverts commit dab8d99.

* fix some logging

* nit

* fix `predict_only` bug. thanks to `@LSinev`!

* change `print_tasks` to `prepare_print_tasks`

* nits

* move eval utils

* move eval utils

* nit

* add comment

* added tqdm descriptions

* Update lm_eval/evaluator_utils.py

Co-authored-by: Hailey Schoelkopf <[email protected]>

* fix mgsm bug

* nit

* fix `build_all_requests`

* pre-commit

* add ceil to limit

---------

Co-authored-by: Hailey Schoelkopf <[email protected]>
  • Loading branch information
baberabb and haileyschoelkopf committed Feb 27, 2024
1 parent 96d185f commit 5ccd65d
Show file tree
Hide file tree
Showing 7 changed files with 487 additions and 337 deletions.
6 changes: 4 additions & 2 deletions lm_eval/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def fn(requests):
eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
)
for req in tqdm(requests):
for req in tqdm(requests, desc="Checking cached requests"):
hsh = hash_args(attr, req.args)
if attr == "generate_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache
Expand All @@ -246,7 +246,9 @@ def fn(requests):
else:
res.append(None)
remaining_reqs.append(req)

eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)

Expand Down
56 changes: 35 additions & 21 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, List, Literal, Tuple, Union
from typing import Any, Iterator, List, Literal, Tuple, Union

import datasets
import numpy as np
Expand Down Expand Up @@ -327,7 +327,7 @@ def _process_doc(self, doc):
return doc

@property
def instances(self):
def instances(self) -> List[Instance]:
"""After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated.
"""
Expand Down Expand Up @@ -355,6 +355,7 @@ def doc_to_target(self, doc):

def build_all_requests(
self,
*,
limit=None,
rank=None,
world_size=None,
Expand Down Expand Up @@ -382,13 +383,6 @@ def build_all_requests(
self._instances = flattened_instances
return

if self.has_test_docs():
docs = self.test_docs()
elif self.has_validation_docs():
docs = self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"

eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")

instances = []
Expand All @@ -402,12 +396,7 @@ def build_all_requests(
limit = None

doc_id_docs = list(
utils.create_iterator(
enumerate(docs),
rank,
world_size,
limit,
)
self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
)

num_docs = len(doc_id_docs)
Expand Down Expand Up @@ -632,6 +621,27 @@ def override_metric(self, metric_name: str) -> None:
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)

@property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
if self.has_test_docs():
return self.test_docs()
elif self.has_validation_docs():
return self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"

def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
) -> Iterator[Tuple[int, Any]]:
limit = int(limit) if limit else None
doc_iterator = utils.create_iterator(
enumerate(self.eval_docs),
rank=int(rank),
limit=limit,
world_size=int(world_size),
)
return doc_iterator


class ConfigurableTask(Task):
VERSION = "Yaml"
Expand Down Expand Up @@ -781,12 +791,7 @@ def __init__(
else "default"
)(list(self.fewshot_docs()), self, rnd=random.Random(1234))

if self.has_test_docs():
self.task_docs = self.test_docs()
elif self.has_validation_docs():
self.task_docs = self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
self.task_docs = self.eval_docs

# Test One Doc
self.features = list(self.task_docs.features.keys())
Expand Down Expand Up @@ -1336,6 +1341,15 @@ def higher_is_better(self) -> dict:
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)

def __repr__(self):
return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
)


class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood"
Expand Down
Loading

0 comments on commit 5ccd65d

Please sign in to comment.