Skip to content

Commit

Permalink
Work on test
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Sep 2, 2024
1 parent 797521c commit 51160c8
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,16 +570,20 @@ def test_resize_embedding(self):

with static_seed_patcher:
orig_model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME)
orig_model.tie_weights()
orig_model.eval()
vocab_size = orig_model.config.vocab_size
new_vocab_size = (vocab_size // tp_size) * (tp_size + 1)
with static_seed_patcher:
orig_model.resize_token_embeddings(new_vocab_size)

with lazy_load_for_parallelism(tensor_parallel_size=tp_size):
with static_seed_patcher:
model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME)
model.eval()
with static_seed_patcher:
model.resize_token_embeddings(new_vocab_size)
model.tie_weights()
accelerator = create_accelerator(
tp_size,
1,
Expand All @@ -588,6 +592,9 @@ def test_resize_embedding(self):
)
with static_seed_patcher:
model = accelerator.prepare_model(model)

xm.master_print(orig_model.lm_head.weight)
xm.master_print(model.lm_head.weight)
gathered = [torch.empty_like(model.model.embed_tokens.weight) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, model.model.embed_tokens.weight, group=tp_group)
gathered_embedding = torch.cat(gathered, dim=0)
Expand All @@ -602,8 +609,6 @@ def test_resize_embedding(self):
gathered = [torch.empty_like(logits) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, logits, group=tp_group)
gathered_logits = torch.cat(gathered, dim=2)
xm.master_print(logits)
xm.master_print(gathered_logits)
torch.testing.assert_close(orig_logits, gathered_logits.to("cpu"))


Expand Down

0 comments on commit 51160c8

Please sign in to comment.