Skip to content

Commit

Permalink
Add duration to to_next_epoch
Browse files Browse the repository at this point in the history
Adds duration to the to_next_epoch function to increment total_wct.

commit-id:20ddfe85
  • Loading branch information
b-chu committed Feb 14, 2024
1 parent 9e60fa3 commit ecd6e50
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
8 changes: 7 additions & 1 deletion composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,10 @@ def to_next_batch(
batch_wct=duration,
)

def to_next_epoch(self):
def to_next_epoch(
self,
duration: Optional[datetime.timedelta] = None,
):
"""Create a new :class:`.Timestamp`, advanced to the next epoch.
Equivalent to:
Expand All @@ -720,6 +723,7 @@ def to_next_epoch(self):
import datetime
timestamp = Timestamp()
duration = datetime.timedelta(seconds=0)
.. doctest::
Expand All @@ -728,6 +732,7 @@ def to_next_epoch(self):
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
... total_wct=self.total_wct + duration,
... epoch_wct=datetime.timedelta(seconds=0),
... batch_wct=datetime.timedelta(seconds=0),
... )
Expand All @@ -739,6 +744,7 @@ def to_next_epoch(self):
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
total_wct=self.total_wct + duration,
epoch_wct=datetime.timedelta(seconds=0),
batch_wct=datetime.timedelta(seconds=0),
)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ def test_timestamp_to_next_batch_epoch():
assert timestamp.batch_wct == datetime.timedelta(seconds=5)

# Finish epoch 0
timestamp = timestamp.to_next_epoch()
timestamp = timestamp.to_next_epoch(datetime.timedelta(seconds=5))
assert timestamp.epoch == 1
assert timestamp.batch == 1
assert timestamp.batch_in_epoch == 0
assert timestamp.sample == 10
assert timestamp.sample_in_epoch == 0
assert timestamp.token == 20
assert timestamp.token_in_epoch == 0
assert timestamp.total_wct == datetime.timedelta(seconds=5)
assert timestamp.total_wct == datetime.timedelta(seconds=10)
assert timestamp.epoch_wct == datetime.timedelta(seconds=0)
assert timestamp.batch_wct == datetime.timedelta(seconds=0)

Expand All @@ -173,7 +173,7 @@ def test_timestamp_to_next_batch_epoch():
assert timestamp.sample_in_epoch == 5
assert timestamp.token == 20
assert timestamp.token_in_epoch == 0
assert timestamp.total_wct == datetime.timedelta(seconds=15)
assert timestamp.total_wct == datetime.timedelta(seconds=20)
assert timestamp.epoch_wct == datetime.timedelta(seconds=10)
assert timestamp.batch_wct == datetime.timedelta(seconds=10)

Expand All @@ -186,7 +186,7 @@ def test_timestamp_to_next_batch_epoch():
assert timestamp.sample_in_epoch == 10
assert timestamp.token == 21
assert timestamp.token_in_epoch == 1
assert timestamp.total_wct == datetime.timedelta(seconds=25)
assert timestamp.total_wct == datetime.timedelta(seconds=30)
assert timestamp.epoch_wct == datetime.timedelta(seconds=20)
assert timestamp.batch_wct == datetime.timedelta(seconds=10)

Expand Down

0 comments on commit ecd6e50

Please sign in to comment.