Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Sep 20, 2024
1 parent fbf0b9a commit e11ec07
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 5 deletions.
10 changes: 10 additions & 0 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,7 @@ def maybe_load_linear_weight_to_parallel_linear(


@requires_peft
@requires_neuronx_distributed
def _parallelize_active_adapters(
tuner_layer: "BaseTunerLayer",
axis: Union[Literal["row"], Literal["column"]],
Expand All @@ -1020,6 +1021,8 @@ def _parallelize_active_adapters(
device: Optional["torch.device"] = None,
):
from peft.tuners.lora import Linear as LoraLinear
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size


try:
peft_config = tuner_layer._peft_config
Expand All @@ -1042,8 +1045,15 @@ def _parallelize_active_adapters(
)
if axis == "row":
layer_to_parallelize = tuner_layer.lora_A[adapter_name]
dim_to_partition = layer_to_parallelize.weight.size(0)
else:
layer_to_parallelize = tuner_layer.lora_B[adapter_name]
dim_to_partition = layer_to_parallelize.weight.size(1)

tp_size = get_tensor_model_parallel_size()

if dim_to_partition % tp_size != 0:
raise RuntimeError(f"The LoRA adapter dimension to parallelize ({dim_to_partition}) is not divisible by the TP size ({tp_size}).")

# TODO: handle the case were weights already exist for this adapter.
parallel_layer = linear_to_parallel_linear(
Expand Down
78 changes: 75 additions & 3 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
# limitations under the License.
"""Tests validating that models can be parallelized correctly."""

from contextlib import nullcontext
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Type, Union

import pytest
import torch
import torch.utils._pytree as pytree
from transformers import AutoTokenizer, LlamaForCausalLM
from peft import LoraConfig
from peft import get_peft_model as orig_get_peft_model
from transformers import AutoTokenizer, LlamaForCausalLM, AutoModelForCausalLM, AutoConfig
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
Expand All @@ -44,7 +47,7 @@

import optimum
from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager
from optimum.neuron.distributed.utils import compute_query_indices_for_rank
from optimum.neuron.distributed.utils import compute_query_indices_for_rank, lazy_load_for_parallelism
from optimum.neuron.utils.cache_utils import (
get_num_neuron_cores,
)
Expand All @@ -53,6 +56,7 @@
is_neuronx_distributed_available,
is_torch_xla_available,
)
from optimum.neuron.utils.peft_utils import get_peft_model
from optimum.neuron.utils.testing_utils import is_trainium_test

from .. import DistributedTest
Expand Down Expand Up @@ -228,7 +232,7 @@ def early_skip(self, fixtures_kwargs):
pp_size = parallel_sizes[-1]
model_specs = fixtures_kwargs.get("model_specs", None)

if pp_size > 1 and model_specs is not None:
if pp_size is not None and pp_size > 1 and model_specs is not None:
model_type = model_specs[0]
manager = ParallelizersManager.parallelizer_for_model(model_type)
if not manager.supports_pipeline_parallelism():
Expand Down Expand Up @@ -424,6 +428,74 @@ def test_parallel_model_matches_original_model_from_config(
model_class, model_name_or_path, config_overwrite, parallel_sizes, False, True, False, False
)

@pytest.mark.parallel_sizes((8, 8, 1))
def test_llama_v2_gqa_and_lora(self, lazy_load):
tp_size = get_tensor_model_parallel_size()
tp_group = get_tensor_model_parallel_group()

static_seed_patcher = StaticSeedPatcher(42)

config = AutoConfig.from_pretrained(LLAMA_V2_MODEL_NAME)
config.num_hidden_layers = 2
config.num_key_value_heads = 2
config.hidden_size = 128
config.tie_word_embeddings = True

lora_config = LoraConfig(
r=64,
lora_alpha=128,
lora_dropout=0.0,
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)

with static_seed_patcher:
orig_model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config, ignore_mismatched_sizes=True)
orig_model.eval()

with static_seed_patcher:
orig_model = orig_get_peft_model(orig_model, lora_config)

orig_model.to("xla")

ctx = lazy_load_for_parallelism(tensor_parallel_size=tp_size) if lazy_load else nullcontext()
with ctx:
with static_seed_patcher:
model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config, ignore_mismatched_sizes=True)
model.eval()

with static_seed_patcher:
model = get_peft_model(model, lora_config)

accelerator = create_accelerator(
tp_size,
1,
parallelize_embeddings=True,
sequence_parallel_enabled=True,
)
with static_seed_patcher:
model = accelerator.prepare_model(model)

tok = AutoTokenizer.from_pretrained(LLAMA_V2_MODEL_NAME)
tok.pad_token = tok.eos_token

inputs = tok("This is a curious test.", padding="max_length", max_length=24, return_tensors="pt")
print(inputs["input_ids"])
inputs = {k: v.to("xla") for k, v in inputs.items()}

orig_logits = orig_model(**inputs).logits
xm.mark_step()

logits = model(**inputs).logits
xm.mark_step()

gathered = [torch.empty_like(logits) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, logits, group=tp_group)
gathered_logits = torch.cat(gathered, dim=2)
xm.mark_step()
xm.master_print(torch.nonzero(orig_logits - gathered_logits))
torch.testing.assert_close(orig_logits, gathered_logits)

@pytest.mark.skipif(
NUM_NEURON_CORES_AVAILABLE < 32,
reason=f"This test requires 32 Neuron cores, but only {NUM_NEURON_CORES_AVAILABLE} are available",
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,15 @@ def try_to_override_via_pytest_mark(mark, name):
world_size = try_to_override_via_pytest_mark(mark, "world_size")
tp_size = try_to_override_via_pytest_mark(mark, "tp_size")
pp_size = try_to_override_via_pytest_mark(mark, "pp_size")
parallel_sizes = try_to_override_via_pytest_mark(mark, "parallel_size")
parallel_sizes = try_to_override_via_pytest_mark(mark, "parallel_sizes")

# Catch world_size, tp_size or pp_size override via fixture.
def try_to_override_via_fixture(name, current_value):
if name in self._fixture_kwargs:
if current_value is not None:
raise ValueError(f"It is not possible to override {name} both via pytest.mark and fixtures.")
return self._fixture_kwargs[name]
return None
return current_value

world_size = try_to_override_via_fixture("world_size", world_size)
tp_size = try_to_override_via_fixture("tp_size", tp_size)
Expand Down

0 comments on commit e11ec07

Please sign in to comment.