Skip to content

Commit

Permalink
Add iteration to TimeUnit
Browse files Browse the repository at this point in the history
commit-id:e4729f79
  • Loading branch information
b-chu committed Feb 14, 2024
1 parent 15b8ef8 commit 6c35d51
Showing 1 changed file with 130 additions and 1 deletion.
131 changes: 130 additions & 1 deletion composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ class TimeUnit(StringEnum):
"""Enum class to represent units of time for the training process.

Attributes:
ITERATION (str): Iterations.
EPOCH (str): Epochs.
BATCH (str): Batches (i.e. number of optimization steps)
SAMPLE (str): Samples.
TOKEN (str): Tokens. Applicable for natural language processing (NLP) models.
DURATION (str): Fraction of the training process complete, on ``[0.0, 1.0)``
"""
ITERATION = 'iter'
EPOCH = 'ep'
BATCH = 'ba'
SAMPLE = 'sp'
Expand Down Expand Up @@ -122,6 +124,20 @@ def __init__(
raise TypeError(f'value {value} is of type {type(value)}. Units {unit} require integer values.')
self._value, self._unit = value, TimeUnit(unit)

@classmethod
def from_iteration(cls, iteration: int) -> Time:
"""Create a :class:`Time` with units of :attr:`TimeUnit.ITERATION`.

Equivalent to ``Time(epoch, TimeUnit.EPOCH)``.

Args:
epoch (int): Number of epochs.

Returns:
Time: :class:`Time` instance, in epochs.
"""
return cls(iteration, TimeUnit.ITERATION)

@classmethod
def from_epoch(cls, epoch: int) -> Time:
"""Create a :class:`Time` with units of :attr:`TimeUnit.EPOCH`.
Expand Down Expand Up @@ -391,37 +407,48 @@ def from_timestring(cls, timestring: str) -> Time:
class Timestamp(Serializable):
"""Timestamp represents a snapshot of the current training progress.

The timestamp measures training progress in terms of epochs, batches, samples, tokens, and wall clock time.
The timestamp measures training progress in terms of iterations, epochs, batches, samples, tokens, and wall clock time.
Timestamps are not updated in-place.

See the :doc:`Time Guide </trainer/time>` for more details on tracking time during training.

Args:
iteration (int | Time[int], optional): The iteration.
epoch (int | Time[int], optional): The epoch.
batch (int | Time[int], optional): the batch.
sample (int | Time[int], optional): The sample.
token (int | Time[int], optional): The token.
epoch_in_iteration (int | Time[int], optional): The epoch in the iteration.
batch_in_epoch (int | Time[int], optional): The batch in the epoch.
sample_in_epoch (int | Time[int], optional): The sample in the epoch.
token_in_epoch (int | Time[int], optional): The token in the epoch.
total_wct (datetime.timedelta, optional): The total wall-clock duration.
iteration_wct (datetime.timedelta, optional): The wall-clock duration of the last iteration.
epoch_wct (datetime.timedelta, optional): The wall-clock duration of the last epoch.
batch_wct (datetime.timedelta, optional): The wall-clock duration of the last batch.
"""

def __init__(
self,
iteration: Union[int, Time[int]] = 0,
epoch: Union[int, Time[int]] = 0,
batch: Union[int, Time[int]] = 0,
sample: Union[int, Time[int]] = 0,
token: Union[int, Time[int]] = 0,
epoch_in_iteration: Union[int, Time[int]] = 0,
batch_in_epoch: Union[int, Time[int]] = 0,
sample_in_epoch: Union[int, Time[int]] = 0,
token_in_epoch: Union[int, Time[int]] = 0,
total_wct: Optional[datetime.timedelta] = None,
iteration_wct: Optional[datetime.timedelta] = None,
epoch_wct: Optional[datetime.timedelta] = None,
batch_wct: Optional[datetime.timedelta] = None,
):
iteration = Time.from_input(iteration, TimeUnit.ITERATION)
if iteration.unit != TimeUnit.ITERATION:
raise ValueError(f'The `iteration` argument has units of {iteration.unit}; not {TimeUnit.ITERATION}.')
self._iteration = iteration

epoch = Time.from_input(epoch, TimeUnit.EPOCH)
if epoch.unit != TimeUnit.EPOCH:
raise ValueError(f'The `epoch` argument has units of {epoch.unit}; not {TimeUnit.EPOCH}.')
Expand All @@ -442,6 +469,12 @@ def __init__(
raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.')
self._token = token

epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.BATCH)
if epoch_in_iteration.unit != TimeUnit.BATCH:
raise ValueError((f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; '
f'not {TimeUnit.EPOCH}.'))
self._epoch_in_iteration = epoch_in_iteration

batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH)
if batch_in_epoch.unit != TimeUnit.BATCH:
raise ValueError((f'The `batch_in_epoch` argument has units of {batch_in_epoch.unit}; '
Expand All @@ -464,6 +497,10 @@ def __init__(
total_wct = datetime.timedelta(seconds=0)
self._total_wct = total_wct

if iteration_wct is None:
iteration_wct = datetime.timedelta(seconds=0)
self._iteration_wct = iteration_wct

if epoch_wct is None:
epoch_wct = datetime.timedelta(seconds=0)
self._epoch_wct = epoch_wct
Expand All @@ -474,14 +511,17 @@ def __init__(

def state_dict(self) -> Dict[str, Any]:
return {
'iteration': self.iteration.value,
'epoch': self.epoch.value,
'batch': self.batch.value,
'sample': self.sample.value,
'token': self.token.value,
'epoch_in_iteration': self.epoch_in_iteration.value,
'batch_in_epoch': self.batch_in_epoch.value,
'sample_in_epoch': self.sample_in_epoch.value,
'token_in_epoch': self.token_in_epoch.value,
'total_wct': self.total_wct,
'iteration_wct': self.iteration_wct,
'epoch_wct': self.epoch_wct,
'batch_wct': self.batch_wct,
}
Expand All @@ -493,35 +533,47 @@ def get_state(self) -> Dict[str, Union[Time[int], datetime.timedelta]]:
Dict[str, Union[Time[int], datetime.timedelta]]: All values of the timestamp object.
"""
return {
'iteration': self.iteration.value,
'epoch': self.epoch,
'batch': self.batch,
'sample': self.sample,
'token': self.token,
'epoch_in_iteration': self.epoch_in_iteration.value,
'batch_in_epoch': self.batch_in_epoch,
'sample_in_epoch': self.sample_in_epoch,
'token_in_epoch': self.token_in_epoch,
'total_wct': self.total_wct,
'iteration_wct': self.iteration_wct,
'epoch_wct': self.epoch_wct,
'batch_wct': self.batch_wct,
}

def load_state_dict(self, state: Dict[str, Any]) -> None:
self._iteration = Time(state['iteration'], TimeUnit.ITERATION)
self._epoch = Time(state['epoch'], TimeUnit.EPOCH)
self._batch = Time(state['batch'], TimeUnit.BATCH)
self._sample = Time(state['sample'], TimeUnit.SAMPLE)
self._token = Time(state['token'], TimeUnit.TOKEN)
self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH)
self._batch_in_epoch = Time(state['batch_in_epoch'], TimeUnit.BATCH)
self._sample_in_epoch = Time(state['sample_in_epoch'], TimeUnit.SAMPLE)
self._token_in_epoch = Time(state['token_in_epoch'], TimeUnit.TOKEN)
# Wall clock time tracking was added in composer v0.7.0
# Using conditional checks as not to break old checkpoints
if 'total_wct' in state:
self._total_wct = state['total_wct']
if 'iteration_wct' in state:
self._iteration_wct = state['iteration_wct']
if 'epoch_wct' in state:
self._epoch_wct = state['epoch_wct']
if 'batch_wct' in state:
self._batch_wct = state['batch_wct']

