Skip to content

Commit

Permalink
Support resize embeddings (#670)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Oct 2, 2024
1 parent 1ef1e12 commit 8646596
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 11 deletions.
3 changes: 3 additions & 0 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,9 @@ def should_parallelize_layer_predicate_func(layer):
cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device)
gc.collect()

# It is important to do that here because initialization can untie weights.
model.tie_weights()

# Because we initialize new parameters, we need to make sure that only the ones that required grads before
# parallelization require grad after parallelization.
for name, parameter in model.named_parameters():
Expand Down
13 changes: 7 additions & 6 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,13 @@ def _transform(
embedding_weight_name = f"{layer_qualified_name}.{embedding_name}.weight"
else:
embedding_weight_name = f"{embedding_name}.weight"
embedding_weight_info = WeightInformation(
weight_map[embedding_weight_name],
embedding_weight_name,
weight_map=weight_map,
device=device,
)
if embedding_name in weight_map:
embedding_weight_info = WeightInformation(
weight_map[embedding_weight_name],
embedding_weight_name,
weight_map=weight_map,
device=device,
)
if model_has_lm_head:
if layer_qualified_name:
lm_head_weight_name = f"{layer_qualified_name}.{lm_head_name}.weight"
Expand Down
57 changes: 57 additions & 0 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,60 @@ def parameter_can_be_initialized(model: torch.nn.Module, parent_module: torch.nn
)


def create_wrapper_for_resize_token_embedding(orig_resize_token_embeddings):

@functools.wraps(orig_resize_token_embeddings)
def wrapper(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> torch.nn.Embedding:
embeddings = self.get_input_embeddings()
lm_head = self.get_output_embeddings()
param2name = {param: name for name, param in self.named_parameters()}
if embeddings.weight.device == torch.device("meta"):
embeddings_qualified_name = param2name[embeddings.weight]
if embeddings_qualified_name in self._weight_map:
filename = self._weight_map[embeddings_qualified_name]
embeddings_weight_info = WeightInformation(
filename=filename,
qualified_name=embeddings_qualified_name,
weight_map=self._weight_map,
)
setattr(embeddings, "weight", torch.nn.Parameter(load_tensor_for_weight(embeddings_weight_info)))
self._weight_map.pop(embeddings_qualified_name)
else:
self._init_weights(embeddings)

if lm_head is not None and lm_head.weight.device == torch.device("meta"):
lm_head_qualified_name = param2name[lm_head.weight]
if lm_head_qualified_name in self._weight_map:
lm_head_weight_filename = self._weight_map[lm_head_qualified_name]
lm_head_weight_info = WeightInformation(
filename=lm_head_weight_filename,
qualified_name=lm_head_qualified_name,
weight_map=self._weight_map,
)
setattr(lm_head, "weight", torch.nn.Parameter(load_tensor_for_weight(lm_head_weight_info)))
self._weight_map.pop(lm_head_qualified_name)

if lm_head.bias is not None:
lm_head_bias_qualified_name = param2name[lm_head.bias]
lm_head_bias_filename = self._weight_map[lm_head_bias_qualified_name]
lm_head_bias_weight_info = WeightInformation(
filename=lm_head_bias_filename,
qualified_name=lm_head_bias_qualified_name,
weight_map=self._weight_map,
)
setattr(lm_head, "bias", torch.nn.Parameter(load_tensor_for_weight(lm_head_bias_weight_info)))
self._weight_map.pop(lm_head_bias_qualified_name)
else:
self._init_weights(lm_head)

return orig_resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)

bound_wrapper = wrapper.__get__(orig_resize_token_embeddings.__self__)
return bound_wrapper


@classmethod
@requires_torch_xla
def from_pretrained_for_mp(
Expand Down Expand Up @@ -1551,6 +1605,9 @@ def from_pretrained_for_mp(

model._weight_map = weight_map

resize_token_embeddings = create_wrapper_for_resize_token_embedding(model.resize_token_embeddings)
model.resize_token_embeddings = resize_token_embeddings

return model


Expand Down
71 changes: 68 additions & 3 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
"""Tests validating that models can be parallelized correctly."""

from contextlib import nullcontext
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Type, Union

import pytest
import torch
import torch.utils._pytree as pytree
from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
Expand All @@ -44,7 +45,7 @@

import optimum
from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager
from optimum.neuron.distributed.utils import compute_query_indices_for_rank
from optimum.neuron.distributed.utils import compute_query_indices_for_rank, lazy_load_for_parallelism
from optimum.neuron.utils.cache_utils import (
get_num_neuron_cores,
)
Expand Down Expand Up @@ -223,14 +224,18 @@ def sequence_parallel_enabled(self, request):
def parallelize_embeddings(self, request):
return request.param

@pytest.fixture(scope="class", params=[False, True], ids=["embeddings_not_tied", "tied_embeddings"])
def tie_embeddings(self, request):
return request.param

def early_skip(self, fixtures_kwargs):
pp_size = fixtures_kwargs.get("pp_size", None)
parallel_sizes = fixtures_kwargs.get("parallel_sizes", None)
if pp_size is None and parallel_sizes is not None:
pp_size = parallel_sizes[-1]
model_specs = fixtures_kwargs.get("model_specs", None)

if pp_size > 1 and model_specs is not None:
if pp_size is not None and pp_size > 1 and model_specs is not None:
model_type = model_specs[0]
manager = ParallelizersManager.parallelizer_for_model(model_type)
if not manager.supports_pipeline_parallelism():
Expand Down Expand Up @@ -563,6 +568,66 @@ def test_llama_v2_gqa(
parallelize_embeddings,
)

@pytest.mark.parallel_sizes((2, 2, 1))
def test_resize_embedding(self, tie_embeddings, lazy_load):
tp_size = get_tensor_model_parallel_size()
tp_group = get_tensor_model_parallel_group()

static_seed_patcher = StaticSeedPatcher(42)

config = AutoConfig.from_pretrained(LLAMA_V2_MODEL_NAME)
config.tie_word_embeddings = tie_embeddings

with static_seed_patcher:
orig_model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config)
orig_model.eval()
vocab_size = orig_model.config.vocab_size
new_vocab_size = vocab_size + tp_size

with static_seed_patcher:
orig_model.resize_token_embeddings(new_vocab_size)

ctx = lazy_load_for_parallelism(tensor_parallel_size=tp_size) if lazy_load else nullcontext()
with ctx:
with static_seed_patcher:
model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config)
model.eval()

with static_seed_patcher:
model.resize_token_embeddings(new_vocab_size)

accelerator = create_accelerator(
tp_size,
1,
parallelize_embeddings=True,
sequence_parallel_enabled=True,
)
with static_seed_patcher:
model = accelerator.prepare_model(model)

# First we check that the embedding weights match
gathered = [torch.empty_like(model.model.embed_tokens.weight) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, model.model.embed_tokens.weight, group=tp_group)
gathered_embedding = torch.cat(gathered, dim=0)
xm.mark_step()
torch.testing.assert_close(orig_model.model.embed_tokens.weight, gathered_embedding.to("cpu"))

# Second we check that logits match
tok = AutoTokenizer.from_pretrained(LLAMA_V2_MODEL_NAME)
tok.pad_token = tok.eos_token
inputs = tok("This is a test", max_length=24, padding="max_length", return_tensors="pt")
inputs = {k: v.to("xla") for k, v in inputs.items()}
orig_model = orig_model.to("xla")
orig_logits = orig_model(**inputs).logits
xm.mark_step()
logits = model(**inputs).logits
xm.mark_step()
gathered = [torch.empty_like(logits) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, logits, group=tp_group)
gathered_logits = torch.cat(gathered, dim=2)
xm.mark_step()
torch.testing.assert_close(orig_logits, gathered_logits)


@pytest.mark.parametrize(
"tp_size,num_attention_heads,num_key_value_heads,kv_size_multiplier,ground_truth",
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,15 @@ def try_to_override_via_pytest_mark(mark, name):
world_size = try_to_override_via_pytest_mark(mark, "world_size")
tp_size = try_to_override_via_pytest_mark(mark, "tp_size")
pp_size = try_to_override_via_pytest_mark(mark, "pp_size")
parallel_sizes = try_to_override_via_pytest_mark(mark, "parallel_size")
parallel_sizes = try_to_override_via_pytest_mark(mark, "parallel_sizes")

# Catch world_size, tp_size or pp_size override via fixture.
def try_to_override_via_fixture(name, current_value):
if name in self._fixture_kwargs:
if current_value is not None:
raise ValueError(f"It is not possible to override {name} both via pytest.mark and fixtures.")
return self._fixture_kwargs[name]
return None
return current_value

world_size = try_to_override_via_fixture("world_size", world_size)
tp_size = try_to_override_via_fixture("tp_size", tp_size)
Expand Down

0 comments on commit 8646596

Please sign in to comment.