From ae798abaa565418dd8af70d9284ca3a2e8f98019 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 4 Oct 2024 09:44:16 -0500 Subject: [PATCH] Add bnb.nn.StableEmbedding for quantized training (#1770) --- litgpt/finetune/adapter.py | 8 +++++++- litgpt/finetune/adapter_v2.py | 8 +++++++- litgpt/finetune/lora.py | 8 +++++++- tests/test_adapter.py | 2 ++ tests/test_adapter_v2.py | 2 ++ tests/test_lora.py | 3 +++ 6 files changed, 28 insertions(+), 3 deletions(-) diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index f479cade8a..c498d903a2 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -172,9 +172,15 @@ def main( fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") model = fabric.setup_module(model) - if isinstance(fabric.strategy.precision, BitsandbytesPrecision): optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) + + from bitsandbytes.nn import StableEmbedding + old_embedding = model.transformer.wte + model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim) + with torch.no_grad(): + model.transformer.wte.weight.copy_(old_embedding.weight) + model.transformer.wte = model.transformer.wte.to(device=old_embedding.weight.device, dtype=old_embedding.weight.dtype) else: optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index a122da2f69..8e5b4c40c9 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -172,9 +172,15 @@ def main( fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") model = fabric.setup_module(model) - if isinstance(fabric.strategy.precision, BitsandbytesPrecision): optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) + + from bitsandbytes.nn import StableEmbedding + old_embedding = model.transformer.wte + model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim) + with torch.no_grad(): + model.transformer.wte.weight.copy_(old_embedding.weight) + model.transformer.wte = model.transformer.wte.to(device=old_embedding.weight.device, dtype=old_embedding.weight.dtype) else: optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index ba2ec24d95..8632cd9988 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -202,9 +202,15 @@ def main( fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") model = fabric.setup_module(model) - if isinstance(fabric.strategy.precision, BitsandbytesPrecision): optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) + + from bitsandbytes.nn import StableEmbedding + old_embedding = model.transformer.wte + model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim) + with torch.no_grad(): + model.transformer.wte.weight.copy_(old_embedding.weight) + model.transformer.wte = model.transformer.wte.to(device=old_embedding.weight.device, dtype=old_embedding.weight.dtype) else: optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index df9cc8a6f7..979ffa9228 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -182,6 +182,8 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca assert dtype_to_name == { "torch.float16": { "transformer.wte.weight", + "transformer.wte.norm.weight", + "transformer.wte.norm.bias", "transformer.h.0.norm_1.weight", "transformer.h.0.norm_1.bias", "transformer.h.0.attn.gating_factor", diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index f45f78065e..869d41b7a6 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -405,6 +405,8 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "transformer.h.1.mlp.fc.adapter_scale", "transformer.h.1.attn.attn.linear.bias", "transformer.wte.weight", + "transformer.wte.norm.weight", + "transformer.wte.norm.bias", "transformer.h.0.norm_2.weight", "transformer.h.1.mlp.proj.linear.bias", "transformer.h.0.attn.gating_factor", diff --git a/tests/test_lora.py b/tests/test_lora.py index 8fd7364e55..f61793e3a8 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -724,6 +724,7 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa args, kwargs = train_mock.call_args fabric, model, optimizer, *_ = args + model.transformer.wte = model.transformer.wte.half() assert isinstance(fabric.strategy.precision, BitsandbytesPrecision) assert isinstance(optimizer, _FabricOptimizer) assert isinstance(optimizer._optimizer, PagedAdamW) @@ -748,6 +749,8 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa "transformer.h.0.attn.attn.lora_B", "transformer.h.0.norm_2.weight", "transformer.wte.weight", + "transformer.wte.norm.weight", + "transformer.wte.norm.bias", "transformer.h.1.mlp.fc.linear.bias", "transformer.ln_f.bias", "transformer.h.1.attn.attn.lora_B",