diff --git a/composer/core/state.py b/composer/core/state.py index fa4feaec75..a1bb14f0af 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -759,7 +759,7 @@ def _iteration_length(self): def _iteration_length(self, iteration_length: Optional[Union[str, Time[int]]]): """Sets the length of an iteration. - An iteration must be defined as multiple epochs. See composer/core/event.py. + An iteration must be defined as multiple epochs or tokens. See composer/core/event.py. """ if iteration_length is None: self.__iteration_length = None @@ -777,7 +777,7 @@ def stop_training(self): logging, and evaluation for that batch, as well as any epoch end events. """ # Set the max_duration to the current time in its unit, except if the unit is TimeUnit.EPOCH. This is because TimeUnit.EPOCH is a very crude way to measure max duration. For example, it will result in division by zero error while computing get_elapsed_duration: https://github.com/mosaicml/composer/blob/1b9c6d3c0592183b947fd89890de0832366e33a7/composer/core/state.py#L641 - if self.max_duration is not None and Time.from_input(self.max_duration,).unit != TimeUnit.EPOCH: + if self.max_duration is not None and Time.from_input(self.max_duration).unit != TimeUnit.EPOCH: max_duration_unit = Time.from_input(self.max_duration).unit self.max_duration = self.timestamp.get(max_duration_unit) else: