Skip to content

Commit

Permalink
Add setter for epoch in iteration (mosaicml#3407)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Jun 18, 2024
1 parent c425fa3 commit 9dda1e2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
25 changes: 17 additions & 8 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,13 +525,8 @@ def __init__(
raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.')
self._token = token

epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.EPOCH)
if epoch_in_iteration.unit != TimeUnit.EPOCH:
raise ValueError((
f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; '
f'not {TimeUnit.EPOCH}.'
))
self._epoch_in_iteration = epoch_in_iteration
self._epoch_in_iteration = Time(0, TimeUnit.EPOCH)
self.epoch_in_iteration = epoch_in_iteration

token_in_iteration = Time.from_input(token_in_iteration, TimeUnit.TOKEN)
if token_in_iteration.unit != TimeUnit.TOKEN:
Expand Down Expand Up @@ -619,7 +614,7 @@ def load_state_dict(self, state: dict[str, Any]) -> None:
if 'iteration' in state:
self._iteration = Time(state['iteration'], TimeUnit.ITERATION)
if 'epoch_in_iteration' in state:
self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH)
self.epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH)
if 'token_in_iteration' in state:
self._token_in_iteration = Time(state['token_in_iteration'], TimeUnit.TOKEN)
if 'iteration_wct' in state:
Expand Down Expand Up @@ -655,6 +650,20 @@ def epoch_in_iteration(self) -> Time[int]:
"""The epoch count in the current iteration (resets at 0 at the beginning of every iteration)."""
return self._epoch_in_iteration

@epoch_in_iteration.setter
def epoch_in_iteration(
self,
epoch_in_iteration: Union[int, Time[int]], # pyright: ignore[reportPropertyTypeMismatch]
):
"""Sets epoch count in the current iteration."""
epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.EPOCH)
if epoch_in_iteration.unit != TimeUnit.EPOCH:
raise ValueError((
f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; '
f'not {TimeUnit.EPOCH}.'
))
self._epoch_in_iteration = epoch_in_iteration

@property
def token_in_iteration(self) -> Time[int]:
"""The token count in the current iteration (resets at 0 at the beginning of every iteration)."""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def test_timestamp_update():
assert timestamp is not timestamp_2


def test_set_timestamp():
timestamp = Timestamp(epoch_in_iteration=1)
assert timestamp.epoch_in_iteration == 1
timestamp.epoch_in_iteration = 2
assert timestamp.epoch_in_iteration == 2


def test_timestamp_to_next_batch_epoch_iteration():
timestamp = Timestamp()
# Step batch 0 in epoch 0
Expand Down

0 comments on commit 9dda1e2

Please sign in to comment.