Skip to content

Commit

Permalink
revert style changes, improve print
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Oct 4, 2024
1 parent 981d041 commit 29ece59
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 84 deletions.
4 changes: 1 addition & 3 deletions litgpt/data/dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from litgpt.prompts import PromptStyle
from litgpt.data import Alpaca, SFTDataset

_URL: str = (
"https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
)
_URL: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
model.clear_kv_cache()
model.train()
output = tokenizer.decode(output)
fabric.print(output)
fabric.print(f"{output}\n")
else:
print(
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "
Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
model.clear_kv_cache()
model.train()
output = tokenizer.decode(output)
fabric.print(output)
fabric.print(f"{output}\n")
else:
print(
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "
Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
model.clear_kv_cache()
model.train()
output = tokenizer.decode(output)
fabric.print(output)
fabric.print(f"{output}\n")
else:
print(
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "
Expand Down
105 changes: 27 additions & 78 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def setup(
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/lora"),
precision: Optional[str] = None,
quantize: Optional[
Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]
] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: Union[int, str] = 1,
num_nodes: int = 1,
lora_r: int = 8,
Expand Down Expand Up @@ -125,9 +123,7 @@ def setup(
)

precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(
logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval
)
logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval)

plugins = None
if quantize is not None and quantize.startswith("bnb."):
Expand All @@ -138,9 +134,7 @@ def setup(
"LitGPT only supports bitsandbytes v0.42.0. "
"This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[
precision
]
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None

Expand Down Expand Up @@ -172,9 +166,7 @@ def setup(
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)

fabric.launch(
main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer
)
fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)


def main(
Expand Down Expand Up @@ -207,9 +199,7 @@ def main(
mark_only_lora_as_trainable(model)

fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
fabric.print(
f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}"
)
fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")

model = fabric.setup_module(model)
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
Expand All @@ -225,9 +215,7 @@ def main(
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(
optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps
)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)

# strict=False because missing keys due to LoRA weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)
Expand All @@ -253,14 +241,10 @@ def main(

# Final evaluation
if eval.final_validation:
val_loss = validate(
fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))
)
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics)
fabric.print(
f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}"
)
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")

# Save the final LoRA checkpoint at the end of training
save_path = out_dir / "final" / "lit_model.pth.lora"
Expand Down Expand Up @@ -289,32 +273,26 @@ def fit(
data: DataModule,
) -> None:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(
ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])
)
longest_seq_length, longest_seq_ix = get_longest_seq_length(ConcatDataset([train_dataloader.dataset, val_dataloader.dataset]))
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)

if eval.initial_validation:
val_loss = validate(
fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))
)
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
val_loss = f"{val_loss:.3f}"
else:
fabric.print("Verifying settings ...")
validate(
fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False
) # sanity check
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check
val_loss = "n/a"

train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(
window=train.gradient_accumulation_iters(devices), sync_on_compute=False
).to(fabric.device)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
fabric.device
)
max_steps = train.max_steps or float("inf")
step_count = 0
iter_num = 0
Expand Down Expand Up @@ -348,10 +326,7 @@ def fit(
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0,
batches=iter_num,
samples=iter_num * train.micro_batch_size,
lengths=total_lengths,
time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths
)
throughput.compute_and_log(step=iter_num)
metrics = {
Expand All @@ -361,9 +336,7 @@ def fit(
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"tokens": iter_num * train.micro_batch_size * model.config.block_size,
"total_tokens": (
iter_num * train.micro_batch_size * model.config.block_size * fabric.world_size
),
"total_tokens": (iter_num * train.micro_batch_size * model.config.block_size * fabric.world_size),
"learning_rate": scheduler.get_last_lr()[0],
}
if isinstance(val_loss, torch.Tensor):
Expand All @@ -382,18 +355,12 @@ def fit(
val_loss = validate(fabric, model, val_dataloader, eval)
generate_example(fabric, model, tokenizer, eval, data)
t1 = time.perf_counter() - t0
fabric.print(
f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms\n"
)
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics, step=iter_num)
fabric.barrier()

if (
train.save_interval is not None
and not is_accumulating
and step_count % train.save_interval == 0
):
if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, checkpoint_file)
Expand All @@ -405,9 +372,7 @@ def fit(

# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True
) -> torch.Tensor:
def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True) -> torch.Tensor:
if verbose:
fabric.print("Validating ...")
model.eval()
Expand All @@ -426,11 +391,9 @@ def validate(


@torch.no_grad()
def generate_example(
fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule
):
fabric.print("Generating sample ...")
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
model.eval()
Expand All @@ -442,11 +405,7 @@ def generate_example(
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
output = generate(
model,
encoded,
max_returned_tokens=max_returned_tokens,
temperature=0.8,
eos_id=tokenizer.eos_id,
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
)
model.clear_kv_cache()
model.train()
Expand All @@ -463,20 +422,14 @@ def generate_example(
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=(max_steps - warmup_steps)
)
return torch.optim.lr_scheduler.SequentialLR(
optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]
)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])


def get_dataloaders(
fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs
) -> Tuple[DataLoader, DataLoader]:
data.connect(
tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length
)
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
Expand Down Expand Up @@ -505,17 +458,13 @@ def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(
f"{__file__} doesn't support the {name!r} argument. This is set in {args}"
)
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [(train, ["epochs"]), (eval, ["max_new_tokens"])]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if not train.epochs and not train.max_steps:
issues.append(
f"{__file__} requires either epochs or max_steps to be set. This is set in {train}"
)
issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}")
if issues:
raise ValueError("\n".join(issues))

0 comments on commit 29ece59

Please sign in to comment.