From 9dda1e252ef3c0b94268d5038b0be1bce7ee55b8 Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:29:34 -0400 Subject: [PATCH] Add setter for epoch in iteration (#3407) --- composer/core/time.py | 25 +++++++++++++++++-------- tests/test_time.py | 7 +++++++ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/composer/core/time.py b/composer/core/time.py index 3916dd7659..00af1fd456 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -525,13 +525,8 @@ def __init__( raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.') self._token = token - epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.EPOCH) - if epoch_in_iteration.unit != TimeUnit.EPOCH: - raise ValueError(( - f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; ' - f'not {TimeUnit.EPOCH}.' - )) - self._epoch_in_iteration = epoch_in_iteration + self._epoch_in_iteration = Time(0, TimeUnit.EPOCH) + self.epoch_in_iteration = epoch_in_iteration token_in_iteration = Time.from_input(token_in_iteration, TimeUnit.TOKEN) if token_in_iteration.unit != TimeUnit.TOKEN: @@ -619,7 +614,7 @@ def load_state_dict(self, state: dict[str, Any]) -> None: if 'iteration' in state: self._iteration = Time(state['iteration'], TimeUnit.ITERATION) if 'epoch_in_iteration' in state: - self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH) + self.epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH) if 'token_in_iteration' in state: self._token_in_iteration = Time(state['token_in_iteration'], TimeUnit.TOKEN) if 'iteration_wct' in state: @@ -655,6 +650,20 @@ def epoch_in_iteration(self) -> Time[int]: """The epoch count in the current iteration (resets at 0 at the beginning of every iteration).""" return self._epoch_in_iteration + @epoch_in_iteration.setter + def epoch_in_iteration( + self, + epoch_in_iteration: Union[int, Time[int]], # pyright: ignore[reportPropertyTypeMismatch] + ): + """Sets epoch count in the current iteration.""" + epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.EPOCH) + if epoch_in_iteration.unit != TimeUnit.EPOCH: + raise ValueError(( + f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; ' + f'not {TimeUnit.EPOCH}.' + )) + self._epoch_in_iteration = epoch_in_iteration + @property def token_in_iteration(self) -> Time[int]: """The token count in the current iteration (resets at 0 at the beginning of every iteration).""" diff --git a/tests/test_time.py b/tests/test_time.py index 1545eaa3b1..d585d9af36 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -146,6 +146,13 @@ def test_timestamp_update(): assert timestamp is not timestamp_2 +def test_set_timestamp(): + timestamp = Timestamp(epoch_in_iteration=1) + assert timestamp.epoch_in_iteration == 1 + timestamp.epoch_in_iteration = 2 + assert timestamp.epoch_in_iteration == 2 + + def test_timestamp_to_next_batch_epoch_iteration(): timestamp = Timestamp() # Step batch 0 in epoch 0