Skip to content

Commit

Permalink
Add iteration to TimeUnit
Browse files Browse the repository at this point in the history
commit-id:d19d88d1
  • Loading branch information
b-chu committed Feb 14, 2024
1 parent ecd6e50 commit 5cbb1c8
Showing 1 changed file with 107 additions and 2 deletions.
109 changes: 107 additions & 2 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ class TimeUnit(StringEnum):
"""Enum class to represent units of time for the training process.
Attributes:
ITERATION (str): Iterations.
EPOCH (str): Epochs.
BATCH (str): Batches (i.e. number of optimization steps)
SAMPLE (str): Samples.
TOKEN (str): Tokens. Applicable for natural language processing (NLP) models.
DURATION (str): Fraction of the training process complete, on ``[0.0, 1.0)``
"""
ITERATION = 'iter'
EPOCH = 'ep'
BATCH = 'ba'
SAMPLE = 'sp'
Expand Down Expand Up @@ -122,6 +124,20 @@ def __init__(
raise TypeError(f'value {value} is of type {type(value)}. Units {unit} require integer values.')
self._value, self._unit = value, TimeUnit(unit)

@classmethod
def from_iteration(cls, iteration: int) -> Time:
"""Create a :class:`Time` with units of :attr:`TimeUnit.ITERATION`.
Equivalent to ``Time(epoch, TimeUnit.EPOCH)``.
Args:
epoch (int): Number of epochs.
Returns:
Time: :class:`Time` instance, in epochs.
"""
return cls(iteration, TimeUnit.ITERATION)

@classmethod
def from_epoch(cls, epoch: int) -> Time:
"""Create a :class:`Time` with units of :attr:`TimeUnit.EPOCH`.
Expand Down Expand Up @@ -391,37 +407,48 @@ def from_timestring(cls, timestring: str) -> Time:
class Timestamp(Serializable):
"""Timestamp represents a snapshot of the current training progress.
The timestamp measures training progress in terms of epochs, batches, samples, tokens, and wall clock time.
The timestamp measures training progress in terms of iterations, epochs, batches, samples, tokens, and wall clock time.
Timestamps are not updated in-place.
See the :doc:`Time Guide </trainer/time>` for more details on tracking time during training.
Args:
iteration (int | Time[int], optional): The iteration.
epoch (int | Time[int], optional): The epoch.
batch (int | Time[int], optional): the batch.
sample (int | Time[int], optional): The sample.
token (int | Time[int], optional): The token.
epoch_in_iteration (int | Time[int], optional): The epoch in the iteration.
batch_in_epoch (int | Time[int], optional): The batch in the epoch.
sample_in_epoch (int | Time[int], optional): The sample in the epoch.
token_in_epoch (int | Time[int], optional): The token in the epoch.
total_wct (datetime.timedelta, optional): The total wall-clock duration.
iteration_wct (datetime.timedelta, optional): The wall-clock duration of the last iteration.
epoch_wct (datetime.timedelta, optional): The wall-clock duration of the last epoch.
batch_wct (datetime.timedelta, optional): The wall-clock duration of the last batch.
"""

def __init__(
self,
iteration: Union[int, Time[int]] = 0,
epoch: Union[int, Time[int]] = 0,
batch: Union[int, Time[int]] = 0,
sample: Union[int, Time[int]] = 0,
token: Union[int, Time[int]] = 0,
epoch_in_iteration: Union[int, Time[int]] = 0,
batch_in_epoch: Union[int, Time[int]] = 0,
sample_in_epoch: Union[int, Time[int]] = 0,
token_in_epoch: Union[int, Time[int]] = 0,
total_wct: Optional[datetime.timedelta] = None,
iteration_wct: Optional[datetime.timedelta] = None,
epoch_wct: Optional[datetime.timedelta] = None,
batch_wct: Optional[datetime.timedelta] = None,
):
iteration = Time.from_input(iteration, TimeUnit.ITERATION)
if iteration.unit != TimeUnit.ITERATION:
raise ValueError(f'The `iteration` argument has units of {iteration.unit}; not {TimeUnit.ITERATION}.')
self._iteration = iteration

epoch = Time.from_input(epoch, TimeUnit.EPOCH)
if epoch.unit != TimeUnit.EPOCH:
raise ValueError(f'The `epoch` argument has units of {epoch.unit}; not {TimeUnit.EPOCH}.')
Expand All @@ -442,6 +469,12 @@ def __init__(
raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.')
self._token = token

epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.BATCH)
if epoch_in_iteration.unit != TimeUnit.BATCH:
raise ValueError((f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; '
f'not {TimeUnit.EPOCH}.'))
self._epoch_in_iteration = epoch_in_iteration

batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH)
if batch_in_epoch.unit != TimeUnit.BATCH:
raise ValueError((f'The `batch_in_epoch` argument has units of {batch_in_epoch.unit}; '
Expand All @@ -464,6 +497,10 @@ def __init__(
total_wct = datetime.timedelta(seconds=0)
self._total_wct = total_wct

if iteration_wct is None:
iteration_wct = datetime.timedelta(seconds=0)
self._iteration_wct = iteration_wct

if epoch_wct is None:
epoch_wct = datetime.timedelta(seconds=0)
self._epoch_wct = epoch_wct
Expand All @@ -474,14 +511,17 @@ def __init__(

def state_dict(self) -> Dict[str, Any]:
return {
'iteration': self.iteration.value,
'epoch': self.epoch.value,
'batch': self.batch.value,
'sample': self.sample.value,
'token': self.token.value,
'epoch_in_iteration': self.epoch_in_iteration.value,
'batch_in_epoch': self.batch_in_epoch.value,
'sample_in_epoch': self.sample_in_epoch.value,
'token_in_epoch': self.token_in_epoch.value,
'total_wct': self.total_wct,
'iteration_wct': self.iteration_wct,
'epoch_wct': self.epoch_wct,
'batch_wct': self.batch_wct,
}
Expand All @@ -493,35 +533,47 @@ def get_state(self) -> Dict[str, Union[Time[int], datetime.timedelta]]:
Dict[str, Union[Time[int], datetime.timedelta]]: All values of the timestamp object.
"""
return {
'iteration': self.iteration.value,
'epoch': self.epoch,
'batch': self.batch,
'sample': self.sample,
'token': self.token,
'epoch_in_iteration': self.epoch_in_iteration.value,
'batch_in_epoch': self.batch_in_epoch,
'sample_in_epoch': self.sample_in_epoch,
'token_in_epoch': self.token_in_epoch,
'total_wct': self.total_wct,
'iteration_wct': self.iteration_wct,
'epoch_wct': self.epoch_wct,
'batch_wct': self.batch_wct,
}

def load_state_dict(self, state: Dict[str, Any]) -> None:
self._iteration = Time(state['iteration'], TimeUnit.ITERATION)
self._epoch = Time(state['epoch'], TimeUnit.EPOCH)
self._batch = Time(state['batch'], TimeUnit.BATCH)
self._sample = Time(state['sample'], TimeUnit.SAMPLE)
self._token = Time(state['token'], TimeUnit.TOKEN)
self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH)
self._batch_in_epoch = Time(state['batch_in_epoch'], TimeUnit.BATCH)
self._sample_in_epoch = Time(state['sample_in_epoch'], TimeUnit.SAMPLE)
self._token_in_epoch = Time(state['token_in_epoch'], TimeUnit.TOKEN)
# Wall clock time tracking was added in composer v0.7.0
# Using conditional checks as not to break old checkpoints
if 'total_wct' in state:
self._total_wct = state['total_wct']
if 'iteration_wct' in state:
self._iteration_wct = state['iteration_wct']
if 'epoch_wct' in state:
self._epoch_wct = state['epoch_wct']
if 'batch_wct' in state:
self._batch_wct = state['batch_wct']

