Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support resize embeddings #670

Merged
merged 13 commits into from
Oct 2, 2024
3 changes: 3 additions & 0 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,9 @@ def should_parallelize_layer_predicate_func(layer):
cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device)
gc.collect()

# It is important to do that here because initialization can untie weights.
model.tie_weights()

# Because we initialize new parameters, we need to make sure that only the ones that required grads before
# parallelization require grad after parallelization.
for name, parameter in model.named_parameters():
Expand Down
13 changes: 7 additions & 6 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,13 @@ def _transform(
embedding_weight_name = f"{layer_qualified_name}.{embedding_name}.weight"
else:
embedding_weight_name = f"{embedding_name}.weight"
embedding_weight_info = WeightInformation(
weight_map[embedding_weight_name],
embedding_weight_name,
weight_map=weight_map,
device=device,
)
if embedding_name in weight_map:
embedding_weight_info = WeightInformation(
weight_map[embedding_weight_name],
embedding_weight_name,
weight_map=weight_map,
device=device,
)
if model_has_lm_head:
if layer_qualified_name:
lm_head_weight_name = f"{layer_qualified_name}.{lm_head_name}.weight"
Expand Down
57 changes: 57 additions & 0 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,60 @@ def parameter_can_be_initialized(model: torch.nn.Module, parent_module: torch.nn
)


def create_wrapper_for_resize_token_embedding(orig_resize_token_embeddings):

@functools.wraps(orig_resize_token_embeddings)
def wrapper(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> torch.nn.Embedding:
embeddings = self.get_input_embeddings()
lm_head = self.get_output_embeddings()
param2name = {param: name for name, param in self.named_parameters()}
if embeddings.weight.device == torch.device("meta"):
embeddings_qualified_name = param2name[embeddings.weight]
if embeddings_qualified_name in self._weight_map:
filename = self._weight_map[embeddings_qualified_name]
embeddings_weight_info = WeightInformation(
filename=filename,
qualified_name=embeddings_qualified_name,
weight_map=self._weight_map,
)
setattr(embeddings, "weight", torch.nn.Parameter(load_tensor_for_weight(embeddings_weight_info)))
self._weight_map.pop(embeddings_qualified_name)
else:
self._init_weights(embeddings)

if lm_head is not None and lm_head.weight.device == torch.device("meta"):
lm_head_qualified_name = param2name[lm_head.weight]
if lm_head_qualified_name in self._weight_map:
lm_head_weight_filename = self._weight_map[lm_head_qualified_name]
lm_head_weight_info = WeightInformation(
filename=lm_head_weight_filename,
qualified_name=lm_head_qualified_name,
weight_map=self._weight_map,
)
setattr(lm_head, "weight", torch.nn.Parameter(load_tensor_for_weight(lm_head_weight_info)))
self._weight_map.pop(lm_head_qualified_name)

if lm_head.bias is not None:
lm_head_bias_qualified_name = param2name[lm_head.bias]
lm_head_bias_filename = self._weight_map[lm_head_bias_qualified_name]
lm_head_bias_weight_info = WeightInformation(
filename=lm_head_bias_filename,
qualified_name=lm_head_bias_qualified_name,
weight_map=self._weight_map,
)
setattr(lm_head, "bias", torch.nn.Parameter(load_tensor_for_weight(lm_head_bias_weight_info)))
self._weight_map.pop(lm_head_bias_qualified_name)
else:
self._init_weights(lm_head)

return orig_resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)

bound_wrapper = wrapper.__get__(orig_resize_token_embeddings.__self__)
return bound_wrapper


@classmethod
@requires_torch_xla
def from_pretrained_for_mp(
Expand Down Expand Up @@ -1551,6 +1605,9 @@ def from_pretrained_for_mp(

model._weight_map = weight_map

resize_token_embeddings = create_wrapper_for_resize_token_embedding(model.resize_token_embeddings)
model.resize_token_embeddings = resize_token_embeddings

return model


Expand Down
71 changes: 68 additions & 3 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
"""Tests validating that models can be parallelized correctly."""

from contextlib import nullcontext
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Type, Union

import pytest
import torch
import torch.utils._pytree as pytree
from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
Expand All @@ -44,7 +45,7 @@

import optimum
from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager
from optimum.neuron.distributed.utils import compute_query_indices_for_rank
from optimum.neuron.distributed.utils import compute_query_indices_for_rank, lazy_load_for_parallelism
from optimum.neuron.utils.cache_utils import (
get_num_neuron_cores,
)
Expand Down Expand Up @@ -223,14 +224,18 @@ def sequence_parallel_enabled(self, request):
def parallelize_embeddings(self, request):
return request.param

@pytest.fixture(scope="class", params=[False, True], ids=["embeddings_not_tied", "tied_embeddings"])
def tie_embeddings(self, request):
return request.param

def early_skip(self, fixtures_kwargs):
pp_size = fixtures_kwargs.get("pp_size", None)
parallel_sizes = fixtures_kwargs.get("parallel_sizes", None)
if pp_size is None and parallel_sizes is not None:
pp_size = parallel_sizes[-1]
model_specs = fixtures_kwargs.get("model_specs", None)

if pp_size > 1 and model_specs is not None:
if pp_size is not None and pp_size > 1 and model_specs is not None:
model_type = model_specs[0]
manager = ParallelizersManager.parallelizer_for_model(model_type)
if not manager.supports_pipeline_parallelism():
Expand Down Expand Up @@ -563,6 +568,66 @@ def test_llama_v2_gqa(
parallelize_embeddings,
)

@pytest.mark.parallel_sizes((2, 2, 1))
def test_resize_embedding(self, tie_embeddings, lazy_load):
tp_size = get_tensor_model_parallel_size()
tp_group = get_tensor_model_parallel_group()

static_seed_patcher = StaticSeedPatcher(42)

config = AutoConfig.from_pretrained(LLAMA_V2_MODEL_NAME)
config.tie_word_embeddings = tie_embeddings

with static_seed_patcher:
orig_model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config)
orig_model.eval()
vocab_size = orig_model.config.vocab_size
new_vocab_size = vocab_size + tp_size

with static_seed_patcher:
orig_model.resize_token_embeddings(new_vocab_size)

ctx = lazy_load_for_parallelism(tensor_parallel_size=tp_size) if lazy_load else nullcontext()
with ctx:
with static_seed_patcher:
model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config)
model.eval()

with static_seed_patcher:
model.resize_token_embeddings(new_vocab_size)

accelerator = create_accelerator(
tp_size,
1,
parallelize_embeddings=True,
sequence_parallel_enabled=True,
)
with static_seed_patcher:
model = accelerator.prepare_model(model)

# First we check that the embedding weights match
gathered = [torch.empty_like(model.model.embed_tokens.weight) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, model.model.embed_tokens.weight, group=tp_group)
gathered_embedding = torch.cat(gathered, dim=0)
xm.mark_step()
torch.testing.assert_close(orig_model.model.embed_tokens.weight, gathered_embedding.to("cpu"))

# Second we check that logits match
tok = AutoTokenizer.from_pretrained(LLAMA_V2_MODEL_NAME)
tok.pad_token = tok.eos_token
inputs = tok("This is a test", max_length=24, padding="max_length", return_tensors="pt")
inputs = {k: v.to("xla") for k, v in inputs.items()}
orig_model = orig_model.to("xla")
orig_logits = orig_model(**inputs).logits
xm.mark_step()
logits = model(**inputs).logits
xm.mark_step()
gathered = [torch.empty_like(logits) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, logits, group=tp_group)
gathered_logits = torch.cat(gathered, dim=2)
xm.mark_step()
torch.testing.assert_close(orig_logits, gathered_logits)


@pytest.mark.parametrize(
"tp_size,num_attention_heads,num_key_value_heads,kv_size_multiplier,ground_truth",
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,15 @@ def try_to_override_via_pytest_mark(mark, name):
world_size = try_to_override_via_pytest_mark(mark, "world_size")
tp_size = try_to_override_via_pytest_mark(mark, "tp_size")
pp_size = try_to_override_via_pytest_mark(mark, "pp_size")
parallel_sizes = try_to_override_via_pytest_mark(mark, "parallel_size")
parallel_sizes = try_to_override_via_pytest_mark(mark, "parallel_sizes")

# Catch world_size, tp_size or pp_size override via fixture.
def try_to_override_via_fixture(name, current_value):
if name in self._fixture_kwargs:
if current_value is not None:
raise ValueError(f"It is not possible to override {name} both via pytest.mark and fixtures.")
return self._fixture_kwargs[name]
return None
return current_value

world_size = try_to_override_via_fixture("world_size", world_size)
tp_size = try_to_override_via_fixture("tp_size", tp_size)
Expand Down
Loading