Skip to content

Commit

Permalink
Simplify CL API
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Sep 19, 2024
1 parent 02802c5 commit 88c3023
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 44 deletions.
61 changes: 33 additions & 28 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import copy
import logging
from typing import Any
from typing import Any, Union

from composer import DataSpec
from composer.core import State, Time, TimeUnit, ensure_time
Expand All @@ -32,19 +32,21 @@
class CurriculumLearning(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
Example duration:
<number>tok
Example schedule:
[
{
'duration': <number>tok,
'train_loader': <dataloader parameters>, # matches top level train_loader
'dataset': <dataset parameters>,
},
{
'duration': <number>tok,
'train_loader': <dataloader parameters>,
'dataset': <dataset parameters>,
},
{
'duration': <number>tok,
'train_loader': <dataloader parameters>,
'dataset': <dataset parameters>,
],
]
Expand All @@ -53,48 +55,47 @@ class CurriculumLearning(CallbackWithConfig):
being used. Note that this is the full train config and must
contain the 'train_loader', 'device_train_batch_size', and
'tokenizer' keys.
duration (Union[Time, str, int]): The duration of the first datamix
(which corresponds to the train_loader).
schedule (list[dict[str, Any]]): The list of datamixes to use and their
durations. Duration units must match max_duration and be in terms of
a TimeUnit that is supported by Iteration. The duration values must
be positive. There must be at least one datamix in the schedule. The
first datamix in the schedule must match the train_loader in the
train_config. On resumption, previously trained on datamixes and
durations cannot be changed. The duration of the current datamix
must be greater than the saved timestamp. The dataset must be a
StreamingDataset.
first datamix during training is not included in the schedule. On
resumption, previously trained on datamixes and durations cannot be
changed. The duration of the current datamix must be greater than
the saved timestamp. The dataset must be a StreamingDataset.
"""

def __init__(
self,
train_config: dict[str, Any],
duration: Union[Time, str, int],
schedule: list[dict[str, Any]],
):
# Ensure all duration units are in epochs or tokens and values are positive
self._schedule = schedule
if len(self._schedule) == 0:
raise ValueError('The schedule must have at least one datamix.')
for index, datamix in enumerate(self._schedule):
first_datamix = {
'duration': duration,
'dataset': train_config['train_loader']['dataset'],
}
self._schedule.insert(0, first_datamix)
for datamix in self._schedule:
self._validate_datamix(datamix)

if (
index == 0 and
train_config['train_loader'] != datamix['train_loader']
):
raise ValueError((
'The first datamix in the schedule must match the '
'train_loader in the train_config.'
))

self._schedule_index = 0
self.device_train_batch_size = train_config['device_train_batch_size']
self.tokenizer = None
self._train_loader_config: dict[str, Any] = train_config['train_loader']
self._device_train_batch_size = train_config['device_train_batch_size']
self._tokenizer = None

def init(self, state: State, logger: Logger):
del logger # unused

if not hasattr(state.model, 'tokenizer'):
raise ValueError('state.model must have a tokenizer attribute.')
self.tokenizer = state.model.tokenizer
self._tokenizer = state.model.tokenizer

def before_load(self, state: State, logger: Logger):
del logger # unused
Expand Down Expand Up @@ -151,8 +152,10 @@ def iteration_start(self, state: State, logger: Logger):
# which is stale
clean_stale_shared_memory()
datamix = copy.deepcopy(self._schedule[self._schedule_index])
train_loader_config = copy.deepcopy(self._train_loader_config)
train_loader_config['dataset'].update(datamix['dataset'])
data_spec = self._build_train_loader(
train_loader_config=datamix['train_loader'],
train_loader_config=train_loader_config,
logger=logger,
)
state.set_dataloader(
Expand Down Expand Up @@ -211,18 +214,20 @@ def _build_train_loader(
train_loader_config: dict[str, Any],
logger: Logger,
) -> DataSpec:
del logger # unused

from llmfoundry.data.dataloader import build_dataloader

# Copied from scripts/train/train.py
log.info(
f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}',
)
assert self.tokenizer is not None
assert self._tokenizer is not None
try:
return build_dataloader(
train_loader_config,
self.tokenizer,
self.device_train_batch_size,
self._tokenizer,
self._device_train_batch_size,
)
except BaseContextualError as e:
e.location = TrainDataLoaderLocation
Expand Down Expand Up @@ -260,5 +265,5 @@ def _validate_datamix(self, datamix: dict[str, Any]):
'Schedules can only be defined in terms of epochs or tokens.',
)

if 'train_loader' not in datamix:
raise ValueError('Each datamix must have a train_loader.')
if 'dataset' not in datamix:
raise ValueError('Each datamix must have a dataset.')
48 changes: 32 additions & 16 deletions tests/callbacks/test_curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
[
(None, '1ep'),
({
'dataset': 'some_dataset',
'hf_name': 'some_dataset',
}, '1ep'),
(None, '10tok'),
(None, ''),
Expand All @@ -36,23 +36,29 @@ def test_curriculum_learning_callback_init(
):
test_cfg = _get_test_cfg()
test_cfg['train_loader'] = tiny_ft_dataloader_cfg
train_loader = test_cfg['train_loader'] if datamix is None else datamix
if datamix is None:
train_loader = test_cfg['train_loader']['dataset']
else:
train_loader = datamix
kwargs = {
'schedule': [{
'duration': duration,
'train_loader': train_loader,
'dataset': train_loader,
}, {
'duration': '2ep',
'train_loader': {},
'dataset': {},
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

if duration == '':
del kwargs['schedule'][0]['duration']
if datamix is not None and len(datamix) == 0:
del kwargs['schedule'][0]['train_loader']
del kwargs['schedule'][0]['dataset']

context = nullcontext()
if datamix is not None or duration == '':
if (datamix is not None and len(datamix) == 0) or duration == '':
context = pytest.raises(ValueError)
with context:
callback = build_callback(
Expand Down Expand Up @@ -85,13 +91,15 @@ def test_curriculum_learning_callback_before_load(
kwargs = {
'schedule': [{
'duration': duration,
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -123,13 +131,15 @@ def test_curriculum_learning_callback_after_load(build_tiny_mpt: Callable,):
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -168,13 +178,15 @@ def test_curriculum_learning_callback_iteration(
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -208,13 +220,15 @@ def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,):
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down Expand Up @@ -249,13 +263,15 @@ def test_curriculum_learning_callback_load_state_dict(
kwargs = {
'schedule': [{
'duration': '1ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}, {
'duration': '2ep',
'train_loader': test_cfg['train_loader'],
'dataset': test_cfg['train_loader']['dataset'],
}],
}

kwargs['duration'] = kwargs['schedule'].pop(0)['duration']

callback = build_callback(
'curriculum_learning',
kwargs=kwargs,
Expand Down

0 comments on commit 88c3023

Please sign in to comment.