From ecd6e50559edfc88bc19d16b6502cdb4649c3d93 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Wed, 14 Feb 2024 21:25:14 +0000 Subject: [PATCH] Add duration to to_next_epoch Adds duration to the to_next_epoch function to increment total_wct. commit-id:20ddfe85 --- composer/core/time.py | 8 +++++++- tests/test_time.py | 8 ++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/composer/core/time.py b/composer/core/time.py index ab2d6a60eea..646448b66ba 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -709,7 +709,10 @@ def to_next_batch( batch_wct=duration, ) - def to_next_epoch(self): + def to_next_epoch( + self, + duration: Optional[datetime.timedelta] = None, + ): """Create a new :class:`.Timestamp`, advanced to the next epoch. Equivalent to: @@ -720,6 +723,7 @@ def to_next_epoch(self): import datetime timestamp = Timestamp() + duration = datetime.timedelta(seconds=0) .. doctest:: @@ -728,6 +732,7 @@ def to_next_epoch(self): ... batch_in_epoch=0, ... sample_in_epoch=0, ... token_in_epoch=0, + ... total_wct=self.total_wct + duration, ... epoch_wct=datetime.timedelta(seconds=0), ... batch_wct=datetime.timedelta(seconds=0), ... ) @@ -739,6 +744,7 @@ def to_next_epoch(self): batch_in_epoch=0, sample_in_epoch=0, token_in_epoch=0, + total_wct=self.total_wct + duration, epoch_wct=datetime.timedelta(seconds=0), batch_wct=datetime.timedelta(seconds=0), ) diff --git a/tests/test_time.py b/tests/test_time.py index 58f1cf9747a..9ca4946822f 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -152,7 +152,7 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.batch_wct == datetime.timedelta(seconds=5) # Finish epoch 0 - timestamp = timestamp.to_next_epoch() + timestamp = timestamp.to_next_epoch(datetime.timedelta(seconds=5)) assert timestamp.epoch == 1 assert timestamp.batch == 1 assert timestamp.batch_in_epoch == 0 @@ -160,7 +160,7 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.sample_in_epoch == 0 assert timestamp.token == 20 assert timestamp.token_in_epoch == 0 - assert timestamp.total_wct == datetime.timedelta(seconds=5) + assert timestamp.total_wct == datetime.timedelta(seconds=10) assert timestamp.epoch_wct == datetime.timedelta(seconds=0) assert timestamp.batch_wct == datetime.timedelta(seconds=0) @@ -173,7 +173,7 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.sample_in_epoch == 5 assert timestamp.token == 20 assert timestamp.token_in_epoch == 0 - assert timestamp.total_wct == datetime.timedelta(seconds=15) + assert timestamp.total_wct == datetime.timedelta(seconds=20) assert timestamp.epoch_wct == datetime.timedelta(seconds=10) assert timestamp.batch_wct == datetime.timedelta(seconds=10) @@ -186,7 +186,7 @@ def test_timestamp_to_next_batch_epoch(): assert timestamp.sample_in_epoch == 10 assert timestamp.token == 21 assert timestamp.token_in_epoch == 1 - assert timestamp.total_wct == datetime.timedelta(seconds=25) + assert timestamp.total_wct == datetime.timedelta(seconds=30) assert timestamp.epoch_wct == datetime.timedelta(seconds=20) assert timestamp.batch_wct == datetime.timedelta(seconds=10)