Skip to content

Commit

Permalink
Add Iteration Event
Browse files Browse the repository at this point in the history
commit-id:26924316
  • Loading branch information
b-chu committed Feb 27, 2024
1 parent a8f921b commit f46c7c4
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 59 deletions.
45 changes: 43 additions & 2 deletions composer/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ def fit_start(self, state: State, logger: Logger) -> None:
del state, logger # unused
pass

def iteration_start(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.ITERATION_START` event.
Args:
state (State): The training state.
logger (Logger): The logger.
"""
del state, logger # unused
pass

def epoch_start(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.EPOCH_START` event.
Expand Down Expand Up @@ -299,8 +309,14 @@ def epoch_end(self, state: State, logger: Logger) -> None:
.. note::
:attr:`.State.timestamp` member variable :attr:`.Timestamp.epoch`
is incremented immediately before :attr:`.Event.EPOCH_END`.
The following :attr:`.State.timestamp` member variables are
incremented immediately before the :attr:`.Event.EPOCH_END` event.
+--------------------------------------+
| :attr:`.Timestamp.epoch` |
+--------------------------------------+
| :attr:`.Timestamp.epoch_in_iteration`|
+--------------------------------------+
Args:
state (State): The training state.
Expand All @@ -319,6 +335,31 @@ def epoch_checkpoint(self, state: State, logger: Logger) -> None:
del state, logger # unused
pass

def iteration_end(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.ITERATION_END` event.
.. note::
:attr:`.State.timestamp` member variable :attr:`.Timestamp.iteration`
is incremented immediately before :attr:`.Event.ITERATION_END`.
Args:
state (State): The training state.
logger (Logger): The logger.
"""
del state, logger # unused
pass

def iteration_checkpoint(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.ITERATION_CHECKPOINT` event.
Args:
state (State): The training state.
logger (Logger): The logger.
"""
del state, logger # unused
pass

def predict_start(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.PREDICT_START` event.
Expand Down
11 changes: 8 additions & 3 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,11 @@ def _assert_dataloader_and_duration_set(state: State, event: Event):

# dataloader should be set on all events except INIT/BEFORE_LOAD/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END
if event not in {
Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END
Event.INIT,
Event.BEFORE_LOAD,
Event.AFTER_LOAD,
Event.EVAL_STANDALONE_START,
Event.EVAL_STANDALONE_END,
}:
assert state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.'

Expand Down Expand Up @@ -391,8 +395,9 @@ def _run_algorithms(
run=True)

if len(trace) > 0:
self.logger.log_traces(
{f'algorithm_traces/{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items()})
self.logger.log_traces({
f'algorithm_traces/{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items()
})

return trace

Expand Down
119 changes: 66 additions & 53 deletions composer/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,59 @@ class Event(StringEnum):
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
# <EPOCH_START>
while True:
# <BEFORE_DATALOADER>
batch = next(dataloader)
if batch is None:
break
# <AFTER_DATALOADER>
for iteration in range(NUM_ITERATIONS):
# <ITERATION_START>
for epoch in range(NUM_EPOCHS):
# <EPOCH_START>
while True:
# <BEFORE_DATALOADER>
batch = next(dataloader)
if batch is None:
break
# <AFTER_DATALOADER>
# <BATCH_START>
# <BATCH_START>
# <BEFORE_TRAIN_BATCH>
# <BEFORE_TRAIN_BATCH>
for microbatch in batch.split(device_train_microbatch_size):
for microbatch in batch.split(device_train_microbatch_size):
# <BEFORE_FORWARD>
outputs = model(batch)
# <AFTER_FORWARD>
# <BEFORE_FORWARD>
outputs = model(batch)
# <AFTER_FORWARD>
# <BEFORE_LOSS>
loss = model.loss(outputs, batch)
# <AFTER_LOSS>
# <BEFORE_LOSS>
loss = model.loss(outputs, batch)
# <AFTER_LOSS>
# <BEFORE_BACKWARD>
loss.backward()
# <AFTER_BACKWARD>
# <BEFORE_BACKWARD>
loss.backward()
# <AFTER_BACKWARD>
# Un-scale gradients
# Un-scale gradients
# <AFTER_TRAIN_BATCH>
optimizer.step()
# <AFTER_TRAIN_BATCH>
optimizer.step()
# <BATCH_END>
# <BATCH_END>
# <BEFORE_EVAL_ALL>
for eval_dataloader in eval_dataloaders:
if should_eval(batch=True):
# <EVAL_START>
for batch in eval_dataloader:
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
outputs, targets = model(batch)
# <EVAL_AFTER_FORWARD>
metrics.update(outputs, targets)
# <EVAL_BATCH_END>
# <EVAL_END>
# <AFTER_EVAL_ALL>
# <BATCH_CHECKPOINT>
# <EPOCH_END>
# <BEFORE_EVAL_ALL>
for eval_dataloader in eval_dataloaders:
Expand All @@ -70,25 +90,9 @@ class Event(StringEnum):
# <AFTER_EVAL_ALL>
# <BATCH_CHECKPOINT>
# <EPOCH_END>
# <BEFORE_EVAL_ALL>
for eval_dataloader in eval_dataloaders:
if should_eval(batch=True):
# <EVAL_START>
for batch in eval_dataloader:
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
outputs, targets = model(batch)
# <EVAL_AFTER_FORWARD>
metrics.update(outputs, targets)
# <EVAL_BATCH_END>
# <EVAL_END>
# <AFTER_EVAL_ALL>
# <EPOCH_CHECKPOINT>
# <EPOCH_CHECKPOINT>
# <ITERATION_END>
# <ITERATION_CHECKPOINT>
# <FIT_END>
Attributes:
Expand All @@ -98,6 +102,7 @@ class Event(StringEnum):
AFTER_LOAD: Immediately after checkpoint is loaded in constructor of :class:`~.trainer.Trainer`.
FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically
occur here.
ITERATION_START: Start of an iteration.
EPOCH_START: Start of an epoch.
BEFORE_DATALOADER: Immediately before the dataloader is called.
AFTER_DATALOADER: Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms.
Expand Down Expand Up @@ -125,7 +130,10 @@ class Event(StringEnum):
EPOCH_END: End of an epoch.
EPOCH_CHECKPOINT: After :attr:`.Event.EPOCH_END` and any epoch-wise evaluation. Saving checkpoints at this
event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether
a checkpointshould be saved.
a checkpoint should be saved.
ITERATION_END: End of an iteration.
ITERATION_CHECKPOINT: After :attr:`.Event.ITERATION_END`. Saving checkpoints at this event allows the checkpoint
saver to determine whether a checkpoint should be saved.
FIT_END: Invoked at the end of each call to :meth:`.Trainer.fit`. This event exists primarily for logging information
and flushing callbacks. Algorithms should not transform the training state on this event, as any changes will not
be preserved in checkpoints.
Expand All @@ -148,6 +156,8 @@ class Event(StringEnum):
AFTER_LOAD = 'after_load'
FIT_START = 'fit_start'

ITERATION_START = 'iteration_start'

EPOCH_START = 'epoch_start'

BEFORE_DATALOADER = 'before_dataloader'
Expand All @@ -174,6 +184,9 @@ class Event(StringEnum):
EPOCH_END = 'epoch_end'
EPOCH_CHECKPOINT = 'epoch_checkpoint'

ITERATION_END = 'iteration_end'
ITERATION_CHECKPOINT = 'iteration_checkpoint'

FIT_END = 'fit_end'

EVAL_BEFORE_ALL = 'eval_before_all'
Expand Down Expand Up @@ -246,12 +259,12 @@ def is_eval(self) -> bool:
return self.value.startswith('eval')


_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START,
Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD,
Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, Event.EVAL_BEFORE_FORWARD,
Event.PREDICT_START, Event.PREDICT_BATCH_START, Event.PREDICT_BEFORE_FORWARD,
Event.EVAL_STANDALONE_START)
_AFTER_EVENTS = (Event.AFTER_LOAD, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH,
Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END,
Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END,
Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END)
_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.ITERATION_START, Event.EPOCH_START, Event.BEFORE_DATALOADER,
Event.BATCH_START, Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS,
Event.BEFORE_BACKWARD, Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START,
Event.EVAL_BEFORE_FORWARD, Event.PREDICT_START, Event.PREDICT_BATCH_START,
Event.PREDICT_BEFORE_FORWARD, Event.EVAL_STANDALONE_START)
_AFTER_EVENTS = (Event.AFTER_LOAD, Event.ITERATION_END, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER,
Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD,
Event.EVAL_AFTER_ALL, Event.EVAL_END, Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END,
Event.PREDICT_END, Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END)
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,7 +2078,7 @@ def _accumulate_time_across_ranks(

def _train_loop(self) -> None:
"""Run training for the specified number of epochs and log results."""
# print training start
# Log training start
log.info('Using precision %s', self.state.precision)
self.logger.log_hyperparameters(
{'enabled_algorithms/' + algo.__class__.__name__: True for algo in self.state.algorithms})
Expand Down

0 comments on commit f46c7c4

Please sign in to comment.