@property
def iteration(self) -> Time[int]:
"""The total iteration count."""
return self._iteration

@property
def epoch(self) -> Time[int]:
"""The total epoch count."""
Expand All @@ -541,6 +593,11 @@ def sample(self) -> Time[int]:
def token(self) -> Time[int]:
"""The total token count."""
return self._token

@property
def epoch_in_iteration(self) -> Time[int]:
"""The epoch count in the current iteration (resets at 0 at the beginning of every iteration)."""
return self._epoch_in_iteration

@property
def batch_in_epoch(self) -> Time[int]:
Expand All @@ -562,6 +619,11 @@ def total_wct(self) -> datetime.timedelta:
"""The wall-clock duration (in seconds) from the beginning of training."""
return self._total_wct

@property
def iteration_wct(self) -> datetime.timedelta:
"""The wall-clock duration (in seconds) for the current iteration."""
return self._iteration_wct

@property
def epoch_wct(self) -> datetime.timedelta:
"""The wall-clock duration (in seconds) for the current epoch."""
Expand All @@ -582,6 +644,8 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]:
Time: The current time, in the specified unit.
"""
unit = TimeUnit(unit)
if unit == TimeUnit.ITERATION:
return self.iteration
if unit == TimeUnit.EPOCH:
return self.epoch
if unit == TimeUnit.BATCH:
Expand Down Expand Up @@ -678,6 +742,7 @@ def to_next_batch(
... token = timestamp.token + tokens,
... token_in_epoch=timestamp.token_in_epoch + tokens,
... total_wct=timestamp.total_wct + duration,
... iteration_wct=timestamp.iteration_wct + duration,
... epoch_wct=timestamp.epoch_wct + duration,
... batch_wct=duration,
... )
Expand Down Expand Up @@ -705,6 +770,7 @@ def to_next_batch(
token=self.token + tokens,
token_in_epoch=self.token_in_epoch + tokens,
total_wct=self.total_wct + duration,
iteration_wct=self.iteration_wct + duration,
epoch_wct=self.epoch_wct + duration,
batch_wct=duration,
)
Expand All @@ -729,10 +795,12 @@ def to_next_epoch(

>>> timestamp.copy(
... epoch=timestamp.epoch + 1,
... epoch_in_iteration=timestamp.epoch_in_iteration + 1,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
... total_wct=timestamp.total_wct + duration,
... iteration_wct=timestamp.iteration_wct + duration,
... epoch_wct=datetime.timedelta(seconds=0),
... batch_wct=datetime.timedelta(seconds=0),
... )
Expand All @@ -743,24 +811,74 @@ def to_next_epoch(
duration = datetime.timedelta(seconds=0)
return self.copy(
epoch=self.epoch + 1,
epoch_in_iteration=self.epoch_in_iteration + 1,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
total_wct=self.total_wct + duration,
iteration_wct=self.iteration_wct + duration,
epoch_wct=datetime.timedelta(seconds=0),
batch_wct=datetime.timedelta(seconds=0),
)

def to_next_iteration(
self,
duration: Optional[datetime.timedelta] = None,
):
"""Create a new :class:`.Timestamp`, advanced to the next iteration.

