From a456fa09fb20c08d5ed721a4eae4287836011abf Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Tue, 27 Feb 2024 13:08:12 -0500 Subject: [PATCH] Deprecate get_state and remove deprecations (#3017) * Add duration to to_next_epoch Adds duration to the to_next_epoch function to increment total_wct. commit-id:5f8ef9be * Add iteration to TimeUnit commit-id:e4729f79 * Deprecate get_state and remove deprecations commit-id:5d87db4b --- composer/callbacks/mlperf.py | 18 +- composer/callbacks/utils.py | 43 ----- composer/core/state.py | 3 +- composer/core/time.py | 162 ++++++++++++++++-- composer/loggers/in_memory_logger.py | 5 +- tests/algorithms/test_algorithm_resumption.py | 2 + tests/common/state.py | 1 + tests/test_time.py | 72 ++++++-- tests/trainer/test_checkpoint.py | 2 + 9 files changed, 218 insertions(+), 90 deletions(-) delete mode 100644 composer/callbacks/utils.py diff --git a/composer/callbacks/mlperf.py b/composer/callbacks/mlperf.py index ddd2f02a76..970e08cacf 100644 --- a/composer/callbacks/mlperf.py +++ b/composer/callbacks/mlperf.py @@ -263,22 +263,8 @@ def _get_dataloader_stats(self, dataloader: Iterable): if isinstance(dataloader.dataset, IterableDataset): num_samples *= dist.get_world_size() return (dataloader.batch_size, num_samples) - try: - # attempt to import ffcv and test if its an ffcv loader. - import ffcv # type: ignore - - warnings.warn(DeprecationWarning('ffcv is deprecated and will be removed in v0.18')) - - if isinstance(dataloader, ffcv.loader.Loader): - # Use the cached attribute ffcv.init_traversal_order to compute number of samples - return ( - dataloader.batch_size, # type: ignore - len(dataloader.next_traversal_order()) * dist.get_world_size() # type: ignore - ) - except ImportError: - pass - - raise TypeError(f'torch dataloader or ffcv dataloader required (and ffcv installed)') + + raise TypeError(f'torch dataloader required') def fit_start(self, state: State, logger: Logger) -> None: if _global_rank_zero(): diff --git a/composer/callbacks/utils.py b/composer/callbacks/utils.py deleted file mode 100644 index 7a4097cecf..0000000000 --- a/composer/callbacks/utils.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -"""Callback utils.""" - -import warnings -from typing import Callable, Optional, Set, Union - -from composer.core import Event, State, Time -from composer.utils.misc import create_interval_scheduler as _create_interval_scheduler - - -def create_interval_scheduler(interval: Union[str, int, Time], - include_end_of_training: bool = True, - checkpoint_events: bool = True, - final_events: Optional[Set[Event]] = None) -> Callable[[State, Event], bool]: - """Helper function to create a scheduler according to a specified interval. - - Args: - interval (Union[str, int, :class:`.Time`]): If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. - Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, - :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. - include_end_of_training (bool): If true, the returned callable will return true at the end of training as well. - Otherwise, the returned callable will return true at intervals only. - checkpoint_events (bool): If true, will use the EPOCH_CHECKPOINT and BATCH_CHECKPOINT events. If False, will use - the EPOCH_END and BATCH_END events. - final_events (Optional[Set[Event]]): The set of events to trigger on at the end of training. - - Returns: - Callable[[State, Event], bool]: A function that returns true at interval and at the end of training if specified. - For example, it can be passed as the ``save_interval`` argument into the :class:`.CheckpointSaver`. - """ - warnings.warn( - '`composer.callbacks.utils.create_interval_scheduler has been moved to `composer.utils.misc.create_interval_scheduler` ' - + 'and will be removed in a future release.', - DeprecationWarning, - ) - return _create_interval_scheduler( - interval=interval, - include_end_of_training=include_end_of_training, - checkpoint_events=checkpoint_events, - final_events=final_events, - ) diff --git a/composer/core/state.py b/composer/core/state.py index cc97cb8391..f99f6050ef 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -733,7 +733,8 @@ def fsdp_sharded_state_dict_enabled(self): @property def fsdp_elastic_sharded_enabled(self): - warnings.warn('state.fsdp_elastic_sharded_enabled is deprecated and will be removed v0.21.0') + warnings.warn('state.fsdp_elastic_sharded_enabled is deprecated and will be removed v0.21.0', + DeprecationWarning) return self.fsdp_sharded_state_dict_enabled @property diff --git a/composer/core/time.py b/composer/core/time.py index ab2d6a60ee..c034263ec2 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -19,6 +19,7 @@ import datetime import re +import warnings from typing import Any, Dict, Generic, Optional, TypeVar, Union, cast from composer.core.serializable import Serializable @@ -31,12 +32,14 @@ class TimeUnit(StringEnum): """Enum class to represent units of time for the training process. Attributes: + ITERATION (str): Iterations. EPOCH (str): Epochs. BATCH (str): Batches (i.e. number of optimization steps) SAMPLE (str): Samples. TOKEN (str): Tokens. Applicable for natural language processing (NLP) models. DURATION (str): Fraction of the training process complete, on ``[0.0, 1.0)`` """ + ITERATION = 'iter' EPOCH = 'ep' BATCH = 'ba' SAMPLE = 'sp' @@ -122,6 +125,20 @@ def __init__( raise TypeError(f'value {value} is of type {type(value)}. Units {unit} require integer values.') self._value, self._unit = value, TimeUnit(unit) + @classmethod + def from_iteration(cls, iteration: int) -> Time: + """Create a :class:`Time` with units of :attr:`TimeUnit.ITERATION`. + + Equivalent to ``Time(iteration, TimeUnit.ITERATION)``. + + Args: + iteration (int): Number of iterations. + + Returns: + Time: :class:`Time` instance, in iterations. + """ + return cls(iteration, TimeUnit.ITERATION) + @classmethod def from_epoch(cls, epoch: int) -> Time: """Create a :class:`Time` with units of :attr:`TimeUnit.EPOCH`. @@ -391,37 +408,48 @@ def from_timestring(cls, timestring: str) -> Time: class Timestamp(Serializable): """Timestamp represents a snapshot of the current training progress. - The timestamp measures training progress in terms of epochs, batches, samples, tokens, and wall clock time. + The timestamp measures training progress in terms of iterations, epochs, batches, samples, tokens, and wall clock time. Timestamps are not updated in-place. See the :doc:`Time Guide ` for more details on tracking time during training. Args: + iteration (int | Time[int], optional): The iteration. epoch (int | Time[int], optional): The epoch. batch (int | Time[int], optional): the batch. sample (int | Time[int], optional): The sample. token (int | Time[int], optional): The token. + epoch_in_iteration (int | Time[int], optional): The epoch in the iteration. batch_in_epoch (int | Time[int], optional): The batch in the epoch. sample_in_epoch (int | Time[int], optional): The sample in the epoch. token_in_epoch (int | Time[int], optional): The token in the epoch. total_wct (datetime.timedelta, optional): The total wall-clock duration. - epoch_wct (datetime.timedelta, optional): The wall-clock duration of the last epoch. + iteration_wct (datetime.timedelta, optional): The wall-clock duration of the current iteration. + epoch_wct (datetime.timedelta, optional): The wall-clock duration of the current epoch. batch_wct (datetime.timedelta, optional): The wall-clock duration of the last batch. """ def __init__( self, + iteration: Union[int, Time[int]] = 0, epoch: Union[int, Time[int]] = 0, batch: Union[int, Time[int]] = 0, sample: Union[int, Time[int]] = 0, token: Union[int, Time[int]] = 0, + epoch_in_iteration: Union[int, Time[int]] = 0, batch_in_epoch: Union[int, Time[int]] = 0, sample_in_epoch: Union[int, Time[int]] = 0, token_in_epoch: Union[int, Time[int]] = 0, total_wct: Optional[datetime.timedelta] = None, + iteration_wct: Optional[datetime.timedelta] = None, epoch_wct: Optional[datetime.timedelta] = None, batch_wct: Optional[datetime.timedelta] = None, ): + iteration = Time.from_input(iteration, TimeUnit.ITERATION) + if iteration.unit != TimeUnit.ITERATION: + raise ValueError(f'The `iteration` argument has units of {iteration.unit}; not {TimeUnit.ITERATION}.') + self._iteration = iteration + epoch = Time.from_input(epoch, TimeUnit.EPOCH) if epoch.unit != TimeUnit.EPOCH: raise ValueError(f'The `epoch` argument has units of {epoch.unit}; not {TimeUnit.EPOCH}.') @@ -442,6 +470,12 @@ def __init__( raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.') self._token = token + epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.EPOCH) + if epoch_in_iteration.unit != TimeUnit.EPOCH: + raise ValueError((f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; ' + f'not {TimeUnit.EPOCH}.')) + self._epoch_in_iteration = epoch_in_iteration + batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH) if batch_in_epoch.unit != TimeUnit.BATCH: raise ValueError((f'The `batch_in_epoch` argument has units of {batch_in_epoch.unit}; ' @@ -464,6 +498,10 @@ def __init__( total_wct = datetime.timedelta(seconds=0) self._total_wct = total_wct + if iteration_wct is None: + iteration_wct = datetime.timedelta(seconds=0) + self._iteration_wct = iteration_wct + if epoch_wct is None: epoch_wct = datetime.timedelta(seconds=0) self._epoch_wct = epoch_wct @@ -474,14 +512,17 @@ def __init__( def state_dict(self) -> Dict[str, Any]: return { + 'iteration': self.iteration.value, 'epoch': self.epoch.value, 'batch': self.batch.value, 'sample': self.sample.value, 'token': self.token.value, + 'epoch_in_iteration': self.epoch_in_iteration.value, 'batch_in_epoch': self.batch_in_epoch.value, 'sample_in_epoch': self.sample_in_epoch.value, 'token_in_epoch': self.token_in_epoch.value, 'total_wct': self.total_wct, + 'iteration_wct': self.iteration_wct, 'epoch_wct': self.epoch_wct, 'batch_wct': self.batch_wct, } @@ -492,18 +533,8 @@ def get_state(self) -> Dict[str, Union[Time[int], datetime.timedelta]]: Returns: Dict[str, Union[Time[int], datetime.timedelta]]: All values of the timestamp object. """ - return { - 'epoch': self.epoch, - 'batch': self.batch, - 'sample': self.sample, - 'token': self.token, - 'batch_in_epoch': self.batch_in_epoch, - 'sample_in_epoch': self.sample_in_epoch, - 'token_in_epoch': self.token_in_epoch, - 'total_wct': self.total_wct, - 'epoch_wct': self.epoch_wct, - 'batch_wct': self.batch_wct, - } + warnings.warn('core.time.Timestamp.get_state is deprecated and will be removed v0.21.0', DeprecationWarning) + return self.state_dict() def load_state_dict(self, state: Dict[str, Any]) -> None: self._epoch = Time(state['epoch'], TimeUnit.EPOCH) @@ -513,14 +544,26 @@ def load_state_dict(self, state: Dict[str, Any]) -> None: self._batch_in_epoch = Time(state['batch_in_epoch'], TimeUnit.BATCH) self._sample_in_epoch = Time(state['sample_in_epoch'], TimeUnit.SAMPLE) self._token_in_epoch = Time(state['token_in_epoch'], TimeUnit.TOKEN) - # Wall clock time tracking was added in composer v0.7.0 # Using conditional checks as not to break old checkpoints + # Wall clock time tracking was added in composer v0.7.0 if 'total_wct' in state: self._total_wct = state['total_wct'] if 'epoch_wct' in state: self._epoch_wct = state['epoch_wct'] if 'batch_wct' in state: self._batch_wct = state['batch_wct'] + # Iteration was added in composer v0.19.1 + if 'iteration' in state: + self._iteration = Time(state['iteration'], TimeUnit.ITERATION) + if 'epoch_in_iteration' in state: + self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH) + if 'iteration_wct' in state: + self._iteration_wct = state['iteration_wct'] + + @property + def iteration(self) -> Time[int]: + """The total iteration count.""" + return self._iteration @property def epoch(self) -> Time[int]: @@ -542,6 +585,11 @@ def token(self) -> Time[int]: """The total token count.""" return self._token + @property + def epoch_in_iteration(self) -> Time[int]: + """The epoch count in the current iteration (resets at 0 at the beginning of every iteration).""" + return self._epoch_in_iteration + @property def batch_in_epoch(self) -> Time[int]: """The batch count in the current epoch (resets at 0 at the beginning of every epoch).""" @@ -562,6 +610,11 @@ def total_wct(self) -> datetime.timedelta: """The wall-clock duration (in seconds) from the beginning of training.""" return self._total_wct + @property + def iteration_wct(self) -> datetime.timedelta: + """The wall-clock duration (in seconds) for the current iteration.""" + return self._iteration_wct + @property def epoch_wct(self) -> datetime.timedelta: """The wall-clock duration (in seconds) for the current epoch.""" @@ -582,6 +635,8 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]: Time: The current time, in the specified unit. """ unit = TimeUnit(unit) + if unit == TimeUnit.ITERATION: + return self.iteration if unit == TimeUnit.EPOCH: return self.epoch if unit == TimeUnit.BATCH: @@ -678,6 +733,7 @@ def to_next_batch( ... token = timestamp.token + tokens, ... token_in_epoch=timestamp.token_in_epoch + tokens, ... total_wct=timestamp.total_wct + duration, + ... iteration_wct=timestamp.iteration_wct + duration, ... epoch_wct=timestamp.epoch_wct + duration, ... batch_wct=duration, ... ) @@ -705,11 +761,15 @@ def to_next_batch( token=self.token + tokens, token_in_epoch=self.token_in_epoch + tokens, total_wct=self.total_wct + duration, + iteration_wct=self.iteration_wct + duration, epoch_wct=self.epoch_wct + duration, batch_wct=duration, ) - def to_next_epoch(self): + def to_next_epoch( + self, + duration: Optional[datetime.timedelta] = None, + ): """Create a new :class:`.Timestamp`, advanced to the next epoch. Equivalent to: @@ -720,39 +780,96 @@ def to_next_epoch(self): import datetime timestamp = Timestamp() + duration = datetime.timedelta(seconds=0) .. doctest:: >>> timestamp.copy( - ... epoch=timestamp.epoch+1, + ... epoch=timestamp.epoch + 1, + ... epoch_in_iteration=timestamp.epoch_in_iteration + 1, ... batch_in_epoch=0, ... sample_in_epoch=0, ... token_in_epoch=0, + ... total_wct=timestamp.total_wct + duration, + ... iteration_wct=timestamp.iteration_wct + duration, ... epoch_wct=datetime.timedelta(seconds=0), ... batch_wct=datetime.timedelta(seconds=0), ... ) Timestamp(...) """ + if duration is None: + duration = datetime.timedelta(seconds=0) return self.copy( epoch=self.epoch + 1, + epoch_in_iteration=self.epoch_in_iteration + 1, batch_in_epoch=0, sample_in_epoch=0, token_in_epoch=0, + total_wct=self.total_wct + duration, + iteration_wct=self.iteration_wct + duration, + epoch_wct=datetime.timedelta(seconds=0), + batch_wct=datetime.timedelta(seconds=0), + ) + + def to_next_iteration( + self, + duration: Optional[datetime.timedelta] = None, + ): + """Create a new :class:`.Timestamp`, advanced to the next iteration. + + Equivalent to: + + .. testsetup:: + + from composer.core.time import Timestamp + import datetime + + timestamp = Timestamp() + + .. doctest:: + + >>> timestamp.copy( + ... iteration=timestamp.iteration + 1, + ... epoch_in_iteration=0, + ... batch_in_epoch=0, + ... sample_in_epoch=0, + ... token_in_epoch=0, + ... total_wct=timestamp.total_wct + duration, + ... iteration_wct=datetime.timedelta(seconds=0), + ... epoch_wct=datetime.timedelta(seconds=0), + ... batch_wct=datetime.timedelta(seconds=0), + ... ) + Timestamp(...) + + """ + if duration is None: + duration = datetime.timedelta(seconds=0) + return self.copy( + iteration=self.iteration + 1, + epoch_in_iteration=0, + batch_in_epoch=0, + sample_in_epoch=0, + token_in_epoch=0, + total_wct=self.total_wct + duration, + iteration_wct=datetime.timedelta(seconds=0), epoch_wct=datetime.timedelta(seconds=0), batch_wct=datetime.timedelta(seconds=0), ) def copy( self, + iteration: Optional[Union[int, Time[int]]] = None, epoch: Optional[Union[int, Time[int]]] = None, batch: Optional[Union[int, Time[int]]] = None, sample: Optional[Union[int, Time[int]]] = None, token: Optional[Union[int, Time[int]]] = None, + epoch_in_iteration: Optional[Union[int, Time[int]]] = None, batch_in_epoch: Optional[Union[int, Time[int]]] = None, sample_in_epoch: Optional[Union[int, Time[int]]] = None, token_in_epoch: Optional[Union[int, Time[int]]] = None, total_wct: Optional[datetime.timedelta] = None, + iteration_wct: Optional[datetime.timedelta] = None, epoch_wct: Optional[datetime.timedelta] = None, batch_wct: Optional[datetime.timedelta] = None, ) -> Timestamp: @@ -761,42 +878,53 @@ def copy( Any specified values will override the existing values in the returned copy. Args: + iteration (int | Time[int], optional): The iteration. epoch (int | Time[int], optional): The epoch. batch (int | Time[int], optional): the batch. sample (int | Time[int], optional): The sample. token (int | Time[int], optional): The token. + epoch_in_iteration (int | Time[int], optional): The epoch in the iteration. batch_in_epoch (int | Time[int], optional): The batch in the epoch. sample_in_epoch (int | Time[int], optional): The sample in the epoch. token_in_epoch (int | Time[int], optional): The token in the epoch. total_wct (datetime.timedelta, optional): The elapsed duration from the beginning of training. + iteration_wct (datetime.timedelta, optional): The wall-clock duration of the current iteration. + epoch_wct (datetime.timedelta, optional): The wall-clock duration of the current epoch. + batch_wct (datetime.timedelta, optional): The wall-clock duration of the last batch. Returns: Timestamp: A new timestamp instance, created from a copy, but with any specified values overriding the existing values. """ return Timestamp( + iteration=iteration if iteration is not None else self.iteration, epoch=epoch if epoch is not None else self.epoch, batch=batch if batch is not None else self.batch, sample=sample if sample is not None else self.sample, token=token if token is not None else self.token, + epoch_in_iteration=epoch_in_iteration if epoch_in_iteration is not None else self.epoch_in_iteration, batch_in_epoch=batch_in_epoch if batch_in_epoch is not None else self.batch_in_epoch, sample_in_epoch=sample_in_epoch if sample_in_epoch is not None else self.sample_in_epoch, token_in_epoch=token_in_epoch if token_in_epoch is not None else self.token_in_epoch, total_wct=total_wct if total_wct is not None else self.total_wct, + iteration_wct=iteration_wct if iteration_wct is not None else self.iteration_wct, epoch_wct=epoch_wct if epoch_wct is not None else self.epoch_wct, batch_wct=batch_wct if batch_wct is not None else self.batch_wct, ) def __repr__(self) -> str: return (f'Timestamp(' + f'iteration={int(self.iteration)}, ' f'epoch={int(self.epoch)}, ' f'batch={int(self.batch)}, ' f'sample={int(self.sample)}, ' f'token={int(self.token)}, ' + f'epoch_in_iteration={int(self.epoch_in_iteration)}, ' f'batch_in_epoch={int(self.batch_in_epoch)}, ' f'sample_in_epoch={int(self.sample_in_epoch)}, ' f'token_in_epoch={int(self.token_in_epoch)}, ' f'total_wct={repr(self.total_wct)}, ' + f'iteration_wct={repr(self.iteration_wct)}, ' f'epoch_wct={repr(self.epoch_wct)}, ' f'batch_wct={repr(self.batch_wct)}' ')') diff --git a/composer/loggers/in_memory_logger.py b/composer/loggers/in_memory_logger.py index 8f5a2c0ea3..ae11beb755 100644 --- a/composer/loggers/in_memory_logger.py +++ b/composer/loggers/in_memory_logger.py @@ -14,7 +14,6 @@ import numpy as np from torch import Tensor -from composer.core.time import Time from composer.loggers.logger import Logger from composer.loggers.logger_destination import LoggerDestination from composer.utils.import_helpers import MissingConditionalImportError @@ -157,8 +156,8 @@ def get_timeseries(self, metric: str) -> Dict[str, Any]: timestamp, metric_value = datapoint timeseries.setdefault(metric, []).append(metric_value) # Iterate through time units and add them all! - for field, time in timestamp.get_state().items(): - time_value = time.value if isinstance(time, Time) else time.total_seconds() + for field, time in timestamp.state_dict().items(): + time_value = time if isinstance(time, int) else time.total_seconds() timeseries.setdefault(field, []).append(time_value) # Convert to numpy arrays for k, v in timeseries.items(): diff --git a/tests/algorithms/test_algorithm_resumption.py b/tests/algorithms/test_algorithm_resumption.py index 9f243caeae..68fa90baf6 100644 --- a/tests/algorithms/test_algorithm_resumption.py +++ b/tests/algorithms/test_algorithm_resumption.py @@ -127,9 +127,11 @@ def _assert_checkpoints_equal(file1, file2): # compare state # remove the wall clock time fields since they will always differ del checkpoint1['state']['timestamp']['Timestamp']['total_wct'] + del checkpoint1['state']['timestamp']['Timestamp']['iteration_wct'] del checkpoint1['state']['timestamp']['Timestamp']['epoch_wct'] del checkpoint1['state']['timestamp']['Timestamp']['batch_wct'] del checkpoint2['state']['timestamp']['Timestamp']['total_wct'] + del checkpoint2['state']['timestamp']['Timestamp']['iteration_wct'] del checkpoint2['state']['timestamp']['Timestamp']['epoch_wct'] del checkpoint2['state']['timestamp']['Timestamp']['batch_wct'] diff --git a/tests/common/state.py b/tests/common/state.py index 3a9d17eac3..1dea9a55f6 100644 --- a/tests/common/state.py +++ b/tests/common/state.py @@ -10,6 +10,7 @@ def _del_wct_timestamp_fields(timestamp_state_dict: Dict[str, Any]): del timestamp_state_dict['Timestamp']['total_wct'] + del timestamp_state_dict['Timestamp']['iteration_wct'] del timestamp_state_dict['Timestamp']['epoch_wct'] del timestamp_state_dict['Timestamp']['batch_wct'] diff --git a/tests/test_time.py b/tests/test_time.py index 58f1cf9747..f259444633 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -9,6 +9,7 @@ @pytest.mark.parametrize('time_string,expected_value,expected_unit', [ + ['2iter', 2, TimeUnit.ITERATION], ['1ep', 1, TimeUnit.EPOCH], ['2ba', 2, TimeUnit.BATCH], ['3e10sp', 3 * 10**10, TimeUnit.SAMPLE], @@ -25,6 +26,7 @@ def test_time_parse(time_string: str, expected_value: int, expected_unit: TimeUn @pytest.mark.parametrize('expected_timestring,time', [ + ['2iter', Time(2, TimeUnit.ITERATION)], ['1ep', Time(1, TimeUnit.EPOCH)], ['2ba', Time(2, TimeUnit.BATCH)], ['3sp', Time(3, TimeUnit.SAMPLE)], @@ -136,9 +138,9 @@ def test_timestamp_update(): assert timestamp is not timestamp_2 -def test_timestamp_to_next_batch_epoch(): +def test_timestamp_to_next_batch_epoch_iteration(): timestamp = Timestamp() - # Step batch 0, epoch 0 + # Step batch 0 in epoch 0 timestamp = timestamp.to_next_batch(10, 20, datetime.timedelta(seconds=5)) assert timestamp.batch == 1 assert timestamp.batch_in_epoch == 1 @@ -148,11 +150,12 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.token == 20 assert timestamp.token_in_epoch == 20 assert timestamp.total_wct == datetime.timedelta(seconds=5) + assert timestamp.iteration_wct == datetime.timedelta(seconds=5) assert timestamp.epoch_wct == datetime.timedelta(seconds=5) assert timestamp.batch_wct == datetime.timedelta(seconds=5) # Finish epoch 0 - timestamp = timestamp.to_next_epoch() + timestamp = timestamp.to_next_epoch(datetime.timedelta(seconds=5)) assert timestamp.epoch == 1 assert timestamp.batch == 1 assert timestamp.batch_in_epoch == 0 @@ -160,24 +163,27 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.sample_in_epoch == 0 assert timestamp.token == 20 assert timestamp.token_in_epoch == 0 - assert timestamp.total_wct == datetime.timedelta(seconds=5) + assert timestamp.total_wct == datetime.timedelta(seconds=10) + assert timestamp.iteration_wct == datetime.timedelta(seconds=10) assert timestamp.epoch_wct == datetime.timedelta(seconds=0) assert timestamp.batch_wct == datetime.timedelta(seconds=0) - # Step a batch 0 in epoch 1 + # Step batch 0 in epoch 1 timestamp = timestamp.to_next_batch(5, 0, datetime.timedelta(seconds=10)) assert timestamp.epoch == 1 assert timestamp.batch == 2 + assert timestamp.epoch_in_iteration == 1 assert timestamp.batch_in_epoch == 1 assert timestamp.sample == 15 assert timestamp.sample_in_epoch == 5 assert timestamp.token == 20 assert timestamp.token_in_epoch == 0 - assert timestamp.total_wct == datetime.timedelta(seconds=15) + assert timestamp.total_wct == datetime.timedelta(seconds=20) + assert timestamp.iteration_wct == datetime.timedelta(seconds=20) assert timestamp.epoch_wct == datetime.timedelta(seconds=10) assert timestamp.batch_wct == datetime.timedelta(seconds=10) - # Step batch 1 in epoch 0 + # Step batch 1 in epoch 1 timestamp = timestamp.to_next_batch(5, 1, datetime.timedelta(seconds=10)) assert timestamp.epoch == 1 assert timestamp.batch == 3 @@ -186,22 +192,68 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.sample_in_epoch == 10 assert timestamp.token == 21 assert timestamp.token_in_epoch == 1 - assert timestamp.total_wct == datetime.timedelta(seconds=25) + assert timestamp.total_wct == datetime.timedelta(seconds=30) + assert timestamp.iteration_wct == datetime.timedelta(seconds=30) assert timestamp.epoch_wct == datetime.timedelta(seconds=20) assert timestamp.batch_wct == datetime.timedelta(seconds=10) + # Finish epoch 1 + timestamp = timestamp.to_next_epoch() + assert timestamp.epoch == 2 + assert timestamp.batch == 3 + assert timestamp.epoch_in_iteration == 2 + assert timestamp.batch_in_epoch == 0 + assert timestamp.sample == 20 + assert timestamp.sample_in_epoch == 0 + assert timestamp.token == 21 + assert timestamp.token_in_epoch == 0 + assert timestamp.total_wct == datetime.timedelta(seconds=30) + assert timestamp.epoch_wct == datetime.timedelta(seconds=0) + assert timestamp.batch_wct == datetime.timedelta(seconds=0) + + # Step batch 0 in epoch 2 + timestamp = timestamp.to_next_batch(5, 1, datetime.timedelta(seconds=10)) + assert timestamp.epoch == 2 + assert timestamp.batch == 4 + assert timestamp.epoch_in_iteration == 2 + assert timestamp.batch_in_epoch == 1 + assert timestamp.sample == 25 + assert timestamp.sample_in_epoch == 5 + assert timestamp.token == 22 + assert timestamp.token_in_epoch == 1 + assert timestamp.total_wct == datetime.timedelta(seconds=40) + assert timestamp.iteration_wct == datetime.timedelta(seconds=40) + assert timestamp.epoch_wct == datetime.timedelta(seconds=10) + assert timestamp.batch_wct == datetime.timedelta(seconds=10) + + # Finish iteration 0 + timestamp = timestamp.to_next_iteration() + assert timestamp.iteration == 1 + assert timestamp.epoch == 2 + assert timestamp.batch == 4 + assert timestamp.epoch_in_iteration == 0 + assert timestamp.batch_in_epoch == 0 + assert timestamp.sample == 25 + assert timestamp.sample_in_epoch == 0 + assert timestamp.token == 22 + assert timestamp.token_in_epoch == 0 + assert timestamp.total_wct == datetime.timedelta(seconds=40) + assert timestamp.iteration_wct == datetime.timedelta(seconds=0) + assert timestamp.epoch_wct == datetime.timedelta(seconds=0) + assert timestamp.batch_wct == datetime.timedelta(seconds=0) + def test_timestamp_repr(): timestamp = Timestamp() assert timestamp == eval(repr(timestamp)) -@pytest.mark.parametrize('time_string', ['1.5ep', '2.1ba', '3.2sp', '3.4tok']) +@pytest.mark.parametrize('time_string', ['1.1iter', '1.5ep', '2.1ba', '3.2sp', '3.4tok']) def test_timestep_bad_strings(time_string: str): with pytest.raises(TypeError): Time.from_timestring(time_string) -@pytest.mark.parametrize('time_string', ['0.5dur', '2.0ep', '3.000ba', '030.0sp']) +@pytest.mark.parametrize('time_string', ['0.5dur', '1.0iter', '2.0ep', '3.000ba', '030.0sp']) def test_timestep_valid_strings(time_string: str): Time.from_timestring(time_string) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 36aed8c9c6..a24f5d3c91 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -71,9 +71,11 @@ def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0): # Remove the wall clock time del checkpoint_1['state']['timestamp']['Timestamp']['total_wct'] + del checkpoint_1['state']['timestamp']['Timestamp']['iteration_wct'] del checkpoint_1['state']['timestamp']['Timestamp']['epoch_wct'] del checkpoint_1['state']['timestamp']['Timestamp']['batch_wct'] del checkpoint_2['state']['timestamp']['Timestamp']['total_wct'] + del checkpoint_2['state']['timestamp']['Timestamp']['iteration_wct'] del checkpoint_2['state']['timestamp']['Timestamp']['epoch_wct'] del checkpoint_2['state']['timestamp']['Timestamp']['batch_wct']