Skip to content

Commit

Permalink
Fix GQA permutation computation and sequential weight initialization …
Browse files Browse the repository at this point in the history
…/ loading when doing TP (#531)
  • Loading branch information
michaelbenayoun authored Mar 28, 2024
1 parent f5c909e commit 1bc0405
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 40 deletions.
2 changes: 2 additions & 0 deletions optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class ModelParallelismPlugin:
pipeline_parallel_use_zero1_optimizer: bool = False
gradient_checkpointing: bool = False
checkpoint_dir: Optional[Union[str, Path]] = None
num_ranks_per_loading_step: int = -1

def __post_init__(self):
if self.tensor_parallel_size < 1:
Expand Down Expand Up @@ -181,5 +182,6 @@ def parallelize_model(
pipeline_parallel_use_zero1_optimizer=self.pipeline_parallel_use_zero1_optimizer,
pipeline_parallel_gradient_checkpointing_enabled=self.gradient_checkpointing,
checkpoint_dir=self.checkpoint_dir,
num_ranks_per_loading_step=self.num_ranks_per_loading_step,
)
return parallelized_model
29 changes: 21 additions & 8 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Base class related to `neuronx_distributed` to perform parallelism."""

import contextlib
import gc
import math
import shutil
from abc import ABC, abstractclassmethod
from collections import defaultdict
Expand Down Expand Up @@ -540,6 +542,7 @@ def parallelize(
pipeline_parallel_use_zero1_optimizer: bool = False,
pipeline_parallel_gradient_checkpointing_enabled: bool = False,
checkpoint_dir: Optional[Union[str, Path]] = None,
num_ranks_per_loading_step: int = -1,
) -> "PreTrainedModel":
"""
Parallelizes the model by transforming regular layer into their parallel counterparts using
Expand Down Expand Up @@ -572,6 +575,9 @@ def parallelize(
checkpoint_dir (`Optional[Union[str, Path]]`):
Path to a sharded checkpoint. If specified, the checkpoint weights will be loaded to the parallelized
model.
num_ranks_per_loading_step (`int`, defaults to `-1`):
Corresponds to the number of ranks that can initialize and load the model weights at the same time.
If the value is inferior to 0, the maximum number of ranks will be used.
Returns:
`PreTrainedModel`: The parallelized model.
Expand All @@ -586,6 +592,7 @@ def parallelize(
get_tensor_model_parallel_size,
)
from neuronx_distributed.parallel_layers.random import _MODEL_PARALLEL_RNG_TRACKER_NAME, get_xla_rng_tracker
from neuronx_distributed.parallel_layers.utils import get_local_world_size
from neuronx_distributed.pipeline import NxDPPModel

tp_size = get_tensor_model_parallel_size()
Expand Down Expand Up @@ -698,14 +705,20 @@ def should_parallelize_layer_predicate_func(layer):
if is_precompilation():
cls._initialize_for_precompilation(model, names_of_the_parameters_to_consider)
else:
if skip_linear_weight_load:
# Load the weights to the parallel linears if the loading was skipped during parallelization.
cls._maybe_load_weights_to_parallel_linears(model)

if skip_linear_weight_load or any(p.device == torch.device("meta") for p in model.parameters()):
# Initialize or load the weights for the parallelized model if it was lazily loaded.
cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device)

local_rank = xm.get_local_ordinal()
if num_ranks_per_loading_step < 0:
num_ranks_per_loading_step = get_local_world_size()
for worker in range(math.ceil(get_local_world_size() / num_ranks_per_loading_step)):
if local_rank // num_ranks_per_loading_step == worker:
if skip_linear_weight_load:
# Load the weights to the parallel linears if the loading was skipped during parallelization.
cls._maybe_load_weights_to_parallel_linears(model)

if skip_linear_weight_load or any(p.device == torch.device("meta") for p in model.parameters()):
# Initialize or load the weights for the parallelized model if it was lazily loaded.
cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device)
gc.collect()
xm.rendezvous(f"weight_loading_and_initialization_{worker}")
xm.rendezvous("End of initalization")

if is_main_worker():
Expand Down
12 changes: 6 additions & 6 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME

from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors
from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indicies_for_rank
from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indices_for_rank


