diff --git a/litgpt/data/dolly.py b/litgpt/data/dolly.py index 1e0789fae2..98cd4a9054 100644 --- a/litgpt/data/dolly.py +++ b/litgpt/data/dolly.py @@ -11,7 +11,9 @@ from litgpt.prompts import PromptStyle from litgpt.data import Alpaca, SFTDataset -_URL: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl" +_URL: str = ( + "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl" +) @dataclass @@ -70,7 +72,8 @@ def setup(self, stage: str = "") -> None: ) +# TODO: break test with old behavior def _transform(item: dict) -> dict: - item["input"] = item.pop("context") - item["output"] = item.pop("response") + item["input"] = item.get("context", "") + item["output"] = item.get("response", "") return item diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index ba2ec24d95..dd6f92d8e7 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -47,7 +47,9 @@ def setup( checkpoint_dir: Path, out_dir: Path = Path("out/finetune/lora"), precision: Optional[str] = None, - quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, + quantize: Optional[ + Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"] + ] = None, devices: Union[int, str] = 1, num_nodes: int = 1, lora_r: int = 8, @@ -123,7 +125,9 @@ def setup( ) precision = precision or get_default_supported_precision(training=True) - logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval) + logger = choose_logger( + logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval + ) plugins = None if quantize is not None and quantize.startswith("bnb."): @@ -134,7 +138,9 @@ def setup( "LitGPT only supports bitsandbytes v0.42.0. " "This may result in errors when using quantization." ) - dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[ + precision + ] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None @@ -166,7 +172,9 @@ def setup( if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) - fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) + fabric.launch( + main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer + ) def main( @@ -199,7 +207,9 @@ def main( mark_only_lora_as_trainable(model) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") - fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") + fabric.print( + f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}" + ) model = fabric.setup_module(model) @@ -209,7 +219,9 @@ def main( optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) - scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) + scheduler = get_lr_scheduler( + optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps + ) # strict=False because missing keys due to LoRA weights not contained in state dict load_checkpoint(fabric, model, checkpoint_path, strict=False) @@ -235,10 +247,14 @@ def main( # Final evaluation if eval.final_validation: - val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) + val_loss = validate( + fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)) + ) metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics) - fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}") + fabric.print( + f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}" + ) # Save the final LoRA checkpoint at the end of training save_path = out_dir / "final" / "lit_model.pth.lora" @@ -267,7 +283,9 @@ def fit( data: DataModule, ) -> None: tokenizer = Tokenizer(checkpoint_dir) - longest_seq_length, longest_seq_ix = get_longest_seq_length(ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])) + longest_seq_length, longest_seq_ix = get_longest_seq_length( + ConcatDataset([train_dataloader.dataset, val_dataloader.dataset]) + ) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" @@ -275,18 +293,22 @@ def fit( ) if eval.initial_validation: - val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) + val_loss = validate( + fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)) + ) val_loss = f"{val_loss:.3f}" else: fabric.print("Verifying settings ...") - validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check + validate( + fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False + ) # sanity check val_loss = "n/a" train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) - running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to( - fabric.device - ) + running_loss = RunningMean( + window=train.gradient_accumulation_iters(devices), sync_on_compute=False + ).to(fabric.device) max_steps = train.max_steps or float("inf") step_count = 0 iter_num = 0 @@ -320,7 +342,10 @@ def fit( loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( - time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths + time=t1 - total_t0, + batches=iter_num, + samples=iter_num * train.micro_batch_size, + lengths=total_lengths, ) throughput.compute_and_log(step=iter_num) metrics = { @@ -330,7 +355,9 @@ def fit( "epoch": train_iterator.epoch, "iter_time": t1 - iter_t0, "tokens": iter_num * train.micro_batch_size * model.config.block_size, - "total_tokens": (iter_num * train.micro_batch_size * model.config.block_size * fabric.world_size), + "total_tokens": ( + iter_num * train.micro_batch_size * model.config.block_size * fabric.world_size + ), "learning_rate": scheduler.get_last_lr()[0], } if isinstance(val_loss, torch.Tensor): @@ -349,12 +376,18 @@ def fit( val_loss = validate(fabric, model, val_dataloader, eval) generate_example(fabric, model, tokenizer, eval, data) t1 = time.perf_counter() - t0 - fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms") + fabric.print( + f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms\n" + ) metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics, step=iter_num) fabric.barrier() - if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0: + if ( + train.save_interval is not None + and not is_accumulating + and step_count % train.save_interval == 0 + ): checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, checkpoint_file) @@ -366,7 +399,9 @@ def fit( # FSDP has issues with `inference_mode` @torch.no_grad() -def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True) -> torch.Tensor: +def validate( + fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True +) -> torch.Tensor: if verbose: fabric.print("Validating ...") model.eval() @@ -385,9 +420,11 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva @torch.no_grad() -def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): +def generate_example( + fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule +): + fabric.print("Generating sample ...") instruction = "Recommend a movie for me to watch during the weekend and explain the reason." - fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) model.eval() @@ -399,12 +436,16 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( - model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model, + encoded, + max_returned_tokens=max_returned_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) model.clear_kv_cache() model.train() output = tokenizer.decode(output) - fabric.print(output) + fabric.print(f"{output}\n") else: print( f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) " @@ -416,14 +457,20 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) - scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) - return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) + scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=(max_steps - warmup_steps) + ) + return torch.optim.lr_scheduler.SequentialLR( + optimizer, [scheduler1, scheduler2], milestones=[warmup_steps] + ) def get_dataloaders( fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs ) -> Tuple[DataLoader, DataLoader]: - data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + data.connect( + tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length + ) with fabric.rank_zero_first(): data.prepare_data() data.setup() @@ -452,13 +499,17 @@ def validate_args(train: TrainArgs, eval: EvalArgs) -> None: for args, names in unsupported: for name in names: if getattr(args, name) is not None: - issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") + issues.append( + f"{__file__} doesn't support the {name!r} argument. This is set in {args}" + ) required = [(train, ["epochs"]), (eval, ["max_new_tokens"])] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") if not train.epochs and not train.max_steps: - issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") + issues.append( + f"{__file__} requires either epochs or max_steps to be set. This is set in {train}" + ) if issues: raise ValueError("\n".join(issues)) diff --git a/tests/data/test_dolly.py b/tests/data/test_dolly.py index 8371c98fe4..52a5cebc46 100644 --- a/tests/data/test_dolly.py +++ b/tests/data/test_dolly.py @@ -5,7 +5,12 @@ def test_dolly(mock_tokenizer, dolly_path): - dolly = Dolly(val_split_fraction=0.5, download_dir=dolly_path.parent, file_name=dolly_path.name, num_workers=0) + dolly = Dolly( + val_split_fraction=0.5, + download_dir=dolly_path.parent, + file_name=dolly_path.name, + num_workers=0, + ) assert isinstance(dolly.prompt_style, AlpacaPromptStyle) dolly.connect(mock_tokenizer, batch_size=2, max_seq_length=10) dolly.prepare_data() @@ -29,3 +34,55 @@ def test_dolly(mock_tokenizer, dolly_path): # has attributes from super class `LightningDataModule` assert dolly.prepare_data_per_node + + +def test_dolly_missing_keys(mock_tokenizer, dolly_path): + """ + Notes + ----- + - Added only for the dolly dataset. + + References + ---------- + - Reference issue: https://github.com/Lightning-AI/litgpt/issues/1760 + + Methodology + ----------- + - Simulate the original behavior by popping `context` key. + - Run dataloader which will apply `transform`. + - Previously it would have thrown missing `context` key error because we `popped` the key. + - Now we are using `get` method to not remove they key(s). + """ + + dolly = Dolly( + val_split_fraction=0.5, + download_dir=dolly_path.parent, + file_name=dolly_path.name, + num_workers=0, + ) + dolly.connect(mock_tokenizer, batch_size=2, max_seq_length=10) + dolly.prepare_data() + dolly.setup() + + # check if the dataset was created without errors + assert dolly.train_dataset is not None + assert dolly.test_dataset is not None + + # Verify that the transform function handled missing keys correctly + for dataset in [dolly.train_dataset, dolly.test_dataset]: + for item in dataset.data: + assert "context" in item + assert "response" in item + assert isinstance(item["context"], str) + assert isinstance(item["response"], str) + # Drop `context` and `response` keys + # This is to simulate the behavior of original issue with `item.pop` + item.pop("context") + item.pop("response") + + # Check if we can iterate through the dataloader without errors + # Previous approach would through key error here since we already popped the keys + train_dataloader = dolly.train_dataloader() + train_batch = next(iter(train_dataloader)) + assert "input_ids" in train_batch + assert "labels" in train_batch