From 706ea7dd40ba60a98dea5f37695d143d91c98b6c Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Fri, 2 Feb 2024 18:22:44 -0800 Subject: [PATCH] Refactoring the function to accept list of metric names instead of a dictionary of metrics. (#938) * .. * undoing prev commit * Refactoring the function to accept list of metric names instead of dictionary * .. * .. * .. * .. --- llmfoundry/utils/builders.py | 4 +--- scripts/eval/eval.py | 3 ++- scripts/train/train.py | 3 ++- tests/utils/test_builders.py | 8 +------- 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 457f146986..49d4eff1cb 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -28,7 +28,6 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from torch.optim.optimizer import Optimizer -from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase from llmfoundry.callbacks import (AsyncEval, EvalGauntlet, FDiffMetrics, @@ -108,9 +107,8 @@ def build_eval_loaders( def add_metrics_to_eval_loaders( evaluators: List[Evaluator], - metrics: Dict[str, Metric], + metric_names: List[str], ) -> List[Evaluator]: - metric_names = list(metrics.keys()) eval_loaders, other_evaluators = [], [] for evaluator in evaluators: if evaluator.metric_names == []: diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index fb4b75ec31..303f1bd6bc 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -184,7 +184,8 @@ def evaluate_model( # Now add the eval metrics if eval_loader_config is not None: train_metrics = composer_model.get_metrics(is_train=True) - evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics) + evaluators = add_metrics_to_eval_loaders(evaluators, + list(train_metrics.keys())) if eval_gauntlet_df is None and eval_gauntlet_callback is not None: eval_gauntlet_df = pd.DataFrame( diff --git a/scripts/train/train.py b/scripts/train/train.py index fe69a87422..dbaaf13ebc 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -544,7 +544,8 @@ def main(cfg: DictConfig) -> Trainer: # Now add the eval metrics if eval_loader_config is not None and not use_async_eval: train_metrics = model.get_metrics(is_train=True) - evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics) + evaluators = add_metrics_to_eval_loaders(evaluators, + list(train_metrics.keys())) # Build the Trainer log.info('Building trainer...') diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index b35e053c5d..08c3504491 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -335,13 +335,7 @@ def test_add_metrics_to_eval_loaders(): ) ] - new_evaluators = add_metrics_to_eval_loaders( - evaluators, - { - 'new1': 'foo', - 'new2': 'bar' - }, # type: ignore - ) + new_evaluators = add_metrics_to_eval_loaders(evaluators, ['new1', 'new2']) assert len(new_evaluators) == 3 assert new_evaluators[0].label == 'second'