Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add iteration to TimeUnit #3013

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 135 additions & 3 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ class TimeUnit(StringEnum):
"""Enum class to represent units of time for the training process.
Attributes:
ITERATION (str): Iterations.
EPOCH (str): Epochs.
BATCH (str): Batches (i.e. number of optimization steps)
SAMPLE (str): Samples.
TOKEN (str): Tokens. Applicable for natural language processing (NLP) models.
DURATION (str): Fraction of the training process complete, on ``[0.0, 1.0)``
"""
ITERATION = 'iter'
EPOCH = 'ep'
BATCH = 'ba'
SAMPLE = 'sp'
Expand Down Expand Up @@ -122,6 +124,20 @@ def __init__(
raise TypeError(f'value {value} is of type {type(value)}. Units {unit} require integer values.')
self._value, self._unit = value, TimeUnit(unit)

@classmethod
def from_iteration(cls, iteration: int) -> Time:
"""Create a :class:`Time` with units of :attr:`TimeUnit.ITERATION`.
Equivalent to ``Time(iteration, TimeUnit.ITERATION)``.
Args:
iteration (int): Number of iterations.
Returns:
Time: :class:`Time` instance, in iterations.
"""
return cls(iteration, TimeUnit.ITERATION)

@classmethod
def from_epoch(cls, epoch: int) -> Time:
"""Create a :class:`Time` with units of :attr:`TimeUnit.EPOCH`.
Expand Down Expand Up @@ -391,37 +407,48 @@ def from_timestring(cls, timestring: str) -> Time:
class Timestamp(Serializable):
"""Timestamp represents a snapshot of the current training progress.
The timestamp measures training progress in terms of epochs, batches, samples, tokens, and wall clock time.
The timestamp measures training progress in terms of iterations, epochs, batches, samples, tokens, and wall clock time.
Timestamps are not updated in-place.
See the :doc:`Time Guide </trainer/time>` for more details on tracking time during training.
Args:
iteration (int | Time[int], optional): The iteration.
epoch (int | Time[int], optional): The epoch.
batch (int | Time[int], optional): the batch.
sample (int | Time[int], optional): The sample.
token (int | Time[int], optional): The token.
epoch_in_iteration (int | Time[int], optional): The epoch in the iteration.
batch_in_epoch (int | Time[int], optional): The batch in the epoch.
sample_in_epoch (int | Time[int], optional): The sample in the epoch.
token_in_epoch (int | Time[int], optional): The token in the epoch.
total_wct (datetime.timedelta, optional): The total wall-clock duration.
epoch_wct (datetime.timedelta, optional): The wall-clock duration of the last epoch.
iteration_wct (datetime.timedelta, optional): The wall-clock duration of the current iteration.
epoch_wct (datetime.timedelta, optional): The wall-clock duration of the current epoch.
batch_wct (datetime.timedelta, optional): The wall-clock duration of the last batch.
"""

def __init__(
self,
iteration: Union[int, Time[int]] = 0,
epoch: Union[int, Time[int]] = 0,
batch: Union[int, Time[int]] = 0,
sample: Union[int, Time[int]] = 0,
token: Union[int, Time[int]] = 0,
epoch_in_iteration: Union[int, Time[int]] = 0,
batch_in_epoch: Union[int, Time[int]] = 0,
sample_in_epoch: Union[int, Time[int]] = 0,
token_in_epoch: Union[int, Time[int]] = 0,
total_wct: Optional[datetime.timedelta] = None,
iteration_wct: Optional[datetime.timedelta] = None,
epoch_wct: Optional[datetime.timedelta] = None,
batch_wct: Optional[datetime.timedelta] = None,
):
iteration = Time.from_input(iteration, TimeUnit.ITERATION)
if iteration.unit != TimeUnit.ITERATION:
b-chu marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'The `iteration` argument has units of {iteration.unit}; not {TimeUnit.ITERATION}.')
self._iteration = iteration

epoch = Time.from_input(epoch, TimeUnit.EPOCH)
if epoch.unit != TimeUnit.EPOCH:
raise ValueError(f'The `epoch` argument has units of {epoch.unit}; not {TimeUnit.EPOCH}.')
Expand All @@ -442,6 +469,12 @@ def __init__(
raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.')
self._token = token

epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.EPOCH)
if epoch_in_iteration.unit != TimeUnit.EPOCH:
raise ValueError((f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; '
f'not {TimeUnit.EPOCH}.'))
self._epoch_in_iteration = epoch_in_iteration

batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH)
if batch_in_epoch.unit != TimeUnit.BATCH:
raise ValueError((f'The `batch_in_epoch` argument has units of {batch_in_epoch.unit}; '
Expand All @@ -464,6 +497,10 @@ def __init__(
total_wct = datetime.timedelta(seconds=0)
self._total_wct = total_wct

if iteration_wct is None:
iteration_wct = datetime.timedelta(seconds=0)
self._iteration_wct = iteration_wct

if epoch_wct is None:
epoch_wct = datetime.timedelta(seconds=0)
self._epoch_wct = epoch_wct
Expand All @@ -474,14 +511,17 @@ def __init__(

def state_dict(self) -> Dict[str, Any]:
return {
'iteration': self.iteration.value,
'epoch': self.epoch.value,
'batch': self.batch.value,
'sample': self.sample.value,
'token': self.token.value,
'epoch_in_iteration': self.epoch_in_iteration.value,
'batch_in_epoch': self.batch_in_epoch.value,
'sample_in_epoch': self.sample_in_epoch.value,
'token_in_epoch': self.token_in_epoch.value,
'total_wct': self.total_wct,
'iteration_wct': self.iteration_wct,
'epoch_wct': self.epoch_wct,
'batch_wct': self.batch_wct,
}
Expand All @@ -493,14 +533,17 @@ def get_state(self) -> Dict[str, Union[Time[int], datetime.timedelta]]:
Dict[str, Union[Time[int], datetime.timedelta]]: All values of the timestamp object.
"""
return {
'iteration': self.iteration,
'epoch': self.epoch,
'batch': self.batch,
'sample': self.sample,
'token': self.token,
'epoch_in_iteration': self.epoch_in_iteration,
'batch_in_epoch': self.batch_in_epoch,
'sample_in_epoch': self.sample_in_epoch,
'token_in_epoch': self.token_in_epoch,
'total_wct': self.total_wct,
'iteration_wct': self.iteration_wct,
'epoch_wct': self.epoch_wct,
'batch_wct': self.batch_wct,
}
Expand All @@ -513,14 +556,26 @@ def load_state_dict(self, state: Dict[str, Any]) -> None:
self._batch_in_epoch = Time(state['batch_in_epoch'], TimeUnit.BATCH)
self._sample_in_epoch = Time(state['sample_in_epoch'], TimeUnit.SAMPLE)
self._token_in_epoch = Time(state['token_in_epoch'], TimeUnit.TOKEN)
# Wall clock time tracking was added in composer v0.7.0
# Using conditional checks as not to break old checkpoints
# Wall clock time tracking was added in composer v0.7.0
if 'total_wct' in state:
self._total_wct = state['total_wct']
if 'epoch_wct' in state:
self._epoch_wct = state['epoch_wct']
if 'batch_wct' in state:
self._batch_wct = state['batch_wct']
# Iteration was added in composer v0.19.1
if 'iteration' in state:
self._iteration = Time(state['iteration'], TimeUnit.ITERATION)
if 'epoch_in_iteration' in state:
self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH)
if 'iteration_wct' in state:
self._iteration_wct = state['iteration_wct']

