Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Iteration related Events #3076

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions composer/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ def fit_start(self, state: State, logger: Logger) -> None:
del state, logger # unused
pass

def iteration_start(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.ITERATION_START` event.

Args:
state (State): The training state.
logger (Logger): The logger.
"""
del state, logger # unused
pass

def epoch_start(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.EPOCH_START` event.

Expand Down Expand Up @@ -299,8 +309,14 @@ def epoch_end(self, state: State, logger: Logger) -> None:

.. note::

:attr:`.State.timestamp` member variable :attr:`.Timestamp.epoch`
is incremented immediately before :attr:`.Event.EPOCH_END`.
The following :attr:`.State.timestamp` member variables are
incremented immediately before the :attr:`.Event.EPOCH_END` event.

+--------------------------------------+
| :attr:`.Timestamp.epoch` |
+--------------------------------------+
| :attr:`.Timestamp.epoch_in_iteration`|
+--------------------------------------+

Args:
state (State): The training state.
Expand All @@ -319,6 +335,31 @@ def epoch_checkpoint(self, state: State, logger: Logger) -> None:
del state, logger # unused
pass

def iteration_end(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.ITERATION_END` event.

.. note::

:attr:`.State.timestamp` member variable :attr:`.Timestamp.iteration`
is incremented immediately before :attr:`.Event.ITERATION_END`.

Args:
state (State): The training state.
logger (Logger): The logger.
"""
del state, logger # unused
pass

def iteration_checkpoint(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.ITERATION_CHECKPOINT` event.

Args:
state (State): The training state.
logger (Logger): The logger.
"""
del state, logger # unused
pass

def predict_start(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.PREDICT_START` event.

Expand Down
20 changes: 13 additions & 7 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,11 @@ def _assert_dataloader_and_duration_set(state: State, event: Event):

# dataloader should be set on all events except INIT/BEFORE_LOAD/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END
if event not in {
Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END
Event.INIT,
b-chu marked this conversation as resolved.
Show resolved Hide resolved
Event.BEFORE_LOAD,
Event.AFTER_LOAD,
Event.EVAL_STANDALONE_START,
Event.EVAL_STANDALONE_END,
}:
assert state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.'

Expand Down Expand Up @@ -384,15 +388,17 @@ def _run_algorithms(
exit_code = algorithm.apply(event, self.state, self.logger)

trace_key = f'{algorithm}/{event}'
trace[trace_key] = Trace(name=algorithm.__class__.__name__,
event=event,
exit_code=exit_code,
order=order,
run=True)
trace[trace_key] = Trace(
name=algorithm.__class__.__name__,
event=event,
exit_code=exit_code,
order=order,
run=True,
)

if len(trace) > 0:
self.logger.log_traces(
{f'algorithm_traces/{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items()})
({f'algorithm_traces/{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items()}))

return trace

Expand Down
119 changes: 66 additions & 53 deletions composer/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,59 @@ class Event(StringEnum):
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
# <EPOCH_START>
while True:
# <BEFORE_DATALOADER>
batch = next(dataloader)
if batch is None:
break
# <AFTER_DATALOADER>
for iteration in range(NUM_ITERATIONS):
# <ITERATION_START>
for epoch in range(NUM_EPOCHS):
# <EPOCH_START>
while True:
# <BEFORE_DATALOADER>
batch = next(dataloader)
if batch is None:
break
# <AFTER_DATALOADER>

# <BATCH_START>
# <BATCH_START>

# <BEFORE_TRAIN_BATCH>
# <BEFORE_TRAIN_BATCH>

for microbatch in batch.split(device_train_microbatch_size):
for microbatch in batch.split(device_train_microbatch_size):

# <BEFORE_FORWARD>
outputs = model(batch)
# <AFTER_FORWARD>
# <BEFORE_FORWARD>
outputs = model(batch)
# <AFTER_FORWARD>

# <BEFORE_LOSS>
loss = model.loss(outputs, batch)
# <AFTER_LOSS>
# <BEFORE_LOSS>
loss = model.loss(outputs, batch)
# <AFTER_LOSS>

# <BEFORE_BACKWARD>
loss.backward()
# <AFTER_BACKWARD>
# <BEFORE_BACKWARD>
loss.backward()
# <AFTER_BACKWARD>

# Un-scale gradients
# Un-scale gradients

# <AFTER_TRAIN_BATCH>
optimizer.step()
# <AFTER_TRAIN_BATCH>
optimizer.step()

# <BATCH_END>
# <BATCH_END>

# <BEFORE_EVAL_ALL>
for eval_dataloader in eval_dataloaders:
if should_eval(batch=True):
# <EVAL_START>
for batch in eval_dataloader:
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
outputs, targets = model(batch)
# <EVAL_AFTER_FORWARD>
metrics.update(outputs, targets)
# <EVAL_BATCH_END>
# <EVAL_END>

# <AFTER_EVAL_ALL>

# <BATCH_CHECKPOINT>
# <EPOCH_END>

# <BEFORE_EVAL_ALL>
for eval_dataloader in eval_dataloaders:
Expand All @@ -70,25 +90,9 @@ class Event(StringEnum):

# <AFTER_EVAL_ALL>

# <BATCH_CHECKPOINT>
# <EPOCH_END>

# <BEFORE_EVAL_ALL>
for eval_dataloader in eval_dataloaders:
if should_eval(batch=True):
# <EVAL_START>
for batch in eval_dataloader:
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
outputs, targets = model(batch)
# <EVAL_AFTER_FORWARD>
metrics.update(outputs, targets)
# <EVAL_BATCH_END>
# <EVAL_END>

# <AFTER_EVAL_ALL>

# <EPOCH_CHECKPOINT>
# <EPOCH_CHECKPOINT>
# <ITERATION_END>
# <ITERATION_CHECKPOINT>
# <FIT_END>

Attributes:
Expand All @@ -98,6 +102,7 @@ class Event(StringEnum):
AFTER_LOAD: Immediately after checkpoint is loaded in constructor of :class:`~.trainer.Trainer`.
FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically
occur here.
ITERATION_START: Start of an iteration.
EPOCH_START: Start of an epoch.
BEFORE_DATALOADER: Immediately before the dataloader is called.
AFTER_DATALOADER: Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms.
Expand Down Expand Up @@ -125,7 +130,10 @@ class Event(StringEnum):
EPOCH_END: End of an epoch.
EPOCH_CHECKPOINT: After :attr:`.Event.EPOCH_END` and any epoch-wise evaluation. Saving checkpoints at this
event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether
a checkpointshould be saved.
a checkpoint should be saved.
ITERATION_END: End of an iteration.
ITERATION_CHECKPOINT: After :attr:`.Event.ITERATION_END`. Saving checkpoints at this event allows the checkpoint
saver to determine whether a checkpoint should be saved.
FIT_END: Invoked at the end of each call to :meth:`.Trainer.fit`. This event exists primarily for logging information
and flushing callbacks. Algorithms should not transform the training state on this event, as any changes will not
be preserved in checkpoints.
Expand All @@ -148,6 +156,8 @@ class Event(StringEnum):
AFTER_LOAD = 'after_load'
FIT_START = 'fit_start'

ITERATION_START = 'iteration_start'

EPOCH_START = 'epoch_start'

BEFORE_DATALOADER = 'before_dataloader'
Expand All @@ -174,6 +184,9 @@ class Event(StringEnum):
EPOCH_END = 'epoch_end'
EPOCH_CHECKPOINT = 'epoch_checkpoint'

ITERATION_END = 'iteration_end'
ITERATION_CHECKPOINT = 'iteration_checkpoint'

FIT_END = 'fit_end'

EVAL_BEFORE_ALL = 'eval_before_all'
Expand Down Expand Up @@ -246,12 +259,12 @@ def is_eval(self) -> bool:
return self.value.startswith('eval')


_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START,
Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD,
Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, Event.EVAL_BEFORE_FORWARD,
Event.PREDICT_START, Event.PREDICT_BATCH_START, Event.PREDICT_BEFORE_FORWARD,
Event.EVAL_STANDALONE_START)
_AFTER_EVENTS = (Event.AFTER_LOAD, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH,
Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END,
Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END,
Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END)
_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.ITERATION_START, Event.EPOCH_START, Event.BEFORE_DATALOADER,
Event.BATCH_START, Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS,
Event.BEFORE_BACKWARD, Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START,
Event.EVAL_BEFORE_FORWARD, Event.PREDICT_START, Event.PREDICT_BATCH_START,
Event.PREDICT_BEFORE_FORWARD, Event.EVAL_STANDALONE_START)
_AFTER_EVENTS = (Event.AFTER_LOAD, Event.ITERATION_END, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER,
Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD,
Event.EVAL_AFTER_ALL, Event.EVAL_END, Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END,
Event.PREDICT_END, Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END)
20 changes: 19 additions & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from composer.core.event import Event
from composer.core.precision import Precision
from composer.core.serializable import Serializable
from composer.core.time import Time, Timestamp, TimeUnit
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time
from composer.devices import Device
from composer.utils import (batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed,
reproducibility)
Expand Down Expand Up @@ -412,6 +412,8 @@ def __init__(
self.dataset_resumption = dataset_resumption or {}
self._max_duration = None
self.max_duration = max_duration
self.__iteration_length = None
self._iteration_length = self.__iteration_length
self.save_metrics = save_metrics

self._train_dataloader = train_dataloader
Expand Down Expand Up @@ -606,6 +608,22 @@ def get_elapsed_duration(self) -> Optional[Time[float]]:
return None
return self.timestamp.get(self.max_duration.unit) / self.max_duration

@property
def _iteration_length(self):
"""The length of an iteration."""
return self.__iteration_length

@_iteration_length.setter
def _iteration_length(self, iteration_length: Optional[Union[str, Time[int]]]):
b-chu marked this conversation as resolved.
Show resolved Hide resolved
if iteration_length is None:
self.__iteration_length = None
return
if isinstance(iteration_length, str):
iteration_length = ensure_time(iteration_length, TimeUnit.EPOCH)
if iteration_length.unit != TimeUnit.EPOCH:
raise NotImplementedError(f'{iteration_length.unit} is not allowed as a unit for iteration_length.')
self.__iteration_length = iteration_length

def stop_training(self):
"""Gracefully stop training.

Expand Down
13 changes: 12 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,7 +2078,7 @@ def _accumulate_time_across_ranks(

def _train_loop(self) -> None:
"""Run training for the specified number of epochs and log results."""
# print training start
# Log training start
log.info('Using precision %s', self.state.precision)
self.logger.log_hyperparameters(
{'enabled_algorithms/' + algo.__class__.__name__: True for algo in self.state.algorithms})
Expand Down Expand Up @@ -2109,6 +2109,9 @@ def _train_loop(self) -> None:

log.debug('Starting training loop')
while self.state.timestamp < self.state.max_duration:
if int(self.state.timestamp.epoch_in_iteration) == 0 and int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.ITERATION_START)

if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value})
Expand Down Expand Up @@ -2244,6 +2247,14 @@ def _train_loop(self) -> None:

