Skip to content

Commit

Permalink
Store more info about environment and transformers version in results…
Browse files Browse the repository at this point in the history
… to help researchers track inconsistencies
  • Loading branch information
LSinev committed Feb 23, 2024
1 parent 2ab0d73 commit f5c4eaf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
3 changes: 2 additions & 1 deletion lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from lm_eval import evaluator, utils
from lm_eval.logging_utils import WandbLogger
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.utils import make_table
from lm_eval.utils import add_env_info, make_table


def _handle_non_serializable(o):
Expand Down Expand Up @@ -310,6 +310,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None:
if args.log_samples:
samples = results.pop("samples")
add_env_info(results) # place this after popping out samples and before dumping
dumped = json.dumps(
results, indent=2, default=_handle_non_serializable, ensure_ascii=False
)
Expand Down
21 changes: 20 additions & 1 deletion lm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import sys
from itertools import islice
from pathlib import Path
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional

import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
from torch.utils.collect_env import get_pretty_env_info
from transformers import __version__ as trans_version


logging.basicConfig(
Expand Down Expand Up @@ -356,6 +358,23 @@ def get_git_commit_hash():
return git_hash


def add_env_info(storage: Dict[str, Any]):
try:
pretty_env_info = get_pretty_env_info()
except Exception as err:
pretty_env_info = str(err)
transformers_version = "Transformers: %s" % trans_version
upper_dir_commit = get_commit_from_path(
Path(os.getcwd(), "..")
) # git hash of upper repo if exists
added_info = {
"pretty_env_info": pretty_env_info,
"transformers_version": transformers_version,
"upper_git_hash": upper_dir_commit, # in case this repo is submodule
}
storage.update(added_info)


def ignore_constructor(loader, node):
return node

Expand Down

0 comments on commit f5c4eaf

Please sign in to comment.