@property
def iteration(self) -> Time[int]:
"""The total iteration count."""
return self._iteration

@property
def epoch(self) -> Time[int]:
Expand All @@ -542,6 +597,11 @@ def token(self) -> Time[int]:
"""The total token count."""
return self._token

@property
def epoch_in_iteration(self) -> Time[int]:
"""The epoch count in the current iteration (resets at 0 at the beginning of every iteration)."""
return self._epoch_in_iteration

@property
def batch_in_epoch(self) -> Time[int]:
"""The batch count in the current epoch (resets at 0 at the beginning of every epoch)."""
Expand All @@ -562,6 +622,11 @@ def total_wct(self) -> datetime.timedelta:
"""The wall-clock duration (in seconds) from the beginning of training."""
return self._total_wct

@property
def iteration_wct(self) -> datetime.timedelta:
"""The wall-clock duration (in seconds) for the current iteration."""
return self._iteration_wct

@property
def epoch_wct(self) -> datetime.timedelta:
"""The wall-clock duration (in seconds) for the current epoch."""
Expand All @@ -582,6 +647,8 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]:
Time: The current time, in the specified unit.
"""
unit = TimeUnit(unit)
if unit == TimeUnit.ITERATION:
return self.iteration
if unit == TimeUnit.EPOCH:
return self.epoch
if unit == TimeUnit.BATCH:
Expand Down Expand Up @@ -678,6 +745,7 @@ def to_next_batch(
... token = timestamp.token + tokens,
... token_in_epoch=timestamp.token_in_epoch + tokens,
... total_wct=timestamp.total_wct + duration,
... iteration_wct=timestamp.iteration_wct + duration,
... epoch_wct=timestamp.epoch_wct + duration,
... batch_wct=duration,
... )
Expand Down Expand Up @@ -705,6 +773,7 @@ def to_next_batch(
token=self.token + tokens,
token_in_epoch=self.token_in_epoch + tokens,
total_wct=self.total_wct + duration,
iteration_wct=self.iteration_wct + duration,
epoch_wct=self.epoch_wct + duration,
batch_wct=duration,
)
Expand All @@ -729,10 +798,12 @@ def to_next_epoch(
>>> timestamp.copy(
... epoch=timestamp.epoch + 1,
... epoch_in_iteration=timestamp.epoch_in_iteration + 1,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
... total_wct=timestamp.total_wct + duration,
... iteration_wct=timestamp.iteration_wct + duration,
... epoch_wct=datetime.timedelta(seconds=0),
... batch_wct=datetime.timedelta(seconds=0),
... )
Expand All @@ -743,24 +814,74 @@ def to_next_epoch(
duration = datetime.timedelta(seconds=0)
return self.copy(
epoch=self.epoch + 1,
epoch_in_iteration=self.epoch_in_iteration + 1,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
total_wct=self.total_wct + duration,
iteration_wct=self.iteration_wct + duration,
epoch_wct=datetime.timedelta(seconds=0),
batch_wct=datetime.timedelta(seconds=0),
)

def to_next_iteration(
self,
duration: Optional[datetime.timedelta] = None,
):
"""Create a new :class:`.Timestamp`, advanced to the next iteration.
Equivalent to:
.. testsetup::
from composer.core.time import Timestamp
import datetime
timestamp = Timestamp()
.. doctest::
>>> timestamp.copy(
... iteration=timestamp.iteration + 1,
... epoch_in_iteration=0,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
... total_wct=timestamp.total_wct + duration,
... iteration_wct=datetime.timedelta(seconds=0),
... epoch_wct=datetime.timedelta(seconds=0),
... batch_wct=datetime.timedelta(seconds=0),
... )
Timestamp(...)
"""
if duration is None:
duration = datetime.timedelta(seconds=0)
return self.copy(
iteration=self.iteration + 1,
epoch_in_iteration=0,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
total_wct=self.total_wct + duration,
iteration_wct=datetime.timedelta(seconds=0),
epoch_wct=datetime.timedelta(seconds=0),
batch_wct=datetime.timedelta(seconds=0),
)

def copy(
self,
iteration: Optional[Union[int, Time[int]]] = None,
epoch: Optional[Union[int, Time[int]]] = None,
batch: Optional[Union[int, Time[int]]] = None,
sample: Optional[Union[int, Time[int]]] = None,
token: Optional[Union[int, Time[int]]] = None,
epoch_in_iteration: Optional[Union[int, Time[int]]] = None,
batch_in_epoch: Optional[Union[int, Time[int]]] = None,
sample_in_epoch: Optional[Union[int, Time[int]]] = None,
token_in_epoch: Optional[Union[int, Time[int]]] = None,
total_wct: Optional[datetime.timedelta] = None,
iteration_wct: Optional[datetime.timedelta] = None,
epoch_wct: Optional[datetime.timedelta] = None,
batch_wct: Optional[datetime.timedelta] = None,
) -> Timestamp:
Expand All @@ -769,42 +890,53 @@ def copy(
Any specified values will override the existing values in the returned copy.
Args:
iteration (int | Time[int], optional): The iteration.
epoch (int | Time[int], optional): The epoch.
batch (int | Time[int], optional): the batch.
sample (int | Time[int], optional): The sample.
token (int | Time[int], optional): The token.
epoch_in_iteration (int | Time[int], optional): The epoch in the iteration.
batch_in_epoch (int | Time[int], optional): The batch in the epoch.
sample_in_epoch (int | Time[int], optional): The sample in the epoch.
token_in_epoch (int | Time[int], optional): The token in the epoch.
total_wct (datetime.timedelta, optional): The elapsed duration from the beginning of training.
iteration_wct (datetime.timedelta, optional): The wall-clock duration of the current iteration.
epoch_wct (datetime.timedelta, optional): The wall-clock duration of the current epoch.
batch_wct (datetime.timedelta, optional): The wall-clock duration of the last batch.
Returns:
Timestamp: A new timestamp instance, created from a copy, but with any specified values
overriding the existing values.
"""
return Timestamp(
iteration=iteration if iteration is not None else self.iteration,
epoch=epoch if epoch is not None else self.epoch,
batch=batch if batch is not None else self.batch,
sample=sample if sample is not None else self.sample,
token=token if token is not None else self.token,
epoch_in_iteration=epoch_in_iteration if epoch_in_iteration is not None else self.epoch_in_iteration,
batch_in_epoch=batch_in_epoch if batch_in_epoch is not None else self.batch_in_epoch,
sample_in_epoch=sample_in_epoch if sample_in_epoch is not None else self.sample_in_epoch,
token_in_epoch=token_in_epoch if token_in_epoch is not None else self.token_in_epoch,
total_wct=total_wct if total_wct is not None else self.total_wct,
iteration_wct=iteration_wct if iteration_wct is not None else self.iteration_wct,
epoch_wct=epoch_wct if epoch_wct is not None else self.epoch_wct,
batch_wct=batch_wct if batch_wct is not None else self.batch_wct,
)

