Skip to content

Commit

Permalink
Add bnb.nn.StableEmbedding for quantized training
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Oct 4, 2024
1 parent c03f3f0 commit 461c33d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
6 changes: 6 additions & 0 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ def main(

if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())

from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.wte.weight.copy_(old_embedding.weight)
else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

Expand Down
6 changes: 6 additions & 0 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ def main(

if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())

from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.wte.weight.copy_(old_embedding.weight)
else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

Expand Down
7 changes: 7 additions & 0 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ def main(

if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())

from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.wte.weight.copy_(old_embedding.weight)

else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

Expand Down

0 comments on commit 461c33d

Please sign in to comment.