Skip to content

Commit

Permalink
Refactoring the function to accept list of metric names instead of a …
Browse files Browse the repository at this point in the history
…dictionary of metrics. (#938)

* ..

* undoing prev commit

* Refactoring the  function to accept list of metric names instead of dictionary

* ..

* ..

* ..

* ..
  • Loading branch information
ShashankMosaicML committed Feb 3, 2024
1 parent 15ee0ac commit 706ea7d
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 12 deletions.
4 changes: 1 addition & 3 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.callbacks import (AsyncEval, EvalGauntlet, FDiffMetrics,
Expand Down Expand Up @@ -108,9 +107,8 @@ def build_eval_loaders(

def add_metrics_to_eval_loaders(
evaluators: List[Evaluator],
metrics: Dict[str, Metric],
metric_names: List[str],
) -> List[Evaluator]:
metric_names = list(metrics.keys())
eval_loaders, other_evaluators = [], []
for evaluator in evaluators:
if evaluator.metric_names == []:
Expand Down
3 changes: 2 additions & 1 deletion scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def evaluate_model(
# Now add the eval metrics
if eval_loader_config is not None:
train_metrics = composer_model.get_metrics(is_train=True)
evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics)
evaluators = add_metrics_to_eval_loaders(evaluators,
list(train_metrics.keys()))

if eval_gauntlet_df is None and eval_gauntlet_callback is not None:
eval_gauntlet_df = pd.DataFrame(
Expand Down
3 changes: 2 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ def main(cfg: DictConfig) -> Trainer:
# Now add the eval metrics
if eval_loader_config is not None and not use_async_eval:
train_metrics = model.get_metrics(is_train=True)
evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics)
evaluators = add_metrics_to_eval_loaders(evaluators,
list(train_metrics.keys()))

# Build the Trainer
log.info('Building trainer...')
Expand Down
8 changes: 1 addition & 7 deletions tests/utils/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,7 @@ def test_add_metrics_to_eval_loaders():
)
]

new_evaluators = add_metrics_to_eval_loaders(
evaluators,
{
'new1': 'foo',
'new2': 'bar'
}, # type: ignore
)
new_evaluators = add_metrics_to_eval_loaders(evaluators, ['new1', 'new2'])
assert len(new_evaluators) == 3

assert new_evaluators[0].label == 'second'
Expand Down

0 comments on commit 706ea7d

Please sign in to comment.