Skip to content

Commit

Permalink
Work on test
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Aug 30, 2024
1 parent b890d11 commit 797521c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 10 deletions.
6 changes: 6 additions & 0 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def prepare_model(
model = self._prepare_model_for_mp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
)
xm.master_print(model)
if should_apply_activation_checkpointing:
apply_activation_checkpointing(model)
else:
Expand All @@ -491,6 +492,11 @@ def prepare_model(
device_placement = False
model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
xm.mark_step()
xm.master_print(model)
# xm.master_print(model.lm_head.base_layer.input_size, model.lm_head.base_layer.output_size, model.lm_head.base_layer.output_size_per_partition)
# xm.master_print(model.lm_head.lora_B["default"].input_size, model.lm_head.lora_B["default"].output_size, model.lm_head.lora_B["default"].output_size_per_partition)
# print(model.lm_head.base_layer(torch.randn((4096, ), device="xla")).shape, xm.get_ordinal())
# exit()
return model

def backward(self, loss, **kwargs):
Expand Down
8 changes: 3 additions & 5 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,9 +1406,9 @@ def wrapper(
weight_map=self._weight_map,
)
setattr(embeddings, "weight", torch.nn.Parameter(load_tensor_for_weight(embeddings_weight_info)))
self._weight_map.pop(embeddings_qualified_name)
else:
self._init_weights(embeddings)
self._weight_map.pop(embeddings_qualified_name)

if lm_head is not None and lm_head.weight.device == torch.device("meta"):
lm_head_qualified_name = param2name[lm_head.weight]
Expand All @@ -1421,6 +1421,7 @@ def wrapper(
weight_map=self._weight_map,
)
setattr(lm_head, "weight", torch.nn.Parameter(load_tensor_for_weight(lm_head_weight_info)))
self._weight_map.pop(lm_head_qualified_name)

if lm_head.bias is not None:
lm_head_bias_qualified_name = param2name[lm_head.bias]
Expand All @@ -1431,13 +1432,10 @@ def wrapper(
weight_map=self._weight_map,
)
setattr(lm_head, "bias", torch.nn.Parameter(load_tensor_for_weight(lm_head_bias_weight_info)))
self._weight_map.pop(lm_head_bias_qualified_name)
else:
self._init_weights(lm_head)

self._weight_map.pop(lm_head_qualified_name)
if lm_head.bias is not None:
self._weight_map.pop(lm_head_bias_qualified_name)

return orig_resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)
new_embedding_shape = resized_embedding.weight.shape
if embedding_shape != new_embedding_shape:
Expand Down
51 changes: 48 additions & 3 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
import torch
import torch.utils._pytree as pytree
from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
Expand All @@ -44,7 +44,7 @@

import optimum
from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager
from optimum.neuron.distributed.utils import compute_query_indices_for_rank
from optimum.neuron.distributed.utils import compute_query_indices_for_rank, lazy_load_for_parallelism
from optimum.neuron.utils.cache_utils import (
get_num_neuron_cores,
)
Expand Down Expand Up @@ -228,7 +228,7 @@ def early_skip(self, fixtures_kwargs):
pp_size = parallel_sizes[-1]
model_specs = fixtures_kwargs.get("model_specs", None)

if pp_size > 1 and model_specs is not None:
if pp_size is not None and pp_size > 1 and model_specs is not None:
model_type = model_specs[0]
manager = ParallelizersManager.parallelizer_for_model(model_type)
if not manager.supports_pipeline_parallelism():
Expand Down Expand Up @@ -561,6 +561,51 @@ def test_llama_v2_gqa(
parallelize_embeddings,
)

@pytest.mark.parallel_sizes((8, 2, 1))
def test_resize_embedding(self):
tp_size = get_tensor_model_parallel_size()
tp_group = get_tensor_model_parallel_group()

static_seed_patcher = StaticSeedPatcher(42)

with static_seed_patcher:
orig_model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME)
orig_model.eval()
vocab_size = orig_model.config.vocab_size
new_vocab_size = (vocab_size // tp_size) * (tp_size + 1)
orig_model.resize_token_embeddings(new_vocab_size)

with lazy_load_for_parallelism(tensor_parallel_size=tp_size):
with static_seed_patcher:
model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME)
model.eval()
model.resize_token_embeddings(new_vocab_size)
accelerator = create_accelerator(
tp_size,
1,
parallelize_embeddings=True,
sequence_parallel_enabled=True,
)
with static_seed_patcher:
model = accelerator.prepare_model(model)
gathered = [torch.empty_like(model.model.embed_tokens.weight) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, model.model.embed_tokens.weight, group=tp_group)
gathered_embedding = torch.cat(gathered, dim=0)
xm.mark_step()
torch.testing.assert_close(orig_model.model.embed_tokens.weight, gathered_embedding.to("cpu"))

tok = AutoTokenizer.from_pretrained(LLAMA_V2_MODEL_NAME)
tok.pad_token = tok.eos_token
inputs = tok("This is a test", max_length=24, padding="max_length", return_tensors="pt")
orig_logits = orig_model(**inputs).logits
logits = model(**{k: v.to("xla") for k, v in inputs.items()}).logits
gathered = [torch.empty_like(logits) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, logits, group=tp_group)
gathered_logits = torch.cat(gathered, dim=2)
xm.master_print(logits)
xm.master_print(gathered_logits)
torch.testing.assert_close(orig_logits, gathered_logits.to("cpu"))


@pytest.mark.parametrize(
"tp_size,num_attention_heads,num_key_value_heads,kv_size_multiplier,ground_truth",
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,15 @@ def try_to_override_via_pytest_mark(mark, name):
world_size = try_to_override_via_pytest_mark(mark, "world_size")
tp_size = try_to_override_via_pytest_mark(mark, "tp_size")
pp_size = try_to_override_via_pytest_mark(mark, "pp_size")
parallel_sizes = try_to_override_via_pytest_mark(mark, "parallel_size")
parallel_sizes = try_to_override_via_pytest_mark(mark, "parallel_sizes")

# Catch world_size, tp_size or pp_size override via fixture.
def try_to_override_via_fixture(name, current_value):
if name in self._fixture_kwargs:
if current_value is not None:
raise ValueError(f"It is not possible to override {name} both via pytest.mark and fixtures.")
return self._fixture_kwargs[name]
return None
return current_value

world_size = try_to_override_via_fixture("world_size", world_size)
tp_size = try_to_override_via_fixture("tp_size", tp_size)
Expand Down

0 comments on commit 797521c

Please sign in to comment.