Equivalent to:

.. testsetup::

from composer.core.time import Timestamp
import datetime

timestamp = Timestamp()

.. doctest::

>>> timestamp.copy(
... iteration=timestamp.iteration + 1,
... epoch_in_iteration=0,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
... total_wct=timestamp.total_wct + duration,
... iteration_wct=datetime.timedelta(seconds=0),
... epoch_wct=datetime.timedelta(seconds=0),
... batch_wct=datetime.timedelta(seconds=0),
... )
Timestamp(...)

"""
if duration is None:
duration = datetime.timedelta(seconds=0)
return self.copy(
iteration=self.iteration + 1,
epoch_in_iteration=0,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
total_wct=self.total_wct + duration,
iteration_wct=datetime.timedelta(seconds=0),
epoch_wct=datetime.timedelta(seconds=0),
batch_wct=datetime.timedelta(seconds=0),
)

def copy(
self,
iteration: Optional[Union[int, Time[int]]] = None,
epoch: Optional[Union[int, Time[int]]] = None,
batch: Optional[Union[int, Time[int]]] = None,
sample: Optional[Union[int, Time[int]]] = None,
token: Optional[Union[int, Time[int]]] = None,
epoch_in_iteration: Optional[Union[int, Time[int]]] = None,
batch_in_epoch: Optional[Union[int, Time[int]]] = None,
sample_in_epoch: Optional[Union[int, Time[int]]] = None,
token_in_epoch: Optional[Union[int, Time[int]]] = None,
total_wct: Optional[datetime.timedelta] = None,
iteration_wct: Optional[datetime.timedelta] = None,
epoch_wct: Optional[datetime.timedelta] = None,
batch_wct: Optional[datetime.timedelta] = None,
) -> Timestamp:
Expand All @@ -769,42 +887,53 @@ def copy(
Any specified values will override the existing values in the returned copy.

Args:
iteration (int | Time[int], optional): The iteration.
epoch (int | Time[int], optional): The epoch.
batch (int | Time[int], optional): the batch.
sample (int | Time[int], optional): The sample.
token (int | Time[int], optional): The token.
epoch_in_iteration (int | Time[int], optional): The epoch in the iteration.
batch_in_epoch (int | Time[int], optional): The batch in the epoch.
sample_in_epoch (int | Time[int], optional): The sample in the epoch.
token_in_epoch (int | Time[int], optional): The token in the epoch.
total_wct (datetime.timedelta, optional): The elapsed duration from the beginning of training.
iteration_wct (datetime.timedelta, optional): The wall-clock duration of the last iteration.
epoch_wct (datetime.timedelta, optional): The wall-clock duration of the last epoch.
batch_wct (datetime.timedelta, optional): The wall-clock duration of the last batch.

Returns:
Timestamp: A new timestamp instance, created from a copy, but with any specified values
overriding the existing values.
"""
return Timestamp(
iteration=iteration if iteration is not None else self.iteration,
epoch=epoch if epoch is not None else self.epoch,
batch=batch if batch is not None else self.batch,
sample=sample if sample is not None else self.sample,
token=token if token is not None else self.token,
epoch_in_iteration=epoch_in_iteration if epoch_in_iteration is not None else self.epoch_in_iteration,
batch_in_epoch=batch_in_epoch if batch_in_epoch is not None else self.batch_in_epoch,
sample_in_epoch=sample_in_epoch if sample_in_epoch is not None else self.sample_in_epoch,
token_in_epoch=token_in_epoch if token_in_epoch is not None else self.token_in_epoch,
total_wct=total_wct if total_wct is not None else self.total_wct,
iteration_wct=iteration_wct if iteration_wct is not None else self.iteration_wct,
epoch_wct=epoch_wct if epoch_wct is not None else self.epoch_wct,
batch_wct=batch_wct if batch_wct is not None else self.batch_wct,
)

def __repr__(self) -> str:
return (f'Timestamp('
f'iteration={int(self.iteration)}, '
f'epoch={int(self.epoch)}, '
f'batch={int(self.batch)}, '
f'sample={int(self.sample)}, '
f'token={int(self.token)}, '
f'epoch_in_iteration={int(self.epoch_in_iteration)}, '
f'batch_in_epoch={int(self.batch_in_epoch)}, '
f'sample_in_epoch={int(self.sample_in_epoch)}, '
f'token_in_epoch={int(self.token_in_epoch)}, '
f'total_wct={repr(self.total_wct)}, '
f'iteration_wct={repr(self.iteration_wct)}, '
f'epoch_wct={repr(self.epoch_wct)}, '
f'batch_wct={repr(self.batch_wct)}'
')')
Expand Down

0 comments on commit 6c35d51

Please sign in to comment.