Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval harness #675

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
26 changes: 26 additions & 0 deletions config/gpt2_nano_harness.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
eval_harness:
task_spec: ["piqa", "hellaswag"]
max_examples: 32
eval_harness_steps: 50
data:
id: dlwh/wikitext_103_detokenized
model:
type: gpt2
hidden_dim: 32
num_heads: 4
num_layers: 2
trainer:
mp: f32
num_train_steps: 100

checkpointer:
keep:
- every: 50
save_interval: 5m

per_device_parallelism: -1
train_batch_size: 32

tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
4 changes: 4 additions & 0 deletions config/gpt2_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
eval_harness:
task_spec: ["lambada", "piqa", "hellaswag"]
max_examples: 32
eval_harness_steps: 1000
24 changes: 24 additions & 0 deletions config/harness/harness_nano.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
eval_harness:
task_spec: ["hellaswag"]
tokenizer: "gpt2"
model:
type: gpt2
hidden_dim: 32
num_heads: 4
num_layers: 2
trainer:
mp: f32
num_train_steps: 100
profiler: true

checkpointer:
keep:
- every: 50
save_interval: 5m

per_device_parallelism: -1
train_batch_size: 32

tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
175 changes: 175 additions & 0 deletions config/olmo/olmo_7b_repro.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#data: !include data/dolma_olmo_paloma.yaml
data:
cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7"
tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo`
# tokenizer: "meta-llama/Llama-2-7b-hf"
stop_strategy: restart
shuffle_buffer_size: 100000
configs:
dolma-algebraic-stack:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz
dolma-arxiv:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz
dolma-gutenberg:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz
dolma-c4:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz
dolma-cc:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz
dolma-cc-news:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz
dolma-falcon:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz
dolma-megawika:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz
dolma-owmath:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz
dolma-pes2o:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz
dolma-reddit:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz
dolma-stackexchange:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz
dolma-starcoder:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz
dolma-flan:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz
dolma-wiki:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz
# these are just for eval
"paloma/4chan":
validation_urls:
- gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz
"paloma/c4_100_domains":
validation_urls:
- gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz
"paloma/c4_en":
validation_urls:
- gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz
"paloma/dolma-v1_5":
validation_urls:
- gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz
"paloma/dolma_100_programing_languages":
validation_urls:
- gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz
"paloma/dolma_100_subreddits":
validation_urls:
- gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz
"paloma/falcon-refinedweb":
validation_urls:
- gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz
"paloma/gab":
validation_urls:
- gs://levanter-data/paloma/gab/val/val*.jsonl.gz
"paloma/m2d2_s2orc_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz
"paloma/m2d2_wikipedia_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz
"paloma/manosphere_meta_sep":
validation_urls:
- gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz
"paloma/mc4":
validation_urls:
- gs://levanter-data/paloma/mc4/val/val*.jsonl.gz
"paloma/ptb":
validation_urls:
- gs://levanter-data/paloma/ptb/val/val*.jsonl.gz
"paloma/redpajama":
validation_urls:
- gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz
"paloma/twitterAAE_HELM_fixed":
validation_urls:
- gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz
"paloma/wikitext_103":
validation_urls:
- gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz
train_weights:
# sampling proportion comes from https://huggingface.co/datasets/allenai/dolma
dolma-algebraic-stack: 12.6 # 12.6 * 1.0
dolma-arxiv: 28.0 # 28.0 * 1.0
dolma-gutenberg: 5.3 # 5.3 * 1.0
dolma-c4: 69.2 # 138.4 * 0.5
dolma-cc: 597.75 # 1,195.5 * 0.5
dolma-cc-news: 14.3 # 1.0
dolma-falcon: 456.4 # 1.0, refined web
dolma-megawika: 4.6 # 1.0
dolma-owmath: 12.6 # 1.0
dolma-pes2o: 57.2 # 1.0
dolma-reddit: 79.9 # 1.0
dolma-stackexchange: 19.6 # 1.0
dolma-starcoder: 263.8 # 1.0
dolma-flan: 16.5 # 6.5 * 1.0
dolma-wiki: 7.4 # 3.7 * 2.0
paloma/4chan: 0.0
paloma/c4_100_domains: 0.0
paloma/c4_en: 0.0
paloma/dolma-v1_5: 0.0
paloma/dolma_100_programing_languages: 0.0
paloma/dolma_100_subreddits: 0.0
paloma/falcon-refinedweb: 0.0
paloma/gab: 0.0
paloma/m2d2_s2orc_unsplit: 0.0
paloma/m2d2_wikipedia_unsplit: 0.0
paloma/manosphere_meta_sep: 0.0
paloma/mc4: 0.0
paloma/ptb: 0.0
paloma/redpajama: 0.0
paloma/twitterAAE_HELM_fixed: 0.0
paloma/wikitext_103: 0.0
model: # 7B class model
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
# flash_attention_block_size: 1024

use_bias: false
use_layer_norm_weight: false
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 2048 # olmo actually uses 2160 table 5 of https://arxiv.org/pdf/2402.00838
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
replica_dcn_axis_size: 2
optimizer:
learning_rate: 3E-4
weight_decay: 0.1
min_lr_ratio: 0.1
beta1: 0.9
beta2: 0.95
warmup: 2000
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies = [
"pydantic<3", # temporary pin until Ray supports pydantic 2.0
"rich~=13.0",
"filelock~=3.13",
"lm-eval==0.4.2"
]

[tool.hatch.build]
Expand Down
44 changes: 44 additions & 0 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from tqdm import tqdm

import levanter.tracker
from levanter.eval_harness import LmEvalHarnessConfig
from levanter.logging import save_xla_dumps_to_wandb
from levanter.tracker.helpers import log_optimizer_hyperparams
from levanter.tracker.wandb import WandbConfig
from levanter.trainer import StepInfo
from levanter.utils import flop_utils
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.utils.tree_utils import inference_mode
from levanter.visualization import compute_and_visualize_log_probs as viz_probs


Expand Down Expand Up @@ -352,3 +354,45 @@ def compute_and_viz_log_probs(step: StepInfo):
wandb.log({"log_probs": wandb.Html(path)}, step=step.step)

return compute_and_viz_log_probs


def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources):
from levanter.eval_harness import run_lm_eval_harness

def lm_eval_harness(step: StepInfo, force=False):
if step.step == 0 and not force:
return # don't run eval on the first step

model = inference_mode(step.model, True)
outputs = run_lm_eval_harness(
model,
config.task_spec_or_default(),
tokenizer,
EvalBatch,
axis_resources,
max_examples=config.max_examples,
)

if jax.process_index() == 0:
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
import json

json.dump(outputs, f)
levanter.tracker.current_tracker().log_artifact(
f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output"
)

# also log accuracy statistics etc
metrics_to_log = {}
for task, metrics in outputs["results"].items():
for metric, value in metrics.items():
if metric.endswith(",none"):
metric = metric[: -len(",none")]

if metric != "alias":
# levanter.tracker.log_metrics({f"lm_eval/{task}/{metric}": value}, step=step.step)
metrics_to_log[f"lm_eval/{task}/{metric}"] = value

levanter.tracker.log_metrics(metrics_to_log, step=step.step)

return lm_eval_harness
8 changes: 6 additions & 2 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,15 @@ def load_checkpoint(
logger.warning("Loading checkpoint in jit. This is not recommended and probably won't work.")

if discover_latest:
checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore
discovered_checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore
else:
discovered_checkpoint_path = checkpoint_path

if checkpoint_path is None or not fs.exists(checkpoint_path):
if discovered_checkpoint_path is None or not fs.exists(discovered_checkpoint_path):
raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}")

checkpoint_path = discovered_checkpoint_path

logger.info(f"Loading checkpoint from {checkpoint_path}")
metadata = load_metadata(checkpoint_path, fs)

Expand Down
15 changes: 2 additions & 13 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import functools
import logging
from collections import defaultdict
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
Expand All @@ -21,6 +20,7 @@
from levanter.mesh import local_devices_mapping, process_mesh_mapping
from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape
from levanter.utils.background_iterable import BackgroundIterable
from levanter.utils.jax_utils import stack_tree
from levanter.utils.py_utils import non_caching_cycle


Expand Down Expand Up @@ -89,7 +89,7 @@ def get_local_batch(begin: int, end: int) -> List[Array]:

individual_datums = get_batch_items(begin, end)

device_batch = _stack_tree(self.Batch.name, individual_datums)
device_batch = stack_tree(self.Batch, individual_datums, pad_to_batch_size=False)
batch_leaves = jtu.tree_leaves(device_batch)

stacked_local_batch[key] = batch_leaves
Expand Down Expand Up @@ -226,17 +226,6 @@ def local_batch_size(self) -> int:
return self.batch_size // self.num_data_process_groups


@functools.partial(jax.jit, static_argnums=(0,))
def _stack_tree(batch_name, individual_datums):
def _stack_leaves_unchecked(*leaves):
if is_named_array(leaves[0]):
return hax.stack(batch_name, leaves)
else:
return jnp.stack(leaves)

return jax.tree_map(_stack_leaves_unchecked, *individual_datums, is_leaf=is_named_array)


class ReplicatedBatchLoader(BatchLoader[Ex]):
"""A batch loader that creates batches without sharded data loading. All examples are loaded on all machines and then
sharded. This is useful if you have a small dataset and want to make a single pass over it.
Expand Down
Loading
Loading