Skip to content

Commit

Permalink
Add duration to to_next_epoch
Browse files Browse the repository at this point in the history
Adds duration to the to_next_epoch function to increment total_wct.

commit-id:5f8ef9be
  • Loading branch information
b-chu committed Feb 14, 2024
1 parent 9e60fa3 commit 9784717
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
12 changes: 10 additions & 2 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,10 @@ def to_next_batch(
batch_wct=duration,
)

def to_next_epoch(self):
def to_next_epoch(
self,
duration: Optional[datetime.timedelta] = None,
):
"""Create a new :class:`.Timestamp`, advanced to the next epoch.
Equivalent to:
Expand All @@ -720,25 +723,30 @@ def to_next_epoch(self):
import datetime
timestamp = Timestamp()
duration = datetime.timedelta(seconds=0)
.. doctest::
>>> timestamp.copy(
... epoch=timestamp.epoch+1,
... epoch=timestamp.epoch + 1,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
... total_wct=timestamp.total_wct + duration,
... epoch_wct=datetime.timedelta(seconds=0),
... batch_wct=datetime.timedelta(seconds=0),
... )
Timestamp(...)
"""
if duration is None:
duration = datetime.timedelta(seconds=0)
return self.copy(
epoch=self.epoch + 1,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
total_wct=self.total_wct + duration,
epoch_wct=datetime.timedelta(seconds=0),
batch_wct=datetime.timedelta(seconds=0),
)
Expand Down
27 changes: 20 additions & 7 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_timestamp_update():

def test_timestamp_to_next_batch_epoch():
timestamp = Timestamp()
# Step batch 0, epoch 0
# Step batch 0 in epoch 0
timestamp = timestamp.to_next_batch(10, 20, datetime.timedelta(seconds=5))
assert timestamp.batch == 1
assert timestamp.batch_in_epoch == 1
Expand All @@ -152,19 +152,19 @@ def test_timestamp_to_next_batch_epoch():
assert timestamp.batch_wct == datetime.timedelta(seconds=5)

# Finish epoch 0
timestamp = timestamp.to_next_epoch()
timestamp = timestamp.to_next_epoch(datetime.timedelta(seconds=5))
assert timestamp.epoch == 1
assert timestamp.batch == 1
assert timestamp.batch_in_epoch == 0
assert timestamp.sample == 10
assert timestamp.sample_in_epoch == 0
assert timestamp.token == 20
assert timestamp.token_in_epoch == 0
assert timestamp.total_wct == datetime.timedelta(seconds=5)
assert timestamp.total_wct == datetime.timedelta(seconds=10)
assert timestamp.epoch_wct == datetime.timedelta(seconds=0)
assert timestamp.batch_wct == datetime.timedelta(seconds=0)

# Step a batch 0 in epoch 1
# Step batch 0 in epoch 1
timestamp = timestamp.to_next_batch(5, 0, datetime.timedelta(seconds=10))
assert timestamp.epoch == 1
assert timestamp.batch == 2
Expand All @@ -173,11 +173,11 @@ def test_timestamp_to_next_batch_epoch():
assert timestamp.sample_in_epoch == 5
assert timestamp.token == 20
assert timestamp.token_in_epoch == 0
assert timestamp.total_wct == datetime.timedelta(seconds=15)
assert timestamp.total_wct == datetime.timedelta(seconds=20)
assert timestamp.epoch_wct == datetime.timedelta(seconds=10)
assert timestamp.batch_wct == datetime.timedelta(seconds=10)

# Step batch 1 in epoch 0
# Step batch 1 in epoch 1
timestamp = timestamp.to_next_batch(5, 1, datetime.timedelta(seconds=10))
assert timestamp.epoch == 1
assert timestamp.batch == 3
Expand All @@ -186,7 +186,20 @@ def test_timestamp_to_next_batch_epoch():
assert timestamp.sample_in_epoch == 10
assert timestamp.token == 21
assert timestamp.token_in_epoch == 1
assert timestamp.total_wct == datetime.timedelta(seconds=25)
assert timestamp.total_wct == datetime.timedelta(seconds=30)
assert timestamp.epoch_wct == datetime.timedelta(seconds=20)
assert timestamp.batch_wct == datetime.timedelta(seconds=10)

# Finish epoch 1
timestamp = timestamp.to_next_epoch()
assert timestamp.epoch == 2
assert timestamp.batch == 3
assert timestamp.batch_in_epoch == 0
assert timestamp.sample == 20
assert timestamp.sample_in_epoch == 0
assert timestamp.token == 21
assert timestamp.token_in_epoch == 0
assert timestamp.total_wct == datetime.timedelta(seconds=30)
assert timestamp.epoch_wct == datetime.timedelta(seconds=20)
assert timestamp.batch_wct == datetime.timedelta(seconds=10)

Expand Down

0 comments on commit 9784717

Please sign in to comment.