self.engine.run_event(Event.EPOCH_CHECKPOINT)

# Increment iteration
if (self.state._iteration_length is not None and
self.state.timestamp.epoch_in_iteration == self.state._iteration_length):
self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_iteration()
b-chu marked this conversation as resolved.
Show resolved Hide resolved
self.engine.run_event(Event.ITERATION_END)
self.engine.run_event(Event.ITERATION_CHECKPOINT)

# Log final time values
self.logger.log_metrics({
'time/epoch': self.state.timestamp.epoch.value,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, nu
Event.INIT: 1,
Event.BEFORE_LOAD: 1,
Event.AFTER_LOAD: 1,
Event.ITERATION_START: 1,
Event.EPOCH_START: num_epochs,
Event.BATCH_START: total_steps,
Event.BEFORE_DATALOADER: total_steps + num_epochs, # extra call per epoch when dataloader is exhausted
Expand All @@ -168,6 +169,8 @@ def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, nu
Event.BATCH_CHECKPOINT: total_steps,
Event.EPOCH_END: num_epochs,
Event.EPOCH_CHECKPOINT: num_epochs,
Event.ITERATION_END: 0,
Event.ITERATION_CHECKPOINT: 0,
Event.EVAL_BEFORE_ALL: total_evals,
Event.EVAL_START: total_evals_start,
Event.EVAL_BATCH_START: total_eval_steps,
Expand Down
19 changes: 19 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,25 @@ def test_compile_uncompile_model_weights_trainer_fit(
assert (torch.equal(next(compiled_model_trainer.state.model.parameters()),
next(uncompiled_model_trainer.state.model.parameters())))

def test_iteration(
self,
train_dataloader: DataLoader,
model: ComposerModel,
):
"""Tests iteration is properly incremented during training when _iteration_length is set."""

# Train with max_duration set to 5 epochs with 2 epoch per iteration
trainer = Trainer(
model=model,
max_duration='5ep',
train_dataloader=train_dataloader,
)
trainer.state._iteration_length = '2ep'
trainer.fit()

assert trainer.state.timestamp.epoch == Time(5, TimeUnit.EPOCH)
assert trainer.state.timestamp.iteration == Time(2, TimeUnit.ITERATION)
b-chu marked this conversation as resolved.
Show resolved Hide resolved


@world_size(1, 2)
@device('cpu', 'gpu', 'gpu-amp', precision=True)
Expand Down
Loading