Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add epochs to levanter #768

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions config/llama_7b_with_olmo_config.yaml
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ trainer:
project: "marin"
tags: ["dolma", "olmo", "llama"]

checkpointer:
keep:
- every: 250

mp: p=f32,c=bfloat16
train_batch_size: 2048
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
Expand All @@ -27,3 +31,5 @@ optimizer:
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 0.01

data_shuffle: true
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ def _prepare_example(ex: dict) -> LmExample:
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
if config.mask_inputs:
loss_mask = hax.arange(Pos) >= ex["source_lens"]
loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1?

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
# to not predict EOS token since we don't have target!
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
else:
loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32)
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ dependencies = [
"pydantic<3",
"rich~=13.0",
"filelock~=3.13",
# "ai2-olmo",
"async-lru~=2.0",
"tqdm-loggable>=0.2",
"deepdiff"
Expand Down
44 changes: 44 additions & 0 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,55 @@
from levanter.utils import flop_utils
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.visualization import compute_and_visualize_log_probs as viz_probs
from levanter.data.text import TokenSeqEpochDataset
from concurrent.futures import ThreadPoolExecutor



logger = pylogging.getLogger(__name__)


def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size):
total_tokens = None

def log_epoch(step_info: StepInfo):
nonlocal total_tokens
if total_tokens is None:
if not total_tokens_future.done():
return # We don't have the total tokens yet, so we can't calculate epoch
dlwh marked this conversation as resolved.
Show resolved Hide resolved
total_tokens = total_tokens_future.result()

# Get the total processed tokens from the metrics logged by log_performance_stats
processed_tokens = tokens_per_example * batch_size * step_info.step
if processed_tokens is None:
return # No token count available yet

current_epoch = processed_tokens / total_tokens
levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step)

return log_epoch

def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int):
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
def log_length():
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
# If ds.async_len() is the only option, run it in an event loop inside the thread
import asyncio

async def compute_length():
length = await ds.async_len()
return length

# Run the async function synchronously in this thread
length = asyncio.run(compute_length())
total_tokens = length * seq_length
levanter.tracker.log_summary({"dataset/total_tokens": total_tokens})
return total_tokens

# Create a ThreadPoolExecutor with a single worker thread
executor = ThreadPoolExecutor(max_workers=1)
# Submit the log_length function to be executed in a separate thread
future = executor.submit(log_length)
return future

def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None):
total_loss = 0.0
total_load_time = 0.0
Expand Down
69 changes: 67 additions & 2 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,57 @@

DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index

class TokenSeqEpochDataset(AsyncDataset[np.ndarray]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just make EpochDataset that wraps an arbitrary dataset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chatgpt and i made this

from typing import Sequence, Optional, TypeVar
import asyncio
import numpy as np

T_co = TypeVar('T_co', covariant=True)

class EpochDataset(AsyncDataset[T_co]):
    """
    A dataset that wraps another dataset, providing infinite epochs by recycling indices.
    If `max_epochs` is specified, it limits the number of cycles before raising StopIteration.
    
    :param dataset: The dataset to wrap.
    :param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely.
    """
    def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None):
        if dataset.is_finite():
            raise ValueError("Cannot apply epoching to a finite dataset.")

        self.dataset = dataset
        self.max_epochs = max_epochs

    async def async_len(self) -> int:
        if self.max_epochs is None:
            raise ValueError("Cannot determine length of an infinite dataset without max_epochs.")
        # Return the total number of samples: max_epochs * length of the dataset
        return self.max_epochs * await self.dataset.async_len()

    async def final_length_is_known(self) -> bool:
        return await self.dataset.final_length_is_known()

    def is_finite(self) -> bool:
        # EpochDataset can be finite if max_epochs is set.
        return self.max_epochs is not None

    async def current_len(self) -> Optional[int]:
        # If max_epochs is None, the dataset is effectively infinite.
        if self.max_epochs is None:
            return None

        # If the final length of the dataset is not known, return the current length of the underlying dataset.
        if not await self.dataset.final_length_is_known():
            return await self.dataset.current_len()

        # If the final length is known, return the max_epochs * async_len of the dataset.
        return self.max_epochs * await self.dataset.async_len()

    async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
        # Use self.wait_until_len_at_least to ensure we have enough data for the batch.
        max_index = max(indices)
        ds_len = await self.wait_until_len_at_least(max_index + 1)

        # Determine the epoch based on the largest index
        epoch = max_index // ds_len

        # If max_epochs is specified, raise an error if the epoch exceeds the allowed number of epochs
        if self.max_epochs is not None and epoch >= self.max_epochs:
            raise StopIteration(f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}")

        # Wrap the indices within the bounds of the dataset length
        wrapped_indices = [idx % ds_len for idx in indices]

        # Delegate to the underlying dataset's get_batch
        return await self.dataset.get_batch(wrapped_indices)

    async def wait_until_len_at_least(self, length: int) -> int:
        """
        Returns the length of the dataset once it is at least `length` or if the dataset has a known (finished) length.

        If the dataset's actual length is less than `length`, it returns the minimum of async_len and the current length.
        """
        # Wait until the underlying dataset's length is at least `length`
        if not self.is_finite(): return length
        
        if await self.dataset.final_length_is_known(): 
            base_length = await self.dataset.async_len()
        else:
            base_length = await self.dataset.wait_until_len_at_least(length)

        if base_length < length:  # hit epoch boundary
            assert self.max_epochs is not None
            return self.max_epochs * base_length

        return base_length

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I removed the "cannot apply epoching to a finite dataset" since that seems like a bug

