Skip to content

Commit

Permalink
Add Iteration to training loop
Browse files Browse the repository at this point in the history
Increments iteration during training based on iteration length defined
in State. Iteration length is a private variable in State and should
have no effect by default.

commit-id:bdbe33f2
  • Loading branch information
b-chu committed Mar 1, 2024
1 parent d95bbd7 commit 056afc3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
20 changes: 19 additions & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from composer.core.event import Event
from composer.core.precision import Precision
from composer.core.serializable import Serializable
from composer.core.time import Time, Timestamp, TimeUnit
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time
from composer.devices import Device
from composer.utils import (batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed,
reproducibility)
Expand Down Expand Up @@ -412,6 +412,8 @@ def __init__(
self.dataset_resumption = dataset_resumption or {}
self._max_duration = None
self.max_duration = max_duration
self.__iteration_length = None
self._iteration_length = self.__iteration_length
self.save_metrics = save_metrics

self._train_dataloader = train_dataloader
Expand Down Expand Up @@ -606,6 +608,22 @@ def get_elapsed_duration(self) -> Optional[Time[float]]:
return None
return self.timestamp.get(self.max_duration.unit) / self.max_duration

@property
def _iteration_length(self):
"""The length of an iteration."""
return self.__iteration_length

@_iteration_length.setter
def _iteration_length(self, iteration_length: Optional[Union[str, Time[int]]]):
if iteration_length is None:
self.__iteration_length = None
return
if isinstance(iteration_length, str):
iteration_length = ensure_time(iteration_length, TimeUnit.EPOCH)
if iteration_length.unit != TimeUnit.EPOCH:
raise NotImplementedError(f'{iteration_length.unit} is not allowed as a unit for iteration_length.')
self.__iteration_length = iteration_length

def stop_training(self):
"""Gracefully stop training.
Expand Down
12 changes: 11 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,7 +2078,7 @@ def _accumulate_time_across_ranks(

def _train_loop(self) -> None:
"""Run training for the specified number of epochs and log results."""
# print training start
# Log training start
log.info('Using precision %s', self.state.precision)
self.logger.log_hyperparameters(
{'enabled_algorithms/' + algo.__class__.__name__: True for algo in self.state.algorithms})
Expand Down Expand Up @@ -2109,6 +2109,9 @@ def _train_loop(self) -> None:

log.debug('Starting training loop')
while self.state.timestamp < self.state.max_duration:
if int(self.state.timestamp.epoch_in_iteration) == 0 and int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.ITERATION_START)

if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value})
Expand Down Expand Up @@ -2244,6 +2247,13 @@ def _train_loop(self) -> None:

self.engine.run_event(Event.EPOCH_CHECKPOINT)

# Increment iteration
if (self.state._iteration_length is not None and self.state.timestamp.epoch_in_iteration == self.state._iteration_length):
self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_iteration()
self.engine.run_event(Event.ITERATION_END)
self.engine.run_event(Event.ITERATION_CHECKPOINT)

# Log final time values
self.logger.log_metrics({
'time/epoch': self.state.timestamp.epoch.value,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, nu
Event.INIT: 1,
Event.BEFORE_LOAD: 1,
Event.AFTER_LOAD: 1,
Event.ITERATION_START: 0,
Event.ITERATION_START: 1,
Event.EPOCH_START: num_epochs,
Event.BATCH_START: total_steps,
Event.BEFORE_DATALOADER: total_steps + num_epochs, # extra call per epoch when dataloader is exhausted
Expand Down

0 comments on commit 056afc3

Please sign in to comment.