Skip to content

Commit

Permalink
Multi eval dataset logging (#603)
Browse files Browse the repository at this point in the history
* added support for multiple eval datasets and logging their metrics separately

* added support for multiple eval datasets and logging their metrics separately

* fixed comments, deleted accidentally added files

* added tests

* added multi-dataset tests, linting

* modified to use tmp_path

* modified to use tmp_path
  • Loading branch information
snarayan21 committed Sep 27, 2023
1 parent fd36398 commit 3d4fa0f
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 23 deletions.
45 changes: 29 additions & 16 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,16 @@ def validate_config(cfg: DictConfig):
"""Validates compatible model and dataloader selection."""
loaders = [cfg.train_loader]
if 'eval_loader' in cfg:
loaders.append(cfg.eval_loader)
eval_loader = cfg.eval_loader
if isinstance(eval_loader, ListConfig):
for loader in eval_loader:
if loader.label is None:
raise ValueError(
'When specifying multiple evaluation datasets, each one must include the \
`label` attribute.')
loaders.append(loader)
else:
loaders.append(eval_loader)
for loader in loaders:
if loader.name == 'text':
if cfg.model.name in ['hf_prefix_lm', 'hf_t5']:
Expand Down Expand Up @@ -245,10 +254,8 @@ def main(cfg: DictConfig) -> Trainer:
must_exist=False,
default_value=None,
convert=True)
eval_loader_config: Optional[DictConfig] = pop_config(cfg,
'eval_loader',
must_exist=False,
default_value=None)
eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config(
cfg, 'eval_loader', must_exist=False, default_value=None)
icl_tasks_config: Optional[Union[ListConfig,
str]] = pop_config(cfg,
'icl_tasks',
Expand Down Expand Up @@ -466,15 +473,21 @@ def main(cfg: DictConfig) -> Trainer:
## Evaluation
print('Building eval loader...')
evaluators = []
eval_loader = None
eval_loaders = []
if eval_loader_config is not None:
eval_dataloader = build_dataloader(eval_loader_config, tokenizer,
device_eval_batch_size)
eval_loader = Evaluator(
label='eval',
dataloader=eval_dataloader,
metric_names=[], # we will add these after model is created
)
is_multi_eval = isinstance(eval_loader_config, ListConfig)
eval_configs = eval_loader_config if is_multi_eval else [
eval_loader_config
]
for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader = Evaluator(
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
metric_names=[], # we will add these after model is created
)
eval_loaders.append(eval_loader)

eval_gauntlet_callback = None

Expand Down Expand Up @@ -514,11 +527,11 @@ def main(cfg: DictConfig) -> Trainer:

# Now add the eval metrics
if eval_loader_config is not None:
assert eval_loader is not None
assert model.train_metrics is not None
eval_metric_names = list(model.train_metrics.keys())
eval_loader.metric_names = eval_metric_names
evaluators.insert(0, eval_loader) # Put the base eval_loader first
for eval_loader in eval_loaders:
eval_loader.metric_names = eval_metric_names
evaluators.insert(0, eval_loader) # Put the base eval_loaders first

# Build the Trainer
print('Building trainer...')
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data_prep_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def test_download_script_from_api():

def test_json_script_from_api():
# test calling it directly
path = os.path.join(os.getcwd(), 'my-copy-c4-3')
path = os.path.join(os.getcwd(), 'my-copy-arxiv-1')
shutil.rmtree(path, ignore_errors=True)
main_json(
Namespace(
**{
'path': 'scripts/data_prep/example_data/arxiv.jsonl',
'out_root': './my-copy-c4-3',
'out_root': './my-copy-arxiv-1',
'compression': None,
'split': 'train',
'concat_tokens': None,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_train_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,23 @@ def test_invalid_name_in_scheduler_cfg_errors(self,
main(cfg)
assert str(exception_info.value
) == 'Not sure how to build scheduler: invalid-scheduler'

def test_no_label_multiple_eval_datasets(self, cfg: DictConfig) -> None:
data_local = './my-copy-c4-multi-eval'
make_fake_index_file(f'{data_local}/train/index.json')
make_fake_index_file(f'{data_local}/val/index.json')
cfg.train_loader.dataset.local = data_local
# Set up multiple eval datasets
first_eval_loader = cfg.eval_loader
first_eval_loader.dataset.local = data_local
second_eval_loader = copy.deepcopy(first_eval_loader)
# Set the first eval dataloader to have no label
first_eval_loader.label = None
second_eval_loader.label = 'eval_1'
cfg.eval_loader = om.create([first_eval_loader, second_eval_loader])
with pytest.raises(ValueError) as exception_info:
main(cfg)
assert str(
exception_info.value
) == 'When specifying multiple evaluation datasets, each one must include the \
`label` attribute.'
84 changes: 79 additions & 5 deletions tests/test_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import copy
import os
import pathlib
import shutil
import sys
from argparse import Namespace
Expand All @@ -16,13 +18,14 @@
sys.path.append(repo_dir)

from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402
from scripts.data_prep.convert_dataset_json import \
main as main_json # noqa: E402
from scripts.train.train import main # noqa: E402


def create_c4_dataset_xsmall(prefix: str) -> str:
def create_c4_dataset_xsmall(path: pathlib.Path) -> str:
"""Creates a small mocked version of the C4 dataset."""
c4_dir = os.path.join(os.getcwd(), f'my-copy-c4-{prefix}')
shutil.rmtree(c4_dir, ignore_errors=True)
c4_dir = os.path.join(path, f'my-copy-c4')
downloaded_split = 'val_xsmall' # very fast to convert

# Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188
Expand Down Expand Up @@ -52,6 +55,28 @@ def create_c4_dataset_xsmall(prefix: str) -> str:
return c4_dir


def create_arxiv_dataset(path: pathlib.Path) -> str:
"""Creates an arxiv dataset."""
arxiv_dir = os.path.join(path, f'my-copy-arxiv')
downloaded_split = 'train'

main_json(
Namespace(
**{
'path': 'data_prep/example_data/arxiv.jsonl',
'out_root': arxiv_dir,
'compression': None,
'split': downloaded_split,
'concat_tokens': None,
'bos_text': None,
'eos_text': None,
'no_wrap': False,
'num_workers': None
}))

return arxiv_dir


def gpt_tiny_cfg(dataset_name: str, device: str):
"""Create gpt tiny cfg."""
conf_path: str = os.path.join(repo_dir,
Expand Down Expand Up @@ -89,9 +114,9 @@ def set_correct_cwd():
os.chdir('..')


def test_train_gauntlet(set_correct_cwd: Any):
def test_train_gauntlet(set_correct_cwd: Any, tmp_path: pathlib.Path):
"""Test training run with a small dataset."""
dataset_name = create_c4_dataset_xsmall('cpu-gauntlet')
dataset_name = create_c4_dataset_xsmall(tmp_path)
test_cfg = gpt_tiny_cfg(dataset_name, 'cpu')
test_cfg.icl_tasks = ListConfig([
DictConfig({
Expand Down Expand Up @@ -150,3 +175,52 @@ def test_train_gauntlet(set_correct_cwd: Any):
inmemorylogger.data['icl/metrics/eval_gauntlet/average'][-1], tuple)

assert inmemorylogger.data['icl/metrics/eval_gauntlet/average'][-1][-1] == 0


def test_train_multi_eval(set_correct_cwd: Any, tmp_path: pathlib.Path):
"""Test training run with multiple eval datasets."""
c4_dataset_name = create_c4_dataset_xsmall(tmp_path)
test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu')
# Set up multiple eval dataloaders
first_eval_loader = test_cfg.eval_loader
first_eval_loader.label = 'c4'
# Create second eval dataloader using the arxiv dataset.
second_eval_loader = copy.deepcopy(first_eval_loader)
arxiv_dataset_name = create_arxiv_dataset(tmp_path)
second_eval_loader.data_local = arxiv_dataset_name
second_eval_loader.label = 'arxiv'
test_cfg.eval_loader = om.create([first_eval_loader, second_eval_loader])
test_cfg.eval_subset_num_batches = 1 # -1 to evaluate on all batches

test_cfg.max_duration = '1ba'
test_cfg.eval_interval = '1ba'
test_cfg.loggers = DictConfig({'inmemory': DictConfig({})})
trainer = main(test_cfg)

assert isinstance(trainer.logger.destinations, tuple)

assert len(trainer.logger.destinations) > 0
inmemorylogger = trainer.logger.destinations[
0] # pyright: ignore [reportGeneralTypeIssues]
assert isinstance(inmemorylogger, InMemoryLogger)
print(inmemorylogger.data.keys())

# Checks for first eval dataloader
assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys()
assert isinstance(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list)
assert len(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0
assert isinstance(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple)

# Checks for second eval dataloader
assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys(
)
assert isinstance(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list)
assert len(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0
assert isinstance(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1],
tuple)

0 comments on commit 3d4fa0f

Please sign in to comment.