Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add environment and transformers version logging in results dump #1464

Merged
merged 5 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np

from lm_eval import evaluator, utils
from lm_eval.logging_utils import WandbLogger
from lm_eval.logging_utils import WandbLogger, add_env_info
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.utils import make_table

Expand Down Expand Up @@ -310,6 +310,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None:
if args.log_samples:
samples = results.pop("samples")
add_env_info(results) # place this after popping out samples and before dumping
dumped = json.dumps(
results, indent=2, default=_handle_non_serializable, ensure_ascii=False
)
Expand Down
38 changes: 37 additions & 1 deletion lm_eval/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import copy
import json
import logging
import os
import re
from typing import Any, Dict, List, Literal, Tuple, Union
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd
from packaging.version import Version
from torch.utils.collect_env import get_pretty_env_info
from transformers import __version__ as trans_version

from lm_eval import utils

Expand Down Expand Up @@ -384,3 +388,35 @@ def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
self._log_samples_as_artifact(eval_preds, task_name)

self.run.log({f"{group}_eval_results": grouped_df})


def get_commit_from_path(repo_path: Path) -> Optional[str]:
git_folder = Path(repo_path, ".git")
if git_folder.is_file():
git_folder = Path(
git_folder.parent, git_folder.read_text().split("\n")[0].split(" ")[-1]
)
if Path(git_folder, "HEAD").exists():
head_name = Path(git_folder, "HEAD").read_text().split("\n")[0].split(" ")[-1]
head_ref = Path(git_folder, head_name)
git_hash = head_ref.read_text().replace("\n", "")
else:
git_hash = None
return git_hash


def add_env_info(storage: Dict[str, Any]):
try:
pretty_env_info = get_pretty_env_info()
except Exception as err:
pretty_env_info = str(err)
transformers_version = "Transformers: %s" % trans_version
upper_dir_commit = get_commit_from_path(
Path(os.getcwd(), "..")
) # git hash of upper repo if exists
added_info = {
"pretty_env_info": pretty_env_info,
"transformers_version": transformers_version,
"upper_git_hash": upper_dir_commit, # in case this repo is submodule
}
storage.update(added_info)
18 changes: 8 additions & 10 deletions lm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@
import inspect
import logging
import os
import pathlib
import re
import subprocess
import sys
from itertools import islice
from typing import (
Any,
Callable,
List,
)
from pathlib import Path
from typing import Any, Callable, List

import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined

from lm_eval.logging_utils import get_commit_from_path
haileyschoelkopf marked this conversation as resolved.
Show resolved Hide resolved


logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
Expand Down Expand Up @@ -291,7 +289,7 @@ def _wrapper(*args, **kwargs):


@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
def find_test_root(start_path: Path) -> Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
Expand All @@ -315,7 +313,7 @@ def run_task_tests(task_list: List[str]):
"""
import pytest

package_root = find_test_root(start_path=pathlib.Path(__file__))
package_root = find_test_root(start_path=Path(__file__))
task_string = " or ".join(task_list)
args = [
f"{package_root}/tests/test_version_stable.py",
Expand All @@ -339,9 +337,9 @@ def get_git_commit_hash():
try:
git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
git_hash = git_hash.decode()
except subprocess.CalledProcessError or FileNotFoundError:
except (subprocess.CalledProcessError, FileNotFoundError):
# FileNotFoundError occurs when git not installed on system
git_hash = None
git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
return git_hash


Expand Down
Loading