def __init__(self, doc_cache: TreeCache[dict], seq_len: int):
self.doc_cache = doc_cache
self.seq_len = seq_len
self._store: Optional[TreeStore] = None
self._cached_len: Optional[int] = None

async def async_len(self) -> int:
await self.doc_cache.finished()
token_arrays = await self._await_token_cache()
return token_arrays.data_size // self.seq_len

async def _await_token_cache(self) -> JaggedArrayStore:
if self._store is None:
self._store = await self.doc_cache.store_async()
return self._store.tree["input_ids"]

async def final_length_is_known(self) -> bool:
return await self.doc_cache.final_length_is_known()

def is_finite(self) -> bool:
return False # Now infinite due to epoch wrapping

async def current_len(self) -> Optional[int]:
store = await self._await_token_cache()
return store.data_size // self.seq_len

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
token_arrays = await self._await_token_cache()
dataset_len = await self.async_len()

wrapped_indices = [idx % dataset_len for idx in indices]
offsets = np.array(wrapped_indices) * self.seq_len

with ts.Batch():
out = []
for offset in offsets:
out.append(token_arrays.data[offset : offset + self.seq_len].read())

out = await asyncio.gather(*out)
return out

async def wait_until_len_at_least(self, length: int) -> int:
# length is brutally slow to compute, so we cache it
if self._cached_len is not None:
return self._cached_len

# TODO: would be better to listen for cache updates
length = await super().wait_until_len_at_least(length)
self._cached_len = length
return length

class TokenSeqDataset(AsyncDataset[np.ndarray]):
"""
Expand Down Expand Up @@ -640,9 +691,15 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
cache_dir: Optional[str] = "cache/"

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False
) -> AsyncDataset[np.ndarray]:
ds = self.token_seq_dataset("train", seq_len, monitors)

if epochs:
ds = self.token_epoch_dataset("train", seq_len, monitors)
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
else:
ds = self.token_seq_dataset("train", seq_len, monitors)

# add epoch flag here.
if ds is None:
raise ValueError("No training set!")

Expand Down Expand Up @@ -693,6 +750,14 @@ def token_seq_dataset(
if cache is None:
return None
return TokenSeqDataset(cache, seq_len)

def token_epoch_dataset(
self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
) -> Optional[TokenSeqDataset]:
cache = self.build_or_load_cache(split, monitors=monitors)
if cache is None:
return None
return TokenSeqEpochDataset(cache, seq_len)

def build_or_load_cache(
self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None
Expand Down
11 changes: 10 additions & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TrainLmConfig:
data_seed: Optional[int] = None # if provided, will override the data seed from the trainer
initialize_from_checkpoint_path: Optional[str] = None
# if provided, will initialize from this checkpoint, used for llama style data mixture
epoch: bool = False # if true, will keep epoching over the dataset and track epochs
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved


def main(config: TrainLmConfig):
Expand Down Expand Up @@ -117,10 +118,17 @@ def main(config: TrainLmConfig):

# TODO: fix this
tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size)
# TokenSeqDataset is config.data.train_set(Pos.size, key=data_key)

train_dataset = CausalLmDataset(
config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, ignore_index=config.data.ignore_token_id
config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), Pos, KeyPos, ignore_index=config.data.ignore_token_id
)

if config.epoch:
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
# add epoch logging
total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len)
trainer.add_hook(callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1)

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
# tokens: gpt-2 has 50257, for example. So we round up.
Expand Down Expand Up @@ -236,6 +244,7 @@ def compute_log_probs(model, example):

## OK, actually run training!
trainer.train(state, train_loader)

# checkpointer.on_step(last_step, force=True)


Expand Down
1 change: 0 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ def training_steps(self, state: S, train_loader, run_hooks: bool = True) -> typi
while int(state.step) < self.num_train_steps:
with capture_time() as loading_time:
example = next(iter_data)

info = self.train_step(state, example)
state = info.state

Expand Down
Loading