From 88c3023c66745a095a6678182de10158779c6417 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Sat, 10 Aug 2024 23:51:17 +0000 Subject: [PATCH] Simplify CL API --- .../callbacks/curriculum_learning_callback.py | 61 ++++++++++--------- .../test_curriculum_learning_callback.py | 48 ++++++++++----- 2 files changed, 65 insertions(+), 44 deletions(-) diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 449ab338bc..3899e59010 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -9,7 +9,7 @@ import copy import logging -from typing import Any +from typing import Any, Union from composer import DataSpec from composer.core import State, Time, TimeUnit, ensure_time @@ -32,19 +32,21 @@ class CurriculumLearning(CallbackWithConfig): """Starts an epoch with a different dataset when resuming from a checkpoint. + Example duration: + tok Example schedule: [ { 'duration': tok, - 'train_loader': , # matches top level train_loader + 'dataset': , }, { 'duration': tok, - 'train_loader': , + 'dataset': , }, { 'duration': tok, - 'train_loader': , + 'dataset': , ], ] @@ -53,48 +55,47 @@ class CurriculumLearning(CallbackWithConfig): being used. Note that this is the full train config and must contain the 'train_loader', 'device_train_batch_size', and 'tokenizer' keys. + duration (Union[Time, str, int]): The duration of the first datamix + (which corresponds to the train_loader). schedule (list[dict[str, Any]]): The list of datamixes to use and their durations. Duration units must match max_duration and be in terms of a TimeUnit that is supported by Iteration. The duration values must be positive. There must be at least one datamix in the schedule. The - first datamix in the schedule must match the train_loader in the - train_config. On resumption, previously trained on datamixes and - durations cannot be changed. The duration of the current datamix - must be greater than the saved timestamp. The dataset must be a - StreamingDataset. + first datamix during training is not included in the schedule. On + resumption, previously trained on datamixes and durations cannot be + changed. The duration of the current datamix must be greater than + the saved timestamp. The dataset must be a StreamingDataset. """ def __init__( self, train_config: dict[str, Any], + duration: Union[Time, str, int], schedule: list[dict[str, Any]], ): # Ensure all duration units are in epochs or tokens and values are positive self._schedule = schedule if len(self._schedule) == 0: raise ValueError('The schedule must have at least one datamix.') - for index, datamix in enumerate(self._schedule): + first_datamix = { + 'duration': duration, + 'dataset': train_config['train_loader']['dataset'], + } + self._schedule.insert(0, first_datamix) + for datamix in self._schedule: self._validate_datamix(datamix) - if ( - index == 0 and - train_config['train_loader'] != datamix['train_loader'] - ): - raise ValueError(( - 'The first datamix in the schedule must match the ' - 'train_loader in the train_config.' - )) - self._schedule_index = 0 - self.device_train_batch_size = train_config['device_train_batch_size'] - self.tokenizer = None + self._train_loader_config: dict[str, Any] = train_config['train_loader'] + self._device_train_batch_size = train_config['device_train_batch_size'] + self._tokenizer = None def init(self, state: State, logger: Logger): del logger # unused if not hasattr(state.model, 'tokenizer'): raise ValueError('state.model must have a tokenizer attribute.') - self.tokenizer = state.model.tokenizer + self._tokenizer = state.model.tokenizer def before_load(self, state: State, logger: Logger): del logger # unused @@ -151,8 +152,10 @@ def iteration_start(self, state: State, logger: Logger): # which is stale clean_stale_shared_memory() datamix = copy.deepcopy(self._schedule[self._schedule_index]) + train_loader_config = copy.deepcopy(self._train_loader_config) + train_loader_config['dataset'].update(datamix['dataset']) data_spec = self._build_train_loader( - train_loader_config=datamix['train_loader'], + train_loader_config=train_loader_config, logger=logger, ) state.set_dataloader( @@ -211,18 +214,20 @@ def _build_train_loader( train_loader_config: dict[str, Any], logger: Logger, ) -> DataSpec: + del logger # unused + from llmfoundry.data.dataloader import build_dataloader # Copied from scripts/train/train.py log.info( f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}', ) - assert self.tokenizer is not None + assert self._tokenizer is not None try: return build_dataloader( train_loader_config, - self.tokenizer, - self.device_train_batch_size, + self._tokenizer, + self._device_train_batch_size, ) except BaseContextualError as e: e.location = TrainDataLoaderLocation @@ -260,5 +265,5 @@ def _validate_datamix(self, datamix: dict[str, Any]): 'Schedules can only be defined in terms of epochs or tokens.', ) - if 'train_loader' not in datamix: - raise ValueError('Each datamix must have a train_loader.') + if 'dataset' not in datamix: + raise ValueError('Each datamix must have a dataset.') diff --git a/tests/callbacks/test_curriculum_learning_callback.py b/tests/callbacks/test_curriculum_learning_callback.py index 075698a4c0..0e6a6c1efe 100644 --- a/tests/callbacks/test_curriculum_learning_callback.py +++ b/tests/callbacks/test_curriculum_learning_callback.py @@ -22,7 +22,7 @@ [ (None, '1ep'), ({ - 'dataset': 'some_dataset', + 'hf_name': 'some_dataset', }, '1ep'), (None, '10tok'), (None, ''), @@ -36,23 +36,29 @@ def test_curriculum_learning_callback_init( ): test_cfg = _get_test_cfg() test_cfg['train_loader'] = tiny_ft_dataloader_cfg - train_loader = test_cfg['train_loader'] if datamix is None else datamix + if datamix is None: + train_loader = test_cfg['train_loader']['dataset'] + else: + train_loader = datamix kwargs = { 'schedule': [{ 'duration': duration, - 'train_loader': train_loader, + 'dataset': train_loader, }, { 'duration': '2ep', - 'train_loader': {}, + 'dataset': {}, }], } + + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + if duration == '': del kwargs['schedule'][0]['duration'] if datamix is not None and len(datamix) == 0: - del kwargs['schedule'][0]['train_loader'] + del kwargs['schedule'][0]['dataset'] context = nullcontext() - if datamix is not None or duration == '': + if (datamix is not None and len(datamix) == 0) or duration == '': context = pytest.raises(ValueError) with context: callback = build_callback( @@ -85,13 +91,15 @@ def test_curriculum_learning_callback_before_load( kwargs = { 'schedule': [{ 'duration': duration, - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -123,13 +131,15 @@ def test_curriculum_learning_callback_after_load(build_tiny_mpt: Callable,): kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -168,13 +178,15 @@ def test_curriculum_learning_callback_iteration( kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -208,13 +220,15 @@ def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,): kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs, @@ -249,13 +263,15 @@ def test_curriculum_learning_callback_load_state_dict( kwargs = { 'schedule': [{ 'duration': '1ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }, { 'duration': '2ep', - 'train_loader': test_cfg['train_loader'], + 'dataset': test_cfg['train_loader']['dataset'], }], } + kwargs['duration'] = kwargs['schedule'].pop(0)['duration'] + callback = build_callback( 'curriculum_learning', kwargs=kwargs,