diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 015b6f6523..388675fe57 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -247,6 +247,7 @@ def fit( while state["step_count"] < max_steps: state["iter_num"] += 1 iter_t0 = time.perf_counter() + batch = next(train_iterator) if train_iterator.epoch >= train.epochs: break input_ids, targets = batch["input_ids"], batch["labels"]