def __repr__(self) -> str:
return (f'Timestamp('
f'iteration={int(self.iteration)}, '
f'epoch={int(self.epoch)}, '
f'batch={int(self.batch)}, '
f'sample={int(self.sample)}, '
f'token={int(self.token)}, '
f'epoch_in_iteration={int(self.epoch_in_iteration)}, '
f'batch_in_epoch={int(self.batch_in_epoch)}, '
f'sample_in_epoch={int(self.sample_in_epoch)}, '
f'token_in_epoch={int(self.token_in_epoch)}, '
f'total_wct={repr(self.total_wct)}, '
f'iteration_wct={repr(self.iteration_wct)}, '
f'epoch_wct={repr(self.epoch_wct)}, '
f'batch_wct={repr(self.batch_wct)}'
')')
Expand Down
2 changes: 2 additions & 0 deletions tests/algorithms/test_algorithm_resumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,11 @@ def _assert_checkpoints_equal(file1, file2):
# compare state
# remove the wall clock time fields since they will always differ
del checkpoint1['state']['timestamp']['Timestamp']['total_wct']
del checkpoint1['state']['timestamp']['Timestamp']['iteration_wct']
del checkpoint1['state']['timestamp']['Timestamp']['epoch_wct']
del checkpoint1['state']['timestamp']['Timestamp']['batch_wct']
del checkpoint2['state']['timestamp']['Timestamp']['total_wct']
del checkpoint2['state']['timestamp']['Timestamp']['iteration_wct']
del checkpoint2['state']['timestamp']['Timestamp']['epoch_wct']
del checkpoint2['state']['timestamp']['Timestamp']['batch_wct']

Expand Down
1 change: 1 addition & 0 deletions tests/common/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

def _del_wct_timestamp_fields(timestamp_state_dict: Dict[str, Any]):
del timestamp_state_dict['Timestamp']['total_wct']
del timestamp_state_dict['Timestamp']['iteration_wct']
del timestamp_state_dict['Timestamp']['epoch_wct']
del timestamp_state_dict['Timestamp']['batch_wct']

Expand Down
Loading
Loading