-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add epochs to levanter #768
base: main
Are you sure you want to change the base?
Conversation
src/levanter/data/text.py
Outdated
@@ -63,6 +63,57 @@ | |||
|
|||
DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index | |||
|
|||
class TokenSeqEpochDataset(AsyncDataset[np.ndarray]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just make EpochDataset
that wraps an arbitrary dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chatgpt and i made this
from typing import Sequence, Optional, TypeVar
import asyncio
import numpy as np
T_co = TypeVar('T_co', covariant=True)
class EpochDataset(AsyncDataset[T_co]):
"""
A dataset that wraps another dataset, providing infinite epochs by recycling indices.
If `max_epochs` is specified, it limits the number of cycles before raising StopIteration.
:param dataset: The dataset to wrap.
:param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely.
"""
def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None):
if dataset.is_finite():
raise ValueError("Cannot apply epoching to a finite dataset.")
self.dataset = dataset
self.max_epochs = max_epochs
async def async_len(self) -> int:
if self.max_epochs is None:
raise ValueError("Cannot determine length of an infinite dataset without max_epochs.")
# Return the total number of samples: max_epochs * length of the dataset
return self.max_epochs * await self.dataset.async_len()
async def final_length_is_known(self) -> bool:
return await self.dataset.final_length_is_known()
def is_finite(self) -> bool:
# EpochDataset can be finite if max_epochs is set.
return self.max_epochs is not None
async def current_len(self) -> Optional[int]:
# If max_epochs is None, the dataset is effectively infinite.
if self.max_epochs is None:
return None
# If the final length of the dataset is not known, return the current length of the underlying dataset.
if not await self.dataset.final_length_is_known():
return await self.dataset.current_len()
# If the final length is known, return the max_epochs * async_len of the dataset.
return self.max_epochs * await self.dataset.async_len()
async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
# Use self.wait_until_len_at_least to ensure we have enough data for the batch.
max_index = max(indices)
ds_len = await self.wait_until_len_at_least(max_index + 1)
# Determine the epoch based on the largest index
epoch = max_index // ds_len
# If max_epochs is specified, raise an error if the epoch exceeds the allowed number of epochs
if self.max_epochs is not None and epoch >= self.max_epochs:
raise StopIteration(f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}")
# Wrap the indices within the bounds of the dataset length
wrapped_indices = [idx % ds_len for idx in indices]
# Delegate to the underlying dataset's get_batch
return await self.dataset.get_batch(wrapped_indices)
async def wait_until_len_at_least(self, length: int) -> int:
"""
Returns the length of the dataset once it is at least `length` or if the dataset has a known (finished) length.
If the dataset's actual length is less than `length`, it returns the minimum of async_len and the current length.
"""
# Wait until the underlying dataset's length is at least `length`
if not self.is_finite(): return length
if await self.dataset.final_length_is_known():
base_length = await self.dataset.async_len()
else:
base_length = await self.dataset.wait_until_len_at_least(length)
if base_length < length: # hit epoch boundary
assert self.max_epochs is not None
return self.max_epochs * base_length
return base_length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI I removed the "cannot apply epoching to a finite dataset" since that seems like a bug
@@ -27,6 +27,7 @@ | |||
|
|||
from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore | |||
from levanter.types import FilterSpec | |||
# from levanter.trainer import StepInfo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rm
return # Can't calculate epochs without dataset size | ||
|
||
# Calculate current epoch from steps without modifying StepInfo | ||
current_epoch = (step_info.step * self.batch_size) // self.total_dataset_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably just be tracking this explicilty in stepinfo, but this is fine right now
@@ -27,6 +27,7 @@ | |||
|
|||
from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore | |||
from levanter.types import FilterSpec | |||
# from levanter.trainer import StepInfo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# from levanter.trainer import StepInfo |
adds epochs with a boolean flag, which will continue epoching over the dataset and tracks epochs throughout training. Should be backwards compatible with checkpoints.