From f43be7c0eeb12663ffef9d599ec1d6fd9b7c06c9 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Fri, 1 Mar 2024 16:13:27 +0000 Subject: [PATCH] Add Iteration related Events Adds ITERATION_START, ITERATION_END, and ITERATION_CHECKPOINT to Event, Engine, and Callback. Increments iteration during training based on iteration length defined in State. Iteration length is a private variable in State and should have no effect by default. commit-id:bdbe33f2 --- composer/core/callback.py | 45 ++++++++++++- composer/core/engine.py | 20 ++++-- composer/core/event.py | 119 +++++++++++++++++++--------------- composer/core/state.py | 20 +++++- composer/trainer/trainer.py | 13 +++- tests/test_events.py | 3 + tests/trainer/test_trainer.py | 19 ++++++ 7 files changed, 175 insertions(+), 64 deletions(-) diff --git a/composer/core/callback.py b/composer/core/callback.py index 68c170bcab..fef48ca1b1 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -135,6 +135,16 @@ def fit_start(self, state: State, logger: Logger) -> None: del state, logger # unused pass + def iteration_start(self, state: State, logger: Logger) -> None: + """Called on the :attr:`.Event.ITERATION_START` event. + + Args: + state (State): The training state. + logger (Logger): The logger. + """ + del state, logger # unused + pass + def epoch_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`.Event.EPOCH_START` event. @@ -299,8 +309,14 @@ def epoch_end(self, state: State, logger: Logger) -> None: .. note:: - :attr:`.State.timestamp` member variable :attr:`.Timestamp.epoch` - is incremented immediately before :attr:`.Event.EPOCH_END`. + The following :attr:`.State.timestamp` member variables are + incremented immediately before the :attr:`.Event.EPOCH_END` event. + + +--------------------------------------+ + | :attr:`.Timestamp.epoch` | + +--------------------------------------+ + | :attr:`.Timestamp.epoch_in_iteration`| + +--------------------------------------+ Args: state (State): The training state. @@ -319,6 +335,31 @@ def epoch_checkpoint(self, state: State, logger: Logger) -> None: del state, logger # unused pass + def iteration_end(self, state: State, logger: Logger) -> None: + """Called on the :attr:`.Event.ITERATION_END` event. + + .. note:: + + :attr:`.State.timestamp` member variable :attr:`.Timestamp.iteration` + is incremented immediately before :attr:`.Event.ITERATION_END`. + + Args: + state (State): The training state. + logger (Logger): The logger. + """ + del state, logger # unused + pass + + def iteration_checkpoint(self, state: State, logger: Logger) -> None: + """Called on the :attr:`.Event.ITERATION_CHECKPOINT` event. + + Args: + state (State): The training state. + logger (Logger): The logger. + """ + del state, logger # unused + pass + def predict_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`.Event.PREDICT_START` event. diff --git a/composer/core/engine.py b/composer/core/engine.py index 75d89d0f9a..90646cd480 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -353,7 +353,11 @@ def _assert_dataloader_and_duration_set(state: State, event: Event): # dataloader should be set on all events except INIT/BEFORE_LOAD/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END if event not in { - Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END + Event.INIT, + Event.BEFORE_LOAD, + Event.AFTER_LOAD, + Event.EVAL_STANDALONE_START, + Event.EVAL_STANDALONE_END, }: assert state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.' @@ -384,15 +388,17 @@ def _run_algorithms( exit_code = algorithm.apply(event, self.state, self.logger) trace_key = f'{algorithm}/{event}' - trace[trace_key] = Trace(name=algorithm.__class__.__name__, - event=event, - exit_code=exit_code, - order=order, - run=True) + trace[trace_key] = Trace( + name=algorithm.__class__.__name__, + event=event, + exit_code=exit_code, + order=order, + run=True, + ) if len(trace) > 0: self.logger.log_traces( - {f'algorithm_traces/{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items()}) + ({f'algorithm_traces/{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items()})) return trace diff --git a/composer/core/event.py b/composer/core/event.py index cb05d393ff..035a74196d 100644 --- a/composer/core/event.py +++ b/composer/core/event.py @@ -21,39 +21,59 @@ class Event(StringEnum): # # # - for epoch in range(NUM_EPOCHS): - # - while True: - # - batch = next(dataloader) - if batch is None: - break - # + for iteration in range(NUM_ITERATIONS): + # + for epoch in range(NUM_EPOCHS): + # + while True: + # + batch = next(dataloader) + if batch is None: + break + # - # + # - # + # - for microbatch in batch.split(device_train_microbatch_size): + for microbatch in batch.split(device_train_microbatch_size): - # - outputs = model(batch) - # + # + outputs = model(batch) + # - # - loss = model.loss(outputs, batch) - # + # + loss = model.loss(outputs, batch) + # - # - loss.backward() - # + # + loss.backward() + # - # Un-scale gradients + # Un-scale gradients - # - optimizer.step() + # + optimizer.step() - # + # + + # + for eval_dataloader in eval_dataloaders: + if should_eval(batch=True): + # + for batch in eval_dataloader: + # + # + outputs, targets = model(batch) + # + metrics.update(outputs, targets) + # + # + + # + + # + # # for eval_dataloader in eval_dataloaders: @@ -70,25 +90,9 @@ class Event(StringEnum): # - # - # - - # - for eval_dataloader in eval_dataloaders: - if should_eval(batch=True): - # - for batch in eval_dataloader: - # - # - outputs, targets = model(batch) - # - metrics.update(outputs, targets) - # - # - - # - - # + # + # + # # Attributes: @@ -98,6 +102,7 @@ class Event(StringEnum): AFTER_LOAD: Immediately after checkpoint is loaded in constructor of :class:`~.trainer.Trainer`. FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically occur here. + ITERATION_START: Start of an iteration. EPOCH_START: Start of an epoch. BEFORE_DATALOADER: Immediately before the dataloader is called. AFTER_DATALOADER: Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms. @@ -125,7 +130,10 @@ class Event(StringEnum): EPOCH_END: End of an epoch. EPOCH_CHECKPOINT: After :attr:`.Event.EPOCH_END` and any epoch-wise evaluation. Saving checkpoints at this event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether - a checkpointshould be saved. + a checkpoint should be saved. + ITERATION_END: End of an iteration. + ITERATION_CHECKPOINT: After :attr:`.Event.ITERATION_END`. Saving checkpoints at this event allows the checkpoint + saver to determine whether a checkpoint should be saved. FIT_END: Invoked at the end of each call to :meth:`.Trainer.fit`. This event exists primarily for logging information and flushing callbacks. Algorithms should not transform the training state on this event, as any changes will not be preserved in checkpoints. @@ -148,6 +156,8 @@ class Event(StringEnum): AFTER_LOAD = 'after_load' FIT_START = 'fit_start' + ITERATION_START = 'iteration_start' + EPOCH_START = 'epoch_start' BEFORE_DATALOADER = 'before_dataloader' @@ -174,6 +184,9 @@ class Event(StringEnum): EPOCH_END = 'epoch_end' EPOCH_CHECKPOINT = 'epoch_checkpoint' + ITERATION_END = 'iteration_end' + ITERATION_CHECKPOINT = 'iteration_checkpoint' + FIT_END = 'fit_end' EVAL_BEFORE_ALL = 'eval_before_all' @@ -246,12 +259,12 @@ def is_eval(self) -> bool: return self.value.startswith('eval') -_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START, - Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD, - Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, Event.EVAL_BEFORE_FORWARD, - Event.PREDICT_START, Event.PREDICT_BATCH_START, Event.PREDICT_BEFORE_FORWARD, - Event.EVAL_STANDALONE_START) -_AFTER_EVENTS = (Event.AFTER_LOAD, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH, - Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END, - Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END, - Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END) +_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.ITERATION_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, + Event.BATCH_START, Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, + Event.BEFORE_BACKWARD, Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, + Event.EVAL_BEFORE_FORWARD, Event.PREDICT_START, Event.PREDICT_BATCH_START, + Event.PREDICT_BEFORE_FORWARD, Event.EVAL_STANDALONE_START) +_AFTER_EVENTS = (Event.AFTER_LOAD, Event.ITERATION_END, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, + Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, + Event.EVAL_AFTER_ALL, Event.EVAL_END, Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, + Event.PREDICT_END, Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END) diff --git a/composer/core/state.py b/composer/core/state.py index f99f6050ef..e4019ee2f1 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -29,7 +29,7 @@ from composer.core.event import Event from composer.core.precision import Precision from composer.core.serializable import Serializable -from composer.core.time import Time, Timestamp, TimeUnit +from composer.core.time import Time, Timestamp, TimeUnit, ensure_time from composer.devices import Device from composer.utils import (batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed, reproducibility) @@ -412,6 +412,8 @@ def __init__( self.dataset_resumption = dataset_resumption or {} self._max_duration = None self.max_duration = max_duration + self.__iteration_length = None + self._iteration_length = self.__iteration_length self.save_metrics = save_metrics self._train_dataloader = train_dataloader @@ -606,6 +608,22 @@ def get_elapsed_duration(self) -> Optional[Time[float]]: return None return self.timestamp.get(self.max_duration.unit) / self.max_duration + @property + def _iteration_length(self): + """The length of an iteration.""" + return self.__iteration_length + + @_iteration_length.setter + def _iteration_length(self, iteration_length: Optional[Union[str, Time[int]]]): + if iteration_length is None: + self.__iteration_length = None + return + if isinstance(iteration_length, str): + iteration_length = ensure_time(iteration_length, TimeUnit.EPOCH) + if iteration_length.unit != TimeUnit.EPOCH: + raise NotImplementedError(f'{iteration_length.unit} is not allowed as a unit for iteration_length.') + self.__iteration_length = iteration_length + def stop_training(self): """Gracefully stop training. diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 4932c5dea6..f9c6e3469b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2078,7 +2078,7 @@ def _accumulate_time_across_ranks( def _train_loop(self) -> None: """Run training for the specified number of epochs and log results.""" - # print training start + # Log training start log.info('Using precision %s', self.state.precision) self.logger.log_hyperparameters( {'enabled_algorithms/' + algo.__class__.__name__: True for algo in self.state.algorithms}) @@ -2109,6 +2109,9 @@ def _train_loop(self) -> None: log.debug('Starting training loop') while self.state.timestamp < self.state.max_duration: + if int(self.state.timestamp.epoch_in_iteration) == 0 and int(self.state.timestamp.batch_in_epoch) == 0: + self.engine.run_event(Event.ITERATION_START) + if int(self.state.timestamp.batch_in_epoch) == 0: self.engine.run_event(Event.EPOCH_START) self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) @@ -2244,6 +2247,14 @@ def _train_loop(self) -> None: self.engine.run_event(Event.EPOCH_CHECKPOINT) + # Increment iteration + if (self.state._iteration_length is not None and + self.state.timestamp.epoch_in_iteration == self.state._iteration_length): + self.state.previous_timestamp = self.state.timestamp + self.state.timestamp = self.state.timestamp.to_next_iteration() + self.engine.run_event(Event.ITERATION_END) + self.engine.run_event(Event.ITERATION_CHECKPOINT) + # Log final time values self.logger.log_metrics({ 'time/epoch': self.state.timestamp.epoch.value, diff --git a/tests/test_events.py b/tests/test_events.py index c81feea0b0..37124ff4ce 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -152,6 +152,7 @@ def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, nu Event.INIT: 1, Event.BEFORE_LOAD: 1, Event.AFTER_LOAD: 1, + Event.ITERATION_START: 1, Event.EPOCH_START: num_epochs, Event.BATCH_START: total_steps, Event.BEFORE_DATALOADER: total_steps + num_epochs, # extra call per epoch when dataloader is exhausted @@ -168,6 +169,8 @@ def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, nu Event.BATCH_CHECKPOINT: total_steps, Event.EPOCH_END: num_epochs, Event.EPOCH_CHECKPOINT: num_epochs, + Event.ITERATION_END: 0, + Event.ITERATION_CHECKPOINT: 0, Event.EVAL_BEFORE_ALL: total_evals, Event.EVAL_START: total_evals_start, Event.EVAL_BATCH_START: total_eval_steps, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 97ca2005ee..4b6effb685 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1138,6 +1138,25 @@ def test_compile_uncompile_model_weights_trainer_fit( assert (torch.equal(next(compiled_model_trainer.state.model.parameters()), next(uncompiled_model_trainer.state.model.parameters()))) + def test_iteration( + self, + train_dataloader: DataLoader, + model: ComposerModel, + ): + """Tests iteration is properly incremented during training when _iteration_length is set.""" + + # Train with max_duration set to 5 epochs with 2 epoch per iteration + trainer = Trainer( + model=model, + max_duration='5ep', + train_dataloader=train_dataloader, + ) + trainer.state._iteration_length = '2ep' + trainer.fit() + + assert trainer.state.timestamp.epoch == Time(5, TimeUnit.EPOCH) + assert trainer.state.timestamp.iteration == Time(2, TimeUnit.ITERATION) + @world_size(1, 2) @device('cpu', 'gpu', 'gpu-amp', precision=True)