From f46c7c4e052988a5ed45aa8bb141d3abf9a828e9 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Thu, 15 Feb 2024 22:59:27 +0000 Subject: [PATCH] Add Iteration Event commit-id:26924316 --- composer/core/callback.py | 45 +++++++++++++- composer/core/engine.py | 11 +++- composer/core/event.py | 119 ++++++++++++++++++++---------------- composer/trainer/trainer.py | 2 +- 4 files changed, 118 insertions(+), 59 deletions(-) diff --git a/composer/core/callback.py b/composer/core/callback.py index 68c170bcab7..fef48ca1b19 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -135,6 +135,16 @@ def fit_start(self, state: State, logger: Logger) -> None: del state, logger # unused pass + def iteration_start(self, state: State, logger: Logger) -> None: + """Called on the :attr:`.Event.ITERATION_START` event. + + Args: + state (State): The training state. + logger (Logger): The logger. + """ + del state, logger # unused + pass + def epoch_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`.Event.EPOCH_START` event. @@ -299,8 +309,14 @@ def epoch_end(self, state: State, logger: Logger) -> None: .. note:: - :attr:`.State.timestamp` member variable :attr:`.Timestamp.epoch` - is incremented immediately before :attr:`.Event.EPOCH_END`. + The following :attr:`.State.timestamp` member variables are + incremented immediately before the :attr:`.Event.EPOCH_END` event. + + +--------------------------------------+ + | :attr:`.Timestamp.epoch` | + +--------------------------------------+ + | :attr:`.Timestamp.epoch_in_iteration`| + +--------------------------------------+ Args: state (State): The training state. @@ -319,6 +335,31 @@ def epoch_checkpoint(self, state: State, logger: Logger) -> None: del state, logger # unused pass + def iteration_end(self, state: State, logger: Logger) -> None: + """Called on the :attr:`.Event.ITERATION_END` event. + + .. note:: + + :attr:`.State.timestamp` member variable :attr:`.Timestamp.iteration` + is incremented immediately before :attr:`.Event.ITERATION_END`. + + Args: + state (State): The training state. + logger (Logger): The logger. + """ + del state, logger # unused + pass + + def iteration_checkpoint(self, state: State, logger: Logger) -> None: + """Called on the :attr:`.Event.ITERATION_CHECKPOINT` event. + + Args: + state (State): The training state. + logger (Logger): The logger. + """ + del state, logger # unused + pass + def predict_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`.Event.PREDICT_START` event. diff --git a/composer/core/engine.py b/composer/core/engine.py index 75d89d0f9a6..b2da4ca2eb9 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -353,7 +353,11 @@ def _assert_dataloader_and_duration_set(state: State, event: Event): # dataloader should be set on all events except INIT/BEFORE_LOAD/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END if event not in { - Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END + Event.INIT, + Event.BEFORE_LOAD, + Event.AFTER_LOAD, + Event.EVAL_STANDALONE_START, + Event.EVAL_STANDALONE_END, }: assert state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.' @@ -391,8 +395,9 @@ def _run_algorithms( run=True) if len(trace) > 0: - self.logger.log_traces( - {f'algorithm_traces/{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items()}) + self.logger.log_traces({ + f'algorithm_traces/{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items() + }) return trace diff --git a/composer/core/event.py b/composer/core/event.py index cb05d393fff..035a74196d5 100644 --- a/composer/core/event.py +++ b/composer/core/event.py @@ -21,39 +21,59 @@ class Event(StringEnum): # # # - for epoch in range(NUM_EPOCHS): - # - while True: - # - batch = next(dataloader) - if batch is None: - break - # + for iteration in range(NUM_ITERATIONS): + # + for epoch in range(NUM_EPOCHS): + # + while True: + # + batch = next(dataloader) + if batch is None: + break + # - # + # - # + # - for microbatch in batch.split(device_train_microbatch_size): + for microbatch in batch.split(device_train_microbatch_size): - # - outputs = model(batch) - # + # + outputs = model(batch) + # - # - loss = model.loss(outputs, batch) - # + # + loss = model.loss(outputs, batch) + # - # - loss.backward() - # + # + loss.backward() + # - # Un-scale gradients + # Un-scale gradients - # - optimizer.step() + # + optimizer.step() - # + # + + # + for eval_dataloader in eval_dataloaders: + if should_eval(batch=True): + # + for batch in eval_dataloader: + # + # + outputs, targets = model(batch) + # + metrics.update(outputs, targets) + # + # + + # + + # + # # for eval_dataloader in eval_dataloaders: @@ -70,25 +90,9 @@ class Event(StringEnum): # - # - # - - # - for eval_dataloader in eval_dataloaders: - if should_eval(batch=True): - # - for batch in eval_dataloader: - # - # - outputs, targets = model(batch) - # - metrics.update(outputs, targets) - # - # - - # - - # + # + # + # # Attributes: @@ -98,6 +102,7 @@ class Event(StringEnum): AFTER_LOAD: Immediately after checkpoint is loaded in constructor of :class:`~.trainer.Trainer`. FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically occur here. + ITERATION_START: Start of an iteration. EPOCH_START: Start of an epoch. BEFORE_DATALOADER: Immediately before the dataloader is called. AFTER_DATALOADER: Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms. @@ -125,7 +130,10 @@ class Event(StringEnum): EPOCH_END: End of an epoch. EPOCH_CHECKPOINT: After :attr:`.Event.EPOCH_END` and any epoch-wise evaluation. Saving checkpoints at this event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether - a checkpointshould be saved. + a checkpoint should be saved. + ITERATION_END: End of an iteration. + ITERATION_CHECKPOINT: After :attr:`.Event.ITERATION_END`. Saving checkpoints at this event allows the checkpoint + saver to determine whether a checkpoint should be saved. FIT_END: Invoked at the end of each call to :meth:`.Trainer.fit`. This event exists primarily for logging information and flushing callbacks. Algorithms should not transform the training state on this event, as any changes will not be preserved in checkpoints. @@ -148,6 +156,8 @@ class Event(StringEnum): AFTER_LOAD = 'after_load' FIT_START = 'fit_start' + ITERATION_START = 'iteration_start' + EPOCH_START = 'epoch_start' BEFORE_DATALOADER = 'before_dataloader' @@ -174,6 +184,9 @@ class Event(StringEnum): EPOCH_END = 'epoch_end' EPOCH_CHECKPOINT = 'epoch_checkpoint' + ITERATION_END = 'iteration_end' + ITERATION_CHECKPOINT = 'iteration_checkpoint' + FIT_END = 'fit_end' EVAL_BEFORE_ALL = 'eval_before_all' @@ -246,12 +259,12 @@ def is_eval(self) -> bool: return self.value.startswith('eval') -_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START, - Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD, - Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, Event.EVAL_BEFORE_FORWARD, - Event.PREDICT_START, Event.PREDICT_BATCH_START, Event.PREDICT_BEFORE_FORWARD, - Event.EVAL_STANDALONE_START) -_AFTER_EVENTS = (Event.AFTER_LOAD, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH, - Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END, - Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END, - Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END) +_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.ITERATION_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, + Event.BATCH_START, Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, + Event.BEFORE_BACKWARD, Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, + Event.EVAL_BEFORE_FORWARD, Event.PREDICT_START, Event.PREDICT_BATCH_START, + Event.PREDICT_BEFORE_FORWARD, Event.EVAL_STANDALONE_START) +_AFTER_EVENTS = (Event.AFTER_LOAD, Event.ITERATION_END, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, + Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, + Event.EVAL_AFTER_ALL, Event.EVAL_END, Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, + Event.PREDICT_END, Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 0d2349bf932..7255c9d7f00 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2078,7 +2078,7 @@ def _accumulate_time_across_ranks( def _train_loop(self) -> None: """Run training for the specified number of epochs and log results.""" - # print training start + # Log training start log.info('Using precision %s', self.state.precision) self.logger.log_hyperparameters( {'enabled_algorithms/' + algo.__class__.__name__: True for algo in self.state.algorithms})