def create_gqa_query_or_output_projection_weight_from_full_weight(
Expand All @@ -42,15 +42,15 @@ def create_gqa_query_or_output_projection_weight_from_full_weight(
hidden_size = full_weight.size(0)
full_weight = full_weight.transpose(0, 1)

indicies = [
compute_query_indicies_for_rank(tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier)
indices = [
compute_query_indices_for_rank(tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier)
for tp_rank in range(tp_size)
]
indicies = torch.cat(indicies, dim=0)
reversed_indicies = torch.sort(indicies, dim=0).indices
indices = torch.cat(indices, dim=0)
reversed_indices = torch.sort(indices, dim=0).indices

full_weight = full_weight.reshape(num_attention_heads, -1, hidden_size)
full_weight = full_weight[reversed_indicies]
full_weight = full_weight[reversed_indices]
full_weight = full_weight.reshape(-1, hidden_size)

if query_or_output == "output":
Expand Down
44 changes: 23 additions & 21 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def create_kv_proj_local_weight_from_regular_weight(
return torch.cat(split[tp_rank::tp_size], dim=0)


def compute_query_indicies_for_rank(
def compute_query_indices_for_rank(
tp_size: int, tp_rank: int, num_attention_heads: int, num_key_value_heads: int, kv_size_multiplier: int
):
"""
Expand All @@ -519,32 +519,34 @@ def compute_query_indicies_for_rank(
query_group_size = num_attention_heads // num_key_value_heads
query_group_size_per_rank = num_attention_heads_per_rank // num_key_value_heads_per_rank

queries_indicies = [torch.arange(query_group_size_per_rank) for _ in range(num_key_value_heads_per_rank)]
queries_indices = [torch.arange(query_group_size_per_rank) for _ in range(num_key_value_heads_per_rank)]

keys_indicies = torch.arange(num_key_value_heads).repeat(kv_size_multiplier)
keys_indicies = torch.repeat_interleave(
keys_indicies, num_attention_heads_per_rank // num_key_value_heads_per_rank
)
keys_indicies = torch.chunk(keys_indicies, tp_size)
keys_indices = torch.arange(num_key_value_heads).repeat(kv_size_multiplier)
keys_indices = torch.repeat_interleave(keys_indices, num_attention_heads_per_rank // num_key_value_heads_per_rank)
keys_indices = torch.chunk(keys_indices, tp_size)

shift_per_key = torch.arange(0, num_attention_heads, query_group_size)

shift_within_query_group = torch.arange(0, query_group_size, query_group_size_per_rank)
num_ranks_to_fit_all_key_value_heads = num_key_value_heads // num_key_value_heads_per_rank
num_query_heads_before_next_head_of_same_group = (
num_ranks_to_fit_all_key_value_heads * num_attention_heads_per_rank
)
shift_within_query_group = torch.repeat_interleave(
shift_within_query_group, num_attention_heads_per_rank * num_key_value_heads_per_rank
shift_within_query_group, num_query_heads_before_next_head_of_same_group
)
shift_within_query_group = torch.chunk(shift_within_query_group, tp_size)

indicies = []
for idx, q_indicies in enumerate(queries_indicies):
indices = []
for idx, q_indices in enumerate(queries_indices):
s = slice(idx * query_group_size_per_rank, (idx + 1) * query_group_size_per_rank)
k_indicies = keys_indicies[tp_rank][s]
k_shift = shift_per_key[k_indicies]
k_indices = keys_indices[tp_rank][s]
k_shift = shift_per_key[k_indices]
group_shift = shift_within_query_group[tp_rank][s]
indicies.append(q_indicies + k_shift + group_shift)
indices.append(q_indices + k_shift + group_shift)

indicies = torch.cat(indicies, dim=0)
return indicies
indices = torch.cat(indices, dim=0)
return indices


@requires_neuronx_distributed
Expand Down Expand Up @@ -578,11 +580,11 @@ def create_query_or_output_projection_local_weight_from_regular_weight(
head_dim = weight_data.size(1) // num_attention_heads
weight_data = weight_data.transpose(0, 1)

indicies = compute_query_indicies_for_rank(
indices = compute_query_indices_for_rank(
tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier
)
reshaped_weight = weight_data.view(num_attention_heads, head_dim, hidden_size)
shuffled_weight = reshaped_weight[indicies]
shuffled_weight = reshaped_weight[indices]
shuffled_weight = shuffled_weight.reshape(-1, hidden_size)

if query_or_output_proj == "output":
Expand Down Expand Up @@ -620,21 +622,21 @@ def create_local_bias_from_regular_bias(

else:
if gather_output:
indicies = torch.cat(
indices = torch.cat(
[
compute_query_indicies_for_rank(
compute_query_indices_for_rank(
tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier
)
for tp_rank in range(tp_size)
],
dim=0,
)
else:
indicies = compute_query_indicies_for_rank(
indices = compute_query_indices_for_rank(
tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier
)
reshaped_bias_weight = bias_weigth_data.view(num_attention_heads, -1)
shuffled_bias_weight = reshaped_bias_weight[indicies]
shuffled_bias_weight = reshaped_bias_weight[indices]
local_bias_weight = shuffled_bias_weight.reshape(-1)
return local_bias_weight

Expand Down
5 changes: 3 additions & 2 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ def compute_loss(self, model, inputs, return_outputs: bool = False):
loss = model.run_train(**inputs)
return loss

return super().compute_loss(model, inputs, return_outputs=return_outputs)
loss = super().compute_loss(model, inputs, return_outputs=return_outputs)
return loss

def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
from neuronx_distributed.pipeline import NxDPPModel
Expand Down Expand Up @@ -397,10 +398,10 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
)

if self.args.mp_plugin.should_parallelize:

dp_size = get_data_parallel_size()
pp_size = get_pipeline_model_parallel_size()
pp_rank = get_pipeline_model_parallel_rank()

tr_loss_div = tr_loss / dp_size

if pp_size > 1 and pp_rank == pp_size - 1:
Expand Down
20 changes: 19 additions & 1 deletion optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ..utils import check_if_transformers_greater, logging
from .accelerate import NeuronAcceleratorState, NeuronPartialState
from .accelerate.utils import ModelParallelismPlugin, patch_accelerate_is_tpu_available
from .utils import is_accelerate_available, is_torch_xla_available
from .utils import is_accelerate_available, is_main_worker, is_torch_xla_available
from .utils.patching import Patcher
from .utils.training_utils import TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP

Expand Down Expand Up @@ -100,6 +100,15 @@ class NeuronTrainingArgumentsMixin:
)
},
)
num_ranks_per_loading_step: int = field(
default=-1,
metadata={
"help": (
"The number of ranks to use concurrently during weight initialization and loading when tensor "
"parallelism is enabled. If left unspecified, the maximum number of ranks will be used."
)
},
)

def __post_init__(self):
# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
Expand Down Expand Up @@ -141,6 +150,14 @@ def __post_init__(self):
resume_from_checkpoint = checkpoint

if self.pipeline_parallel_size > 1:
if self.gradient_accumulation_steps > 1:
if is_main_worker():
logger.info(
"Pipeline parallel used, setting gradient_accumulation_steps to 1 and scaling the pipeline batch size."
)
self.per_device_train_batch_size *= self.gradient_accumulation_steps
self.per_device_eval_batch_size *= self.gradient_accumulation_steps
self.gradient_accumulation_steps = 1
if self.pipeline_parallel_num_microbatches == -1:
self.pipeline_parallel_num_microbatches = self.per_device_train_batch_size
if self.per_device_train_batch_size % self.pipeline_parallel_num_microbatches != 0:
Expand All @@ -164,6 +181,7 @@ def __post_init__(self):
pipeline_parallel_use_zero1_optimizer=self.zero_1,
gradient_checkpointing=self.gradient_checkpointing,
checkpoint_dir=resume_from_checkpoint,
num_ranks_per_loading_step=self.num_ranks_per_loading_step,
)

# This is required to be able to use bf16, otherwise a check in super().__post_init__() fails.
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
DiffusersPretrainedConfig,
check_if_weights_replacable,
get_stable_diffusion_configs,
is_main_worker,
replace_weights,
)
from .optimization_utils import get_attention_scores_sd, get_attention_scores_sdxl
Expand Down
45 changes: 43 additions & 2 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import optimum
from optimum.neuron.accelerate.accelerator import NeuronAccelerator
from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager
from optimum.neuron.distributed.utils import compute_query_indicies_for_rank
from optimum.neuron.distributed.utils import compute_query_indices_for_rank
from optimum.neuron.utils.cache_utils import (
get_num_neuron_cores,
)
Expand Down Expand Up @@ -609,11 +609,52 @@ def test_llama_v2_gqa(
[7, 15, 23, 31],
],
],
[
32,
32,
4,
8,
[
[0],
[8],
[16],
[24],
[1],
[9],
[17],
[25],
[2],
[10],
[18],
[26],
[3],
[11],
[19],
[27],
[4],
[12],
[20],
[28],
[5],
[13],
[21],
[29],
[6],
[14],
[22],
[30],
[7],
[15],
[23],
[31],
],
],
],
ids=[
"32-heads-4kv-heads-kv-mul-2,one kv head per rank",
"32-heads-4kv-heads-kv-mul-4,multiple kv heads per rank",
"32-heads-4kv-heads-kv-mul-8,all kv heads per rank",
"tp=32,32-heads-4kv-heads-kv-mul-8,one query head per rank",
],
)
@is_trainium_test
Expand All @@ -622,7 +663,7 @@ def test_compute_query_indices_for_rank(
):
for tp_rank in range(tp_size):
expected = torch.tensor(ground_truth[tp_rank])
computed = compute_query_indicies_for_rank(
computed = compute_query_indices_for_rank(
tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier
)
print(f"TP rank = {tp_rank}")
Expand Down

0 comments on commit 1bc0405

Please sign in to comment.