diff --git a/composer/core/time.py b/composer/core/time.py index ab2d6a60ee..90b3bfdb97 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -709,7 +709,10 @@ def to_next_batch( batch_wct=duration, ) - def to_next_epoch(self): + def to_next_epoch( + self, + duration: Optional[datetime.timedelta] = None, + ): """Create a new :class:`.Timestamp`, advanced to the next epoch. Equivalent to: @@ -720,25 +723,30 @@ def to_next_epoch(self): import datetime timestamp = Timestamp() + duration = datetime.timedelta(seconds=0) .. doctest:: >>> timestamp.copy( - ... epoch=timestamp.epoch+1, + ... epoch=timestamp.epoch + 1, ... batch_in_epoch=0, ... sample_in_epoch=0, ... token_in_epoch=0, + ... total_wct=timestamp.total_wct + duration, ... epoch_wct=datetime.timedelta(seconds=0), ... batch_wct=datetime.timedelta(seconds=0), ... ) Timestamp(...) """ + if duration is None: + duration = datetime.timedelta(seconds=0) return self.copy( epoch=self.epoch + 1, batch_in_epoch=0, sample_in_epoch=0, token_in_epoch=0, + total_wct=self.total_wct + duration, epoch_wct=datetime.timedelta(seconds=0), batch_wct=datetime.timedelta(seconds=0), ) diff --git a/tests/test_time.py b/tests/test_time.py index 58f1cf9747..611bf83f72 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -138,7 +138,7 @@ def test_timestamp_update(): def test_timestamp_to_next_batch_epoch(): timestamp = Timestamp() - # Step batch 0, epoch 0 + # Step batch 0 in epoch 0 timestamp = timestamp.to_next_batch(10, 20, datetime.timedelta(seconds=5)) assert timestamp.batch == 1 assert timestamp.batch_in_epoch == 1 @@ -152,7 +152,7 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.batch_wct == datetime.timedelta(seconds=5) # Finish epoch 0 - timestamp = timestamp.to_next_epoch() + timestamp = timestamp.to_next_epoch(datetime.timedelta(seconds=5)) assert timestamp.epoch == 1 assert timestamp.batch == 1 assert timestamp.batch_in_epoch == 0 @@ -160,11 +160,11 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.sample_in_epoch == 0 assert timestamp.token == 20 assert timestamp.token_in_epoch == 0 - assert timestamp.total_wct == datetime.timedelta(seconds=5) + assert timestamp.total_wct == datetime.timedelta(seconds=10) assert timestamp.epoch_wct == datetime.timedelta(seconds=0) assert timestamp.batch_wct == datetime.timedelta(seconds=0) - # Step a batch 0 in epoch 1 + # Step batch 0 in epoch 1 timestamp = timestamp.to_next_batch(5, 0, datetime.timedelta(seconds=10)) assert timestamp.epoch == 1 assert timestamp.batch == 2 @@ -173,11 +173,11 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.sample_in_epoch == 5 assert timestamp.token == 20 assert timestamp.token_in_epoch == 0 - assert timestamp.total_wct == datetime.timedelta(seconds=15) + assert timestamp.total_wct == datetime.timedelta(seconds=20) assert timestamp.epoch_wct == datetime.timedelta(seconds=10) assert timestamp.batch_wct == datetime.timedelta(seconds=10) - # Step batch 1 in epoch 0 + # Step batch 1 in epoch 1 timestamp = timestamp.to_next_batch(5, 1, datetime.timedelta(seconds=10)) assert timestamp.epoch == 1 assert timestamp.batch == 3 @@ -186,10 +186,23 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.sample_in_epoch == 10 assert timestamp.token == 21 assert timestamp.token_in_epoch == 1 - assert timestamp.total_wct == datetime.timedelta(seconds=25) + assert timestamp.total_wct == datetime.timedelta(seconds=30) assert timestamp.epoch_wct == datetime.timedelta(seconds=20) assert timestamp.batch_wct == datetime.timedelta(seconds=10) + # Finish epoch 1 + timestamp = timestamp.to_next_epoch() + assert timestamp.epoch == 2 + assert timestamp.batch == 3 + assert timestamp.batch_in_epoch == 0 + assert timestamp.sample == 20 + assert timestamp.sample_in_epoch == 0 + assert timestamp.token == 21 + assert timestamp.token_in_epoch == 0 + assert timestamp.total_wct == datetime.timedelta(seconds=30) + assert timestamp.epoch_wct == datetime.timedelta(seconds=0) + assert timestamp.batch_wct == datetime.timedelta(seconds=0) + def test_timestamp_repr(): timestamp = Timestamp()