@property
def iteration(self) -> Time[int]:
"""The total iteration count."""
return self._iteration

@property
def epoch(self) -> Time[int]:
"""The total epoch count."""
Expand All @@ -541,6 +593,11 @@ def sample(self) -> Time[int]:
def token(self) -> Time[int]:
"""The total token count."""
return self._token

@property
def epoch_in_iteration(self) -> Time[int]:
"""The epoch count in the current iteration (resets at 0 at the beginning of every iteration)."""
return self._epoch_in_iteration

@property
def batch_in_epoch(self) -> Time[int]:
Expand All @@ -562,6 +619,11 @@ def total_wct(self) -> datetime.timedelta:
"""The wall-clock duration (in seconds) from the beginning of training."""
return self._total_wct

@property
def iteration_wct(self) -> datetime.timedelta:
"""The wall-clock duration (in seconds) for the current iteration."""
return self._iteration_wct

@property
def epoch_wct(self) -> datetime.timedelta:
"""The wall-clock duration (in seconds) for the current epoch."""
Expand All @@ -582,6 +644,8 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]:
Time: The current time, in the specified unit.
"""
unit = TimeUnit(unit)
if unit == TimeUnit.ITERATION:
return self.iteration
if unit == TimeUnit.EPOCH:
return self.epoch
if unit == TimeUnit.BATCH:
Expand Down Expand Up @@ -678,6 +742,7 @@ def to_next_batch(
... token = timestamp.token + tokens,
... token_in_epoch=timestamp.token_in_epoch + tokens,
... total_wct=timestamp.total_wct + duration,
... iteration_wct=timestamp.iteration_wct + duration,
... epoch_wct=timestamp.epoch_wct + duration,
... batch_wct=duration,
... )
Expand Down Expand Up @@ -705,6 +770,7 @@ def to_next_batch(
token=self.token + tokens,
token_in_epoch=self.token_in_epoch + tokens,
total_wct=self.total_wct + duration,
iteration_wct=self.iteration_wct + duration,
epoch_wct=self.epoch_wct + duration,
batch_wct=duration,
)
Expand All @@ -728,7 +794,45 @@ def to_next_epoch(
.. doctest::
>>> timestamp.copy(
... epoch=timestamp.epoch+1,
... epoch=timestamp.epoch + 1,
... epoch_in_iteration=timestamp.epoch_in_iteration + 1,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
... iteration_wct=datetime.timedelta(seconds=0),
... epoch_wct=datetime.timedelta(seconds=0),
... batch_wct=datetime.timedelta(seconds=0),
... )
Timestamp(...)
"""
return self.copy(
epoch=self.epoch + 1,
epoch_in_iteration=self.epoch_in_iteration + 1,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
epoch_wct=datetime.timedelta(seconds=0),
batch_wct=datetime.timedelta(seconds=0),
)

def to_next_iteration(self):
"""Create a new :class:`.Timestamp`, advanced to the next iteration.
Equivalent to:
.. testsetup::
from composer.core.time import Timestamp
import datetime
timestamp = Timestamp()
.. doctest::
>>> timestamp.copy(
... iteration=timestamp.iteration + 1,
... epoch_in_iteration=0,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
Expand All @@ -741,6 +845,7 @@ def to_next_epoch(
"""
return self.copy(
epoch=self.epoch + 1,
epoch_in_iteration=self.epoch_in_iteration + 1,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
Expand Down

0 comments on commit 5cbb1c8

Please sign in to comment.