From 4a7df1a900a9c7226ef3e0d9f10c2885c34b3292 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 20 Mar 2024 17:11:34 +0100 Subject: [PATCH] GQA optimization for TP (#498) --- notebooks/text-generation/scripts/run_clm.py | 4 +- optimum/commands/neuron/subcommands.py | 4 +- optimum/neuron/accelerate/accelerator.py | 15 +- .../neuron/accelerate/utils/dataclasses.py | 2 + optimum/neuron/distributed/__init__.py | 4 +- optimum/neuron/distributed/base.py | 207 +++++-- optimum/neuron/distributed/checkpointing.py | 151 ++++- optimum/neuron/distributed/decoder_models.py | 6 +- .../distributed/encoder_decoder_models.py | 4 +- optimum/neuron/distributed/parallel_layers.py | 288 ++++++--- optimum/neuron/distributed/utils.py | 580 +++++++++++++++++- optimum/neuron/training_args.py | 12 + optimum/neuron/utils/misc.py | 27 +- tests/distributed/test_common.py | 65 +- .../distributed/test_model_parallelization.py | 301 ++++++--- tests/distributed/utils.py | 7 + tests/test_cache_utils.py | 2 + 17 files changed, 1394 insertions(+), 285 deletions(-) diff --git a/notebooks/text-generation/scripts/run_clm.py b/notebooks/text-generation/scripts/run_clm.py index 5e2c6744a..6dfa51a4b 100644 --- a/notebooks/text-generation/scripts/run_clm.py +++ b/notebooks/text-generation/scripts/run_clm.py @@ -47,11 +47,11 @@ def training_function(script_args, training_args): # if (int(os.environ.get("RANK", -1)) == 0) and int(training_args.tensor_parallel_size) > 1: # print("Converting sharded checkpoint to consolidated format") # from optimum.neuron.distributed.checkpointing import ( - # consolidate_tensor_parallel_checkpoints_to_unified_checkpoint, + # consolidate_model_parallel_checkpoints_to_unified_checkpoint, # ) # from shutil import rmtree - # consolidate_tensor_parallel_checkpoints_to_unified_checkpoint( + # consolidate_model_parallel_checkpoints_to_unified_checkpoint( # training_args.output_dir, training_args.output_dir, "pytorch" # ) # rmtree(os.path.join(training_args.output_dir, "tensor_parallel_shards")) # remove sharded checkpoint files diff --git a/optimum/commands/neuron/subcommands.py b/optimum/commands/neuron/subcommands.py index e07ae5fb0..070eb36e1 100644 --- a/optimum/commands/neuron/subcommands.py +++ b/optimum/commands/neuron/subcommands.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING -from ...neuron.distributed import consolidate_tensor_parallel_checkpoints_to_unified_checkpoint +from ...neuron.distributed import consolidate_model_parallel_checkpoints_to_unified_checkpoint from ...utils import logging from ..base import BaseOptimumCLICommand @@ -53,7 +53,7 @@ def parse_args(parser: "ArgumentParser"): def run(self): checkpoint_format = "safetensors" if self.args.format == "safetensors" else "pytorch" logger.info(f"Consolidating checkpoints from {self.args.checkpoint_dir} to the {checkpoint_format} format...") - consolidate_tensor_parallel_checkpoints_to_unified_checkpoint( + consolidate_model_parallel_checkpoints_to_unified_checkpoint( self.args.checkpoint_dir, self.args.output_dir, save_format=self.args.format, diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index e9bb53369..06e4c3660 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -413,6 +413,7 @@ def prepare_model_for_xla_fsdp( def _prepare_model_for_mp( self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False ): + import torch_xla.core.xla_model as xm from neuronx_distributed.pipeline import NxDPPModel if model in self._models or Parallelizer.was_parallelized(model): @@ -421,7 +422,7 @@ def _prepare_model_for_mp( cpu_ids = {name: id(param) for name, param in model.named_parameters()} tied_parameters_dict = get_tied_parameters_dict(model) model_main_input_name = getattr(model, "main_input_name", None) - # TODO: enable self.device (if needed). + # TODO: use self.device. model = self.state.mp_plugin.parallelize_model(model, device=None) if model_main_input_name is not None: @@ -435,6 +436,11 @@ def _prepare_model_for_mp( else: model_to_cast = model + # Update CPU ids + original_parameter_names_to_gqa_qkv_names = model._gqa_qkv_metadata["original_names_to_gqa_qkv_names"] + for key in list(cpu_ids.keys()): + cpu_ids[original_parameter_names_to_gqa_qkv_names.get(key, key)] = cpu_ids.pop(key) + model_to_cast = model.local_module if isinstance(model, NxDPPModel) else model if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1": model_to_cast.to(torch.bfloat16) @@ -460,6 +466,7 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings): move_model_to_device(model, self.device) tie_parameters(model, tied_parameters_dict) xla_params = dict(model.named_parameters()) + symmetric_diff = set(cpu_ids.keys()).symmetric_difference((xla_params.keys())) if symmetric_diff: raise ValueError( @@ -470,6 +477,7 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings): cpu_ids[name]: xla_params[name] for name, _ in model.named_parameters() } + xm.mark_step() device_placement = False return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) @@ -485,8 +493,11 @@ def prepare_model( model = self.patch_model_for_neuron(model) - # We do not want to use the cache here as it would imply more communication that we do not need. + # We do not want to use the cache, or output unused tensors as it would imply more communication that we do not + # need. model.config.use_cache = False + model.config.output_attentions = False + model.config.output_hidden_states = False if self.distributed_type is NeuronDistributedType.XLA_FSDP: return self.prepare_model_for_xla_fsdp( diff --git a/optimum/neuron/accelerate/utils/dataclasses.py b/optimum/neuron/accelerate/utils/dataclasses.py index a1ec11154..8f4ce5b45 100644 --- a/optimum/neuron/accelerate/utils/dataclasses.py +++ b/optimum/neuron/accelerate/utils/dataclasses.py @@ -144,6 +144,7 @@ class ModelParallelismPlugin: tensor_parallel_size: int = 1 parallelize_embeddings: bool = True sequence_parallel_enabled: bool = False + kv_size_multiplier: Optional[int] = None pipeline_parallel_size: int = 1 pipeline_parallel_num_microbatches: int = 1 pipeline_parallel_use_zero1_optimizer: bool = False @@ -175,6 +176,7 @@ def parallelize_model( device=device, parallelize_embeddings=self.parallelize_embeddings, sequence_parallel_enabled=self.sequence_parallel_enabled, + kv_size_multiplier=self.kv_size_multiplier, pipeline_parallel_num_microbatches=self.pipeline_parallel_num_microbatches, pipeline_parallel_use_zero1_optimizer=self.pipeline_parallel_use_zero1_optimizer, pipeline_parallel_gradient_checkpointing_enabled=self.gradient_checkpointing, diff --git a/optimum/neuron/distributed/__init__.py b/optimum/neuron/distributed/__init__.py index 7af896662..5ea728233 100644 --- a/optimum/neuron/distributed/__init__.py +++ b/optimum/neuron/distributed/__init__.py @@ -15,8 +15,8 @@ from .base import Parallelizer from .checkpointing import ( - consolidate_tensor_parallel_checkpoints, - consolidate_tensor_parallel_checkpoints_to_unified_checkpoint, + consolidate_model_parallel_checkpoints, + consolidate_model_parallel_checkpoints_to_unified_checkpoint, ) from .parallelizers_manager import ParallelizersManager from .utils import lazy_load_for_parallelism, make_optimizer_constructor_lazy diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 2033d1705..0b21f03b7 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -36,19 +36,26 @@ IOSequenceParallelizer, LayerNormSequenceParallelizer, LayerNormType, - ParallelLayer, SequenceCollectiveOpInfo, ) from .utils import ( TENSOR_PARALLEL_SHARDS_DIR_NAME, + OptimumGQAQKVColumnParallelLinear, + OptimumNeuronFXTracer, ParameterMetadata, WeightInformation, apply_activation_checkpointing, + get_linear_weight_info, + get_output_projection_qualified_names_after_qga_qkv_replacement, + get_parameter_names_mapping_after_gqa_qkv_replacement, initialize_parallel_linear, initialize_torch_nn_module, linear_to_parallel_linear, load_tensor_for_weight, + maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear, maybe_load_linear_weight_to_parallel_linear, + maybe_load_weights_to_gqa_qkv_column_parallel_linear, + maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear, named_parameters, parameter_can_be_initialized, try_to_hf_initialize, @@ -274,25 +281,51 @@ def _parallelize( @classmethod @requires_neuronx_distributed def _maybe_load_weights_to_parallel_linears(cls, model: "PreTrainedModel"): - from neuronx_distributed.parallel_layers.layers import BaseParallelLinear + from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, + ) weight_map = getattr(model, "_weight_map", {}) + name_to_module = dict(model.named_modules()) + + gqa_output_projections = {} + for fully_qualified_name, layer in name_to_module.items(): + if isinstance(layer, OptimumGQAQKVColumnParallelLinear): + parent_name = fully_qualified_name.rsplit(".", maxsplit=1)[0] + output_projection_name = f"{parent_name}.{layer.output_proj_name}" + gqa_output_projections[output_projection_name] = ( + layer.num_attention_heads, + layer.num_key_value_heads, + layer.kv_size_multiplier, + ) - for fully_qualified_name, layer in model.named_modules(): - if isinstance(layer, BaseParallelLinear): - try: - linear_weight_info, linear_bias_weight_info = ParallelLayer._get_linear_weight_info( - weight_map, fully_qualified_name - ) - except ValueError: - linear_weight_info = None - linear_bias_weight_info = None + for fully_qualified_name, layer in name_to_module.items(): + if isinstance(layer, (RowParallelLinear, ColumnParallelLinear)): + linear_weight_info, linear_bias_weight_info = get_linear_weight_info( + weight_map, fully_qualified_name, fail_if_not_found=False + ) if linear_weight_info is not None: - maybe_load_linear_weight_to_parallel_linear( - layer, - linear_layer_weight_info=linear_weight_info, - linear_layer_bias_weight_info=linear_bias_weight_info, - ) + if fully_qualified_name in gqa_output_projections: + num_attention_heads, num_key_value_heads, kv_size_multiplier = gqa_output_projections[ + fully_qualified_name + ] + maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear( + layer, + num_attention_heads, + num_key_value_heads, + kv_size_multiplier, + linear_layer_weight_info=linear_weight_info, + linear_layer_bias_weight_info=linear_bias_weight_info, + ) + else: + maybe_load_linear_weight_to_parallel_linear( + layer, + linear_layer_weight_info=linear_weight_info, + linear_layer_bias_weight_info=linear_bias_weight_info, + ) + elif isinstance(layer, OptimumGQAQKVColumnParallelLinear): + maybe_load_weights_to_gqa_qkv_column_parallel_linear(model, layer) @classmethod @requires_neuronx_distributed @@ -303,9 +336,8 @@ def _initialize_or_load_weights( device: Optional[torch.device] = None, ): from neuronx_distributed import parallel_layers - from neuronx_distributed.parallel_layers.parallel_state import ( - get_tensor_model_parallel_rank, - ) + from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear + from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank weight_map = getattr(model, "_weight_map", {}) with torch.no_grad(): @@ -405,26 +437,61 @@ def _initialize_or_load_weights( if isinstance(mod, parallel_layers.layers.RowParallelLinear): axis = "row" input_is_parallel = mod.input_is_parallel - else: + elif isinstance(mod, parallel_layers.layers.ColumnParallelLinear): axis = "column" gather_output = mod.gather_output - fake_linear_mod = torch.nn.Linear(mod.input_size, mod.output_size) - left_uninitialized = try_to_hf_initialize(model, fake_linear_mod, parameter_names) - if left_uninitialized: - initialize_parallel_linear(mod, left_uninitialized) + elif isinstance(mod, GQAQKVColumnParallelLinear): + axis = "qga_qkv_column" + gather_output = mod.gather_output else: - fake_parallel_linear_mod = linear_to_parallel_linear( - fake_linear_mod, - axis, - input_is_parallel=input_is_parallel, - gather_output=gather_output, - sequence_parallel_enabled=mod.sequence_parallel_enabled, + raise RuntimeError( + f"This kind of parallel linear is not supported yet: {mod.__class__.__name__}" ) - mod.weight.data = fake_parallel_linear_mod.weight.data.clone() - if mod.bias is not None: - mod.bias.data = fake_parallel_linear_mod.bias.data.clone() + + if axis in ["row", "column"]: + fake_linear_mod = torch.nn.Linear(mod.input_size, mod.output_size) + left_uninitialized = try_to_hf_initialize(model, fake_linear_mod, parameter_names) + if left_uninitialized: + initialize_parallel_linear(mod, left_uninitialized) + else: + fake_parallel_linear_mod = linear_to_parallel_linear( + fake_linear_mod, + axis, + input_is_parallel=input_is_parallel, + gather_output=gather_output, + sequence_parallel_enabled=mod.sequence_parallel_enabled, + ) + mod.weight.copy_(fake_parallel_linear_mod.weight.data) + if mod.bias is not None: + mod.bias.copy_(fake_parallel_linear_mod.bias.data) + del fake_parallel_linear_mod del fake_linear_mod - del fake_parallel_linear_mod + else: + + def initialize(mod: GQAQKVColumnParallelLinear, proj_name: str, output_size: int): + fake_linear_mod = torch.nn.Linear(mod.input_size, output_size) + parameter_names_to_consider = [ + name for name in parameter_names if name.endswith(f"_{proj_name}") + ] + mapping = { + f"weight_{proj_name}": "weight", + f"bias_{proj_name}": "bias", + } + left_uninitialized = try_to_hf_initialize( + model, fake_linear_mod, parameter_names_to_consider, parameter_names_mapping=mapping + ) + if left_uninitialized: + initialize_parallel_linear(mod, left_uninitialized) + else: + # TODO: change kv heads. + maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( + mod, f"weight_{proj_name}", linear_layer=fake_linear_mod + ) + del fake_linear_mod + + initialize(mod, "q", mod.output_sizes[0]) + initialize(mod, "k", mod.output_sizes[1]) + initialize(mod, "v", mod.output_sizes[1]) else: left_uninitialized = try_to_hf_initialize(model, mod, parameter_names) if left_uninitialized and hasattr(mod, "reset_parameters"): @@ -438,6 +505,7 @@ def parallelize( device: Optional[torch.device] = None, parallelize_embeddings: bool = True, sequence_parallel_enabled: bool = False, + kv_size_multiplier: Optional[int] = None, pipeline_parallel_input_names: Optional[Union[Tuple[str, ...], List[str]]] = None, pipeline_parallel_num_microbatches: int = 1, pipeline_parallel_use_zero1_optimizer: bool = False, @@ -461,6 +529,10 @@ def parallelize( This can be disabled in the case when the TP size does not divide the vocabulary size. sequence_parallel_enabled (`bool`, defaults to `False`): Whether or not sequence parallelism is enabled. + kv_size_multiplier (`Optional[int], defaults to `None`): + The number of times to replicate the KV heads when the TP size is bigger than the number of KV heads. + If left unspecified, the smallest multiplier that makes the number of KV heads divisible by the TP size + will be used. pipeline_parallel_num_microbatches (`int`, defaults to 1): The number of microbatches used for pipeline execution. pipeline_parallel_use_zero1_optimizer (`bool`, defaults to `False`): @@ -484,6 +556,7 @@ def parallelize( get_pipeline_model_parallel_size, get_tensor_model_parallel_size, ) + from neuronx_distributed.parallel_layers.random import _MODEL_PARALLEL_RNG_TRACKER_NAME, get_xla_rng_tracker from neuronx_distributed.pipeline import NxDPPModel tp_size = get_tensor_model_parallel_size() @@ -493,13 +566,18 @@ def parallelize( # Parallelizing the model. # This needs to be done prior to preparing the model for sequence parallelism because modules can be overriden. + name_to_parameter = dict(named_parameters(model, remove_duplicate=False)) + parameter_to_name = {p: n for n, p in name_to_parameter.items()} names_of_the_parameters_to_consider = cls._get_parameter_names_for_current_pipeline( model, remove_duplicate=True ) - name_to_parameter = dict(named_parameters(model, remove_duplicate=False)) - parameter_to_name = {p: n for n, p in name_to_parameter.items()} + # We delay weight loading when the model was instantiated from pretrained lazily. + # We do not skip for cases such as: + # - Loaded a model `from_config`: in this case we simply initialize later in `_initialize_or_load_weights`. + # - Loaded a model `from_pretrained` but not lazily. + skip_linear_weight_load = hasattr(model, "_weight_map") def should_parallelize_layer_predicate_func(layer): if pp_size == 1: @@ -511,18 +589,55 @@ def should_parallelize_layer_predicate_func(layer): return names < names_of_the_parameters_to_consider if tp_size > 1: + # TODO: remove that once it is solved on the `neuronx_distributed` side. + try: + get_xla_rng_tracker().add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 42) + except Exception: + # It means that `_MODEL_PARALLEL_RNG_TRACKER_NAME` was already added to the rng tracker, we can ignore. + pass + model = cls._parallelize( model, device=device, parallelize_embeddings=parallelize_embeddings, sequence_parallel_enabled=sequence_parallel_enabled, should_parallelize_layer_predicate_func=should_parallelize_layer_predicate_func, - skip_linear_weight_load=True, + skip_linear_weight_load=skip_linear_weight_load, + kv_size_multiplier=kv_size_multiplier, ) xm.rendezvous("End of tensor parallelism") if is_main_worker(): logger.info("Tensor parallelism done.") + # We need to refresh the names because they might have changed after `_parallelize`. + # For instance if we changed regular linears to GQAQKVColumnParallelLinear. + names_of_the_parameters_to_consider = cls._get_parameter_names_for_current_pipeline( + model, remove_duplicate=True + ) + + # We need to retrieve this mapping here because PP works with `torch.fx` so we will not end-up with the same + # names after tracing. + gqa_qkv_metadata = { + "original_names_to_gqa_qkv_names": {}, + "output_projections_names": set(), + "num_attention_heads": None, + "num_key_value_heads": None, + "kv_size_multiplier": None, + } + for mod in model.modules(): + if isinstance(mod, OptimumGQAQKVColumnParallelLinear): + num_attention_heads = mod.num_attention_heads + num_key_value_heads = mod.num_key_value_heads + kv_size_multiplier = mod.kv_size_multiplier + gqa_qkv_metadata = { + "original_names_to_gqa_qkv_names": get_parameter_names_mapping_after_gqa_qkv_replacement(model), + "output_projections_names": get_output_projection_qualified_names_after_qga_qkv_replacement(model), + "num_attention_heads": num_attention_heads, + "num_key_value_heads": num_key_value_heads, + "kv_size_multiplier": kv_size_multiplier, + } + break + # Preparing the model for sequence parallelism: sp_specs_cls = cls.SEQUENCE_PARALLELSIM_SPECS_CLS @@ -551,11 +666,14 @@ def should_parallelize_layer_predicate_func(layer): if is_main_worker(): logger.info("Loading and initializing the weights, this might take a while on large models.") - # Load the weights to the parallel linears if the loading was skipped during parallelization. - cls._maybe_load_weights_to_parallel_linears(model) + if skip_linear_weight_load: + # Load the weights to the parallel linears if the loading was skipped during parallelization. + cls._maybe_load_weights_to_parallel_linears(model) + + if skip_linear_weight_load or any(p.device == torch.device("meta") for p in model.parameters()): + # Initialize or load the weights for the parallelized model if it was lazily loaded. + cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device) - # Initialize or load the weights for the parallelized model. - cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device) xm.rendezvous("End of initalization") if is_main_worker(): @@ -565,8 +683,8 @@ def should_parallelize_layer_predicate_func(layer): if not cls.supports_pipeline_parallelism(): raise NotImplementedError("{cls} does not support pipeline parallelism.") - model.config.return_dict = False model.config.use_cache = False + model.config.return_dict = False model.config.output_attentions = False model.config.output_hidden_states = False @@ -582,6 +700,7 @@ def should_parallelize_layer_predicate_func(layer): pipeline_cuts=cls.PIPELINE_PARALLELISM_SPECS_CLS.create_pipeline_cuts(model, pp_size), leaf_module_cls=cls.PIPELINE_PARALLELISM_SPECS_CLS.leaf_module_cls(), use_zero1_optimizer=pipeline_parallel_use_zero1_optimizer, + tracer_cls=OptimumNeuronFXTracer, ) if pipeline_parallel_gradient_checkpointing_enabled: apply_activation_checkpointing(model) @@ -590,9 +709,12 @@ def should_parallelize_layer_predicate_func(layer): if is_main_worker(): logger.info("Pipeline parallelism done.") + # TODO: can we optimize by skipping initialization and weight loading when `checkpoint_dir` is not None. if checkpoint_dir is not None: cls.load_model_checkpoint(model, checkpoint_dir) + model._gqa_qkv_metadata = gqa_qkv_metadata + return model @classmethod @@ -823,6 +945,7 @@ def save_model_checkpoint_as_sharded( state_dict["sharded_metadata"] = { k: asdict(v) for k, v in cls._get_parameters_tp_metadata(dict(model.named_parameters())).items() } + state_dict["gqa_qkv_metadata"] = model._gqa_qkv_metadata if optimizer is not None: # TODO: have metadata working for the optimizer. diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index f2ee9aed8..3466e8cc1 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -16,32 +16,58 @@ import json from pathlib import Path -from typing import Dict, Literal, Union +from typing import Any, Callable, Dict, List, Literal, Union import torch from transformers.modeling_utils import shard_checkpoint from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME -from ..utils.require_utils import requires_safetensors -from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata +from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors +from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indicies_for_rank -def consolidate_tensor_parallel_checkpoints(checkpoint_dir: Union[str, Path]) -> Dict[str, "torch.Tensor"]: - if not isinstance(checkpoint_dir, Path): - checkpoint_dir = Path(checkpoint_dir) - - if checkpoint_dir.name != TENSOR_PARALLEL_SHARDS_DIR_NAME: - if (checkpoint_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir(): - checkpoint_dir = checkpoint_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME - else: - raise ValueError(f"Could not find the tensor parallel shards from {checkpoint_dir}") - +def create_gqa_query_or_output_projection_weight_from_full_weight( + full_weight: torch.Tensor, + tp_size: int, + num_attention_heads: int, + num_key_value_heads: int, + kv_size_multiplier: int, + query_or_output: Union[Literal["query"], Literal["output"]], +): + assert query_or_output in ["query", "output"] + assert full_weight.device == torch.device("cpu") + if query_or_output == "query": + hidden_size = full_weight.size(1) + else: + hidden_size = full_weight.size(0) + full_weight = full_weight.transpose(0, 1) + + indicies = [ + compute_query_indicies_for_rank(tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier) + for tp_rank in range(tp_size) + ] + indicies = torch.cat(indicies, dim=0) + reversed_indicies = torch.sort(indicies, dim=0).indices + + full_weight = full_weight.reshape(num_attention_heads, -1, hidden_size) + full_weight = full_weight[reversed_indicies] + full_weight = full_weight.reshape(-1, hidden_size) + + if query_or_output == "output": + full_weight = full_weight.transpose(0, 1) + + return full_weight + + +def consolidate_tensor_parallel_checkpoints( + sharded_checkpoints: List[Path], load_function: Callable[[Union[str, Path]], Dict[str, Any]] +) -> Dict[str, "torch.Tensor"]: state_dicts = [] - - for sharded_checkpoint in sorted(checkpoint_dir.glob("tp_rank_*/checkpoint.pt")): + sharded_checkpoints = sorted(sharded_checkpoints) + for sharded_checkpoint in sharded_checkpoints: if not sharded_checkpoint.is_file(): continue - state_dicts.append(torch.load(sharded_checkpoint)) + state_dicts.append(load_function(sharded_checkpoint.as_posix())) parameter_names = state_dicts[0]["model"].keys() sharded_metadatas = { @@ -56,25 +82,106 @@ def consolidate_tensor_parallel_checkpoints(checkpoint_dir: Union[str, Path]) -> for name in parameter_names } + gqa_qkv_metadata = state_dicts[0]["gqa_qkv_metadata"] + original_parameter_names_to_gqa_qkv_names = gqa_qkv_metadata["original_names_to_gqa_qkv_names"] + gqa_qkv_output_projections_names = gqa_qkv_metadata["output_projections_names"] + gqa_qkv_names_to_original_names = {v: k for k, v in original_parameter_names_to_gqa_qkv_names.items()} + consolidated_state_dict = {} for name in parameter_names: + # We need to handle the mapping between the GQA parameter names and the original names. + is_gqa_qkv_weight = name in gqa_qkv_names_to_original_names + if is_gqa_qkv_weight: + original_name = gqa_qkv_names_to_original_names[name] + weight_name = name.rsplit(".", maxsplit=1)[1] + else: + original_name = name + weight_name = "" # Not needed. + # For now all parameter metadatas are equal so it is enough to take the first element. # This might not be the case anymore when `ParameterMetadata` uses slices. metadata = sharded_metadatas[name][0] if metadata.is_tied: - consolidated_state_dict[name] = state_dicts[0]["model"][name] + consolidated_state_dict[original_name] = state_dicts[0]["model"][name].to("cpu") else: - params = [state_dict["model"][name] for state_dict in state_dicts] - consolidated_state_dict[name] = torch.cat( - params, + weights = [state_dict["model"][name] for state_dict in state_dicts] + tp_size = len(weights) + full_weight = torch.cat( + weights, dim=metadata.partition_dim, ) + full_weight = full_weight.to("cpu") + if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]: + full_weight = ( + torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone() + ) + elif weight_name == "weight_q" or original_name in gqa_qkv_output_projections_names: + full_weight = create_gqa_query_or_output_projection_weight_from_full_weight( + full_weight, + tp_size, + gqa_qkv_metadata["num_attention_heads"], + gqa_qkv_metadata["num_key_value_heads"], + gqa_qkv_metadata["kv_size_multiplier"], + "query" if weight_name == "weight_q" else "output", + ) + consolidated_state_dict[original_name] = full_weight + + return consolidated_state_dict + + +@requires_neuronx_distributed +def consolidate_model_parallel_checkpoints(checkpoint_dir: Union[str, Path]) -> Dict[str, "torch.Tensor"]: + from neuronx_distributed.parallel_layers.checkpointing import _xser_load + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + if checkpoint_dir.name != TENSOR_PARALLEL_SHARDS_DIR_NAME: + if (checkpoint_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir(): + checkpoint_dir = checkpoint_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME + else: + raise ValueError(f"Could not find the tensor parallel shards from {checkpoint_dir}") + + # Regular case: the checkpoint was saved without xser. + sharded_checkpoints = list(checkpoint_dir.glob("tp_rank_*/checkpoint.pt")) + load_function = torch.load + + # If no file was found, maybe the checkpoint was saved with xser. + if not sharded_checkpoints: + sharded_checkpoints = checkpoint_dir.glob("tp_rank_*") + sharded_checkpoints = [p for p in sharded_checkpoints if not p.name.endswith("tensors")] + load_function = _xser_load + + if not sharded_checkpoints: + raise ValueError(f"Could not find any sharded checkpoint in {checkpoint_dir.as_posix()}") + + def get_checkpoint_name(checkpoint_path: Path) -> str: + name = checkpoint_path.name + if name == "checkpoint.pt": + name = checkpoint_path.parent.name + return name + + pp_size = max((int(get_checkpoint_name(checkpoint_path)[-2:]) for checkpoint_path in sharded_checkpoints)) + 1 + checkpoints_grouped_by_pp_ranks = [[] for _ in range(pp_size)] + for pp_rank in range(pp_size): + for checkpoint_path in sharded_checkpoints: + checkpoint_name = get_checkpoint_name(checkpoint_path) + if int(checkpoint_name[-2:]) == pp_rank: + checkpoints_grouped_by_pp_ranks[pp_rank].append(checkpoint_path) + + consolidated_state_dict = {} + for checkpoint_group_for_pp_rank in checkpoints_grouped_by_pp_ranks: + consolidated_for_pp_rank = consolidate_tensor_parallel_checkpoints(checkpoint_group_for_pp_rank, load_function) + consolidated_state_dict.update(**consolidated_for_pp_rank) + + for key, tensor in consolidated_state_dict.items(): + consolidated_state_dict[key] = tensor return consolidated_state_dict @requires_safetensors -def consolidate_tensor_parallel_checkpoints_to_unified_checkpoint( +def consolidate_model_parallel_checkpoints_to_unified_checkpoint( checkpoint_dir: Union[str, Path], output_dir: Union[str, Path], save_format: Literal["pytorch", "safetensors"] = "safetensors", @@ -86,7 +193,7 @@ def consolidate_tensor_parallel_checkpoints_to_unified_checkpoint( output_dir.mkdir(parents=True, exist_ok=True) - state_dict = consolidate_tensor_parallel_checkpoints(checkpoint_dir) + state_dict = consolidate_model_parallel_checkpoints(checkpoint_dir) shards, index = shard_checkpoint( state_dict, weights_name=SAFE_WEIGHTS_NAME if save_format == "safetensors" else WEIGHTS_NAME ) diff --git a/optimum/neuron/distributed/decoder_models.py b/optimum/neuron/distributed/decoder_models.py index 61e9e2e33..766927977 100644 --- a/optimum/neuron/distributed/decoder_models.py +++ b/optimum/neuron/distributed/decoder_models.py @@ -45,7 +45,7 @@ ParallelSelfAttentionWithFusedQKV, SequenceCollectiveOpInfo, ) -from .utils import linear_to_parallel_linear +from .utils import get_linear_weight_info, linear_to_parallel_linear if TYPE_CHECKING: @@ -396,7 +396,7 @@ def _transform( if weight_map is not None: layer_to_fully_qualified_name = {id(module): name for name, module in model.named_modules()} layer_qualified_name = layer_to_fully_qualified_name[id(layer)] - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( weight_map, f"{layer_qualified_name}.{attribute_name}", device=device, @@ -699,7 +699,7 @@ def transform( if weight_map is not None: layer_to_fully_qualified_name = {id(module): name for name, module in model.named_modules()} layer_qualified_name = layer_to_fully_qualified_name[id(layer)] - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( weight_map, f"{layer_qualified_name}.{attribute_name}", device=device, diff --git a/optimum/neuron/distributed/encoder_decoder_models.py b/optimum/neuron/distributed/encoder_decoder_models.py index 8faf953f7..0af70494c 100644 --- a/optimum/neuron/distributed/encoder_decoder_models.py +++ b/optimum/neuron/distributed/encoder_decoder_models.py @@ -29,7 +29,7 @@ ParallelSelfAttention, SequenceCollectiveOpInfo, ) -from .utils import linear_to_parallel_linear +from .utils import get_linear_weight_info, linear_to_parallel_linear if TYPE_CHECKING: @@ -139,7 +139,7 @@ def transform( module, attribute_name = cls._get_module_and_attribute_name(layer, cls.FIRST_LINEAR_NAME) if weight_map is not None: layer_qualified_name = layer_to_fully_qualified_name[id(module)] - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( weight_map, f"{layer_qualified_name}.{attribute_name}", device=device, diff --git a/optimum/neuron/distributed/parallel_layers.py b/optimum/neuron/distributed/parallel_layers.py index 1ae22762b..c0f97bc5d 100644 --- a/optimum/neuron/distributed/parallel_layers.py +++ b/optimum/neuron/distributed/parallel_layers.py @@ -19,7 +19,6 @@ from abc import ABC, abstractclassmethod from dataclasses import dataclass from enum import Enum -from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union import torch @@ -30,11 +29,15 @@ from ..utils.misc import is_main_worker from ..utils.require_utils import requires_neuronx_distributed from .utils import ( - GroupedQueryAttentionInfo, + FakeProj, + OptimumGQAQKVColumnParallelLinear, WeightInformation, embedding_to_parallel_embedding, - gqa_key_value_slicing_when_tp_size_greater_than_num_key_value_heads, + get_linear_weight_info, linear_to_parallel_linear, + mark_parameter_init_status_during_parallelization, + maybe_load_weights_to_gqa_qkv_column_parallel_linear, + maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear, ) @@ -68,44 +71,6 @@ def _get_module_and_attribute_name( attribute_name = split[1] return leaf_module, attribute_name - @classmethod - def _get_linear_weight_info( - cls, - weight_map: Dict[str, Union[Path, str]], - linear_layer_qualified_name: str, - device: Optional["torch.device"] = None, - fail_if_not_found: bool = True, - ) -> Tuple[Optional[WeightInformation], Optional[WeightInformation]]: - linear_layer_weight_qualified_name = f"{linear_layer_qualified_name}.weight" - if linear_layer_weight_qualified_name not in weight_map: - if fail_if_not_found: - raise ValueError( - f"Could not find the linear weight called {linear_layer_weight_qualified_name} in the weight map." - ) - else: - linear_layer_weight_info = None - else: - linear_layer_weight_info = WeightInformation( - weight_map[linear_layer_weight_qualified_name], - linear_layer_weight_qualified_name, - weight_map=weight_map, - device=device, - ) - - linear_layer_bias_qualified_name = f"{linear_layer_qualified_name}.bias" - linear_layer_bias_filename = weight_map.get(linear_layer_bias_qualified_name, None) - if linear_layer_bias_filename is not None: - linear_layer_bias_weight_info = WeightInformation( - linear_layer_bias_filename, - linear_layer_bias_qualified_name, - weight_map=weight_map, - device=device, - ) - else: - linear_layer_bias_weight_info = None - - return linear_layer_weight_info, linear_layer_bias_weight_info - @abstractclassmethod def _transform( cls, @@ -329,7 +294,7 @@ class ParallelSelfAttention(ParallelLayer): If left unspecified, the attribute will be fetched by using the NormalizedConfig associated to the model. """ - PARALLEL_LAYER_SPECIFIC_KWARGS = {"skip_linear_weight_load": False} + PARALLEL_LAYER_SPECIFIC_KWARGS = {"skip_linear_weight_load": False, "kv_size_multiplier": None} QUERIES_NAME = "query" KEYS_NAME = "key" @@ -340,6 +305,126 @@ class ParallelSelfAttention(ParallelLayer): NUM_KEY_VALUE_GROUPS_NAME: Optional[str] = None ALL_HEAD_SIZE_NAME: Optional[str] = None + GQA_QKV_PROJ_NAME: str = "qkv_proj" + + @classmethod + def get_layer_qualified_name(cls, model: torch.nn.Module, layer: torch.nn.Module) -> str: + layer_to_fully_qualified_name = {id(module): name for name, module in model.named_modules()} + return layer_to_fully_qualified_name[id(layer)] + + @classmethod + def patch_proj_to_use_gqa_qkv_column_parallel_linear( + cls, + attention_layer: torch.nn.Module, + attention_layer_qualified_name: str, + proj_qualified_name: str, + proj_name: str, + output_index: int, + ): + fake_proj = FakeProj( + proj_qualified_name, + proj_name, + output_index, + lambda: attention_layer, + attention_layer_qualified_name, + cls.GQA_QKV_PROJ_NAME, + ) + + setattr(attention_layer, proj_name, fake_proj) + + @classmethod + @requires_neuronx_distributed + def replace_qkv_by_gqa_qkv_column_parallel_linear( + cls, + model: "torch.nn.Module", + attention_layer: "torch.nn.Module", + sequence_parallel_enabled: bool = False, + kv_size_multiplier: Optional[int] = None, + skip_linear_weight_load: bool = False, + ): + from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size + + if cls.NUM_KEY_VALUE_HEADS_NAME is None: + raise ValueError(f"{cls} does not defined the name of the number of key value heads.") + tp_size = get_tensor_model_parallel_size() + num_key_value_heads = getattr(attention_layer, cls.NUM_KEY_VALUE_HEADS_NAME) + if tp_size < num_key_value_heads: + raise ValueError( + f"The TP size ({tp_size}) is lower than the number of key value heads, using " + "GQAQKVColumnParallelLinear is not needed." + ) + + num_attention_heads = getattr(attention_layer, cls.NUM_ATTENTION_HEADS_NAME) + query_linear = getattr(attention_layer, cls.QUERIES_NAME) + key_linear = getattr(attention_layer, cls.KEYS_NAME) + + hidden_size = query_linear.weight.size(1) + query_in_features = query_linear.weight.size(0) + key_value_in_features = key_linear.weight.size(0) + + if kv_size_multiplier is None: + kv_size_multiplier = get_tensor_model_parallel_size() // num_key_value_heads + + device = query_linear.weight.device + if device == torch.device("meta"): + device = None + + gqa_qkv_column_parallel_linear = OptimumGQAQKVColumnParallelLinear( + cls.QUERIES_NAME, + cls.KEYS_NAME, + cls.VALUES_NAME, + cls.OUTPUT_PROJECTION_NAME, + num_attention_heads, + num_key_value_heads, + hidden_size, + [query_in_features, key_value_in_features], + gather_output=False, + bias=query_linear.bias is not None, + sequence_parallel_enabled=sequence_parallel_enabled, + device=device, + kv_size_multiplier=kv_size_multiplier, + ) + + setattr(attention_layer, cls.GQA_QKV_PROJ_NAME, gqa_qkv_column_parallel_linear) + + maybe_load_weights_to_gqa_qkv_column_parallel_linear( + model, + gqa_qkv_column_parallel_linear, + try_from_checkpoint=not skip_linear_weight_load, + try_from_original_layer=not skip_linear_weight_load, + ) + + attention_layer_qualified_name = cls.get_layer_qualified_name(model, attention_layer) + fake_q_proj = FakeProj( + f"{attention_layer_qualified_name}.{cls.QUERIES_NAME}", + "q", + 0, + lambda: attention_layer, + attention_layer_qualified_name, + cls.GQA_QKV_PROJ_NAME, + ) + setattr(attention_layer, cls.QUERIES_NAME, fake_q_proj) + + fake_k_proj = FakeProj( + f"{attention_layer_qualified_name}.{cls.KEYS_NAME}", + "k", + 1, + lambda: attention_layer, + attention_layer_qualified_name, + cls.GQA_QKV_PROJ_NAME, + ) + setattr(attention_layer, cls.KEYS_NAME, fake_k_proj) + + fake_v_proj = FakeProj( + f"{attention_layer_qualified_name}.{cls.VALUES_NAME}", + "v", + 2, + lambda: attention_layer, + attention_layer_qualified_name, + cls.GQA_QKV_PROJ_NAME, + ) + setattr(attention_layer, cls.VALUES_NAME, fake_v_proj) + @classmethod @requires_neuronx_distributed def _transform( @@ -356,6 +441,7 @@ def _transform( raise AttributeError("Both NUM_KEY_VALUE_HEADS_NAME and NUM_KEY_VALUE_GROUPS_NAME must be specified.") skip_linear_weight_load = parallel_layer_specific_kwargs["skip_linear_weight_load"] + kv_size_multiplier = parallel_layer_specific_kwargs["kv_size_multiplier"] from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size @@ -365,11 +451,8 @@ def _transform( config = model.config normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) - if weight_map is not None: - layer_to_fully_qualified_name = {id(module): name for name, module in model.named_modules()} - layer_qualified_name = layer_to_fully_qualified_name[id(layer)] - else: - layer_qualified_name = "" + layer_to_fully_qualified_name = {id(module): name for name, module in model.named_modules()} + layer_qualified_name = layer_to_fully_qualified_name[id(layer)] if cls.NUM_ATTENTION_HEADS_NAME is None: num_attention_heads_name = normalized_config.NUM_ATTENTION_HEADS @@ -395,37 +478,27 @@ def _transform( raise ValueError( "Only the cases where the number of key value heads is divisible by the TP size, or the other way around are supported." ) - elif is_main_worker() and num_key_value_heads < tp_size: - logger.warning( - f"The TP size ({tp_size}) is bigger than the number of key value heads ({num_key_value_heads}). " - "This is not ideal because the key and value projections will not be sharded accross the TP ranks. " - "For better performance choose the number of key value heads to be divisible by the TP size." - ) - kv_heads_are_parallelized = num_key_value_heads >= tp_size + needs_gqa_qkv_column_parallel_linear = num_key_value_heads < tp_size else: num_key_value_heads = getattr(layer, num_attention_heads_name) - kv_heads_are_parallelized = True + needs_gqa_qkv_column_parallel_linear = False - for name in [cls.QUERIES_NAME, cls.KEYS_NAME, cls.VALUES_NAME]: - linear_layer_weight_info, linear_layer_bias_weight_info = None, None - if weight_map is not None: - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( + if needs_gqa_qkv_column_parallel_linear: + cls.replace_qkv_by_gqa_qkv_column_parallel_linear( + model, + layer, + sequence_parallel_enabled=sequence_parallel_enabled, + kv_size_multiplier=kv_size_multiplier, + skip_linear_weight_load=skip_linear_weight_load, + ) + else: + for name in [cls.QUERIES_NAME, cls.KEYS_NAME, cls.VALUES_NAME]: + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( weight_map, f"{layer_qualified_name}.{name}", device=device, + fail_if_not_found=False, ) - # Under GQA setting with num_key_value_heads < tp_size, the key and value projections are replicated accross - # workers. - if not kv_heads_are_parallelized and name in [cls.KEYS_NAME, cls.VALUES_NAME]: - gqa_info = GroupedQueryAttentionInfo(num_attention_heads, num_key_value_heads) - parallel_linear = gqa_key_value_slicing_when_tp_size_greater_than_num_key_value_heads( - gqa_info, - getattr(layer, name), - linear_layer_weight_info=linear_layer_weight_info, - linear_layer_bias_weight_info=linear_layer_bias_weight_info, - device=device, - ) - else: parallel_linear = linear_to_parallel_linear( getattr(layer, name), "column", @@ -436,30 +509,43 @@ def _transform( skip_weight_load=skip_linear_weight_load, device=device, ) - setattr(layer, name, parallel_linear) + setattr(layer, name, parallel_linear) if cls.OUTPUT_PROJECTION_NAME is not None: - linear_layer_weight_info, linear_layer_bias_weight_info = None, None - if weight_map is not None: - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( - weight_map, - f"{layer_qualified_name}.{cls.OUTPUT_PROJECTION_NAME}", - device=device, - ) - setattr( - layer, - cls.OUTPUT_PROJECTION_NAME, - linear_to_parallel_linear( - getattr(layer, cls.OUTPUT_PROJECTION_NAME), - "row", - input_is_parallel=True, + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( + weight_map, + f"{layer_qualified_name}.{cls.OUTPUT_PROJECTION_NAME}", + device=device, + fail_if_not_found=False, + ) + parallel_output_proj = linear_to_parallel_linear( + getattr(layer, cls.OUTPUT_PROJECTION_NAME), + "row", + input_is_parallel=True, + linear_layer_weight_info=linear_layer_weight_info, + linear_layer_bias_weight_info=linear_layer_bias_weight_info, + sequence_parallel_enabled=sequence_parallel_enabled, + skip_weight_load=skip_linear_weight_load, + device=device, + ) + + if needs_gqa_qkv_column_parallel_linear: + qga_qkv_layer = getattr(layer, cls.GQA_QKV_PROJ_NAME) + # We need to re-initialize the output projection in this case since the queries are "shuffled". + mark_parameter_init_status_during_parallelization(parallel_output_proj.weight, False) + maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear( + parallel_output_proj, + qga_qkv_layer.num_attention_heads, + qga_qkv_layer.num_key_value_heads, + qga_qkv_layer.kv_size_multiplier, + original_output_projection=getattr(layer, cls.OUTPUT_PROJECTION_NAME), linear_layer_weight_info=linear_layer_weight_info, linear_layer_bias_weight_info=linear_layer_bias_weight_info, - sequence_parallel_enabled=sequence_parallel_enabled, - skip_weight_load=skip_linear_weight_load, - device=device, - ), - ) + try_from_checkpoint=not skip_linear_weight_load, + try_from_original_layer=not skip_linear_weight_load, + ) + + setattr(layer, cls.OUTPUT_PROJECTION_NAME, parallel_output_proj) setattr( layer, @@ -472,7 +558,7 @@ def _transform( # Since those heads end-up sharded accross TP ranks just as the query heads, only the number of kv heads # needs to be updated. The number of query groups remains the same here because it is the ratio between the # number of query heads and the number of kv heads. - if kv_heads_are_parallelized: + if not needs_gqa_qkv_column_parallel_linear: setattr( layer, cls.NUM_KEY_VALUE_HEADS_NAME, @@ -483,15 +569,17 @@ def _transform( # In this case, multiple ranks will end-up with the same kv head, and each rank will only have one kv head # and query group. else: + gqa_qkv_proj = getattr(layer, cls.GQA_QKV_PROJ_NAME) + new_num_key_value_heads = (num_key_value_heads * gqa_qkv_proj.kv_size_multiplier) // tp_size setattr( layer, cls.NUM_KEY_VALUE_HEADS_NAME, - 1, + new_num_key_value_heads, ) setattr( layer, cls.NUM_KEY_VALUE_GROUPS_NAME, - 1, + getattr(layer, num_attention_heads_name) // new_num_key_value_heads, ) setattr( @@ -573,7 +661,7 @@ def _transform( linear_layer_weight_info, linear_layer_bias_weight_info = None, None if weight_map is not None: - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( weight_map, f"{layer_qualified_name}.{cls.QUERY_KEY_VALUE_NAME}", device=device, @@ -596,7 +684,7 @@ def _transform( if cls.OUTPUT_PROJECTION_NAME is not None: linear_layer_weight_info, linear_layer_bias_weight_info = None, None if weight_map is not None: - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( weight_map, f"{layer_qualified_name}.{cls.OUTPUT_PROJECTION_NAME}", device=device, @@ -658,7 +746,7 @@ def _transform( if weight_map is not None: layer_to_fully_qualified_name = {id(module): name for name, module in model.named_modules()} layer_qualified_name = layer_to_fully_qualified_name[id(layer)] - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( weight_map, f"{layer_qualified_name}.{cls.OUTPUT_PROJECTION_NAME}", device=device, @@ -714,7 +802,7 @@ def _transform( module, attribute_name = cls._get_module_and_attribute_name(layer, cls.FIRST_LINEAR_NAME) if weight_map is not None: layer_qualified_name = layer_to_fully_qualified_name[id(module)] - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( weight_map, f"{layer_qualified_name}.{attribute_name}", device=device, @@ -739,7 +827,7 @@ def _transform( linear_layer_weight_info, linear_layer_bias_weight_info = None, None if weight_map is not None: layer_qualified_name = layer_to_fully_qualified_name[id(module)] - linear_layer_weight_info, linear_layer_bias_weight_info = cls._get_linear_weight_info( + linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info( weight_map, f"{layer_qualified_name}.{attribute_name}", device=device, @@ -906,7 +994,7 @@ def _transform( layer_to_fully_qualified_name = {id(module): name for name, module in model.named_modules()} linear_projection_qualified_name = layer_to_fully_qualified_name[id(linear_projection)] try: - linear_projection_weight_info, linear_projection_bias_weight_info = cls._get_linear_weight_info( + linear_projection_weight_info, linear_projection_bias_weight_info = get_linear_weight_info( weight_map, linear_projection_qualified_name, device=device, diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 5320f8bc8..5a34fb3a3 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -22,7 +22,7 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Set, Tuple, Type, Union import torch from transformers import PretrainedConfig @@ -41,8 +41,19 @@ if is_neuronx_distributed_available(): + from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear from neuronx_distributed.parallel_layers import layers from neuronx_distributed.pipeline import NxDPPModel + from neuronx_distributed.pipeline.trace import HFTracerWrapper +else: + + class GQAQKVColumnParallelLinear(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + from transformers.utils.fx import HFTracer + + HFTracerWrapper = HFTracer if TYPE_CHECKING: @@ -135,6 +146,148 @@ def __post_init__(self): ) +class FakeProj(torch.nn.Module): + """ + Dummy layer that replaces a Linear projection by gathering the result from its associated merged + QGAQKVColumnParallelLinear. + """ + + def __init__( + self, + fully_qualified_name: str, + proj_name: str, + output_index: int, + get_parent_module: Callable[[], torch.nn.Module], + parent_module_fully_qualified_name: str, + gqa_qkv_proj_name: str, + ): + super().__init__() + self.fully_qualified_name = fully_qualified_name + self.proj_name = proj_name + self.output_index = output_index + self.get_parent_module = get_parent_module + self.parent_module_fully_qualified_name = parent_module_fully_qualified_name + self.gqa_qkv_proj_name = gqa_qkv_proj_name + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + parent_module = self.get_parent_module() + gqa_qkv_column_parallel_linear = getattr(parent_module, self.gqa_qkv_proj_name) + if not hasattr(parent_module, "_gqa_qkv_output"): + parent_module._gqa_qkv_output = gqa_qkv_column_parallel_linear(hidden_states) + parent_module._gqa_qkv_output_fetch_counter = 0 + parent_module._gqa_qkv_output_fetch_counter += 1 + output = parent_module._gqa_qkv_output[self.output_index] + if parent_module._gqa_qkv_output_fetch_counter == 3: + del parent_module._gqa_qkv_output + return output + + +class OptimumGQAQKVColumnParallelLinear(GQAQKVColumnParallelLinear): + """ + Same as GQAQKVColumnParallelLinear with the needed metadata for `optimum-neuron`. + """ + + def __init__( + self, + query_proj_name: str, + key_proj_name: str, + value_proj_name: str, + output_proj_name: str, + num_attention_heads: int, + num_key_value_heads: int, + input_size: int, + output_sizes: int, + bias: bool = True, + gather_output: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + init_method: Optional[Callable] = None, + sequence_parallel_enabled: bool = False, + keep_master_weight: bool = False, + kv_size_multiplier: int = 1, + ): + super().__init__( + input_size, + output_sizes, + bias=bias, + gather_output=gather_output, + dtype=dtype, + device=device, + init_method=init_method, + sequence_parallel_enabled=sequence_parallel_enabled, + keep_master_weight=keep_master_weight, + kv_size_multiplier=kv_size_multiplier, + ) + + self.query_proj_name = query_proj_name + self.key_proj_name = key_proj_name + self.value_proj_name = value_proj_name + self.output_proj_name = output_proj_name + + self._qkv_proj_name_to_proj_name = {"q": query_proj_name, "k": key_proj_name, "v": value_proj_name} + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + def get_parameter_names_mapping( + self, module_to_name: Dict[torch.nn.Module, str], reversed: bool = False + ) -> Dict[str, str]: + fully_qualified_name = module_to_name[self] + parent_module_name, _ = fully_qualified_name.rsplit(".", maxsplit=1) + mapping = {} + for qkv_proj_name, proj_name in self._qkv_proj_name_to_proj_name.items(): + mapping[f"{parent_module_name}.{proj_name}.weight"] = f"{fully_qualified_name}.weight_{qkv_proj_name}" + if self.use_bias: + mapping[f"{parent_module_name}.{proj_name}.bias"] = f"{fully_qualified_name}.bias_{qkv_proj_name}" + if reversed: + mapping = {v: k for k, v in mapping.items()} + return mapping + + +@requires_neuronx_distributed +def get_parameter_names_mapping_after_gqa_qkv_replacement( + model: torch.nn.Module, reversed: bool = False +) -> Dict[str, str]: + """ + Returns the mapping between the original projection names and their names after replacing them with + GQAQKVColumnParallelLinear. + """ + from neuronx_distributed.pipeline import NxDPPModel + + mapping = {} + if isinstance(model, NxDPPModel): + named_modules = dict(model.local_named_modules()) + else: + named_modules = dict(model.named_modules()) + module_to_name = {v: k for k, v in named_modules.items()} + for _, mod in named_modules.items(): + if isinstance(mod, OptimumGQAQKVColumnParallelLinear): + mapping.update(**mod.get_parameter_names_mapping(module_to_name, reversed=reversed)) + return mapping + + +@requires_neuronx_distributed +def get_output_projection_qualified_names_after_qga_qkv_replacement(model: torch.nn.Module) -> Set[str]: + """ + Returns the names of the output projections inside the attention layer, these are needed when using + GQAQKVColumnParallelLinear. + """ + from neuronx_distributed.pipeline import NxDPPModel + + qualified_names = set() + if isinstance(model, NxDPPModel): + named_modules = dict(model.local_named_modules()) + else: + named_modules = dict(model.named_modules()) + for name, mod in named_modules.items(): + if isinstance(mod, OptimumGQAQKVColumnParallelLinear): + parent_name = name.rsplit(".", maxsplit=1)[0] + output_projection_name = f"{parent_name}.{mod.output_proj_name}" + qualified_names.add(f"{output_projection_name}.weight") + if model.get_submodule(output_projection_name).bias is not None: + qualified_names.add(f"{output_projection_name}.bias") + return qualified_names + + @requires_safetensors def load_tensor_for_weight( weight_info: WeightInformation, tensor_slices: Optional[Tuple[Optional[Tuple[int, ...]], ...]] = None @@ -157,7 +310,9 @@ def load_tensor_for_weight( """ from safetensors import safe_open - device = str(weight_info.device) + # TODO: for now `safetensors` does not support loading directly to the `xla` device. + # device = str(weight_info.device) + device = "cpu" with safe_open(weight_info.filename, framework="pt", device=device) as fp: if tensor_slices is None: tensor = fp.get_tensor(weight_info.qualified_name) @@ -293,6 +448,346 @@ def embedding_to_parallel_embedding( return parallel_embedding_layer, parallel_lm_head_layer +def get_linear_weight_info( + weight_map: Optional[Dict[str, Union[Path, str]]], + linear_layer_qualified_name: str, + device: Optional[torch.device] = None, + fail_if_not_found: bool = True, +) -> Tuple[Optional[WeightInformation], Optional[WeightInformation]]: + linear_layer_weight_qualified_name = f"{linear_layer_qualified_name}.weight" + if weight_map is None: + weight_map = {} + if linear_layer_weight_qualified_name not in weight_map: + if fail_if_not_found: + raise ValueError( + f"Could not find the linear weight called {linear_layer_weight_qualified_name} in the weight map." + ) + else: + linear_layer_weight_info = None + else: + linear_layer_weight_info = WeightInformation( + weight_map[linear_layer_weight_qualified_name], + linear_layer_weight_qualified_name, + weight_map=weight_map, + device=device, + ) + + linear_layer_bias_qualified_name = f"{linear_layer_qualified_name}.bias" + linear_layer_bias_filename = weight_map.get(linear_layer_bias_qualified_name, None) + if linear_layer_bias_filename is not None: + linear_layer_bias_weight_info = WeightInformation( + linear_layer_bias_filename, + linear_layer_bias_qualified_name, + weight_map=weight_map, + device=device, + ) + else: + linear_layer_bias_weight_info = None + + return linear_layer_weight_info, linear_layer_bias_weight_info + + +@requires_neuronx_distributed +def create_kv_proj_local_weight_from_regular_weight( + weight_data: torch.Tensor, kv_size_multiplier: int, output_size_per_partition: int +) -> torch.Tensor: + """ + Creates the local version of the key or value projections weight for the given TP rank when using + GQAQKVColumnParallelLinear. + """ + assert not isinstance(weight_data, torch.nn.Parameter) + from neuronx_distributed.parallel_layers.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_size, + ) + + tp_size = get_tensor_model_parallel_size() + tp_rank = get_tensor_model_parallel_rank() + repeated_weight = weight_data.repeat(kv_size_multiplier, 1) + split = torch.split(repeated_weight, output_size_per_partition, dim=0) + return torch.cat(split[tp_rank::tp_size], dim=0) + + +def compute_query_indicies_for_rank( + tp_size: int, tp_rank: int, num_attention_heads: int, num_key_value_heads: int, kv_size_multiplier: int +): + """ + Computes the permutation for the query weight wheun using GQAQKVColumnParallelLinear. + """ + num_attention_heads_per_rank = num_attention_heads // tp_size + num_key_value_heads_per_rank = (num_key_value_heads * kv_size_multiplier) // tp_size + query_group_size = num_attention_heads // num_key_value_heads + query_group_size_per_rank = num_attention_heads_per_rank // num_key_value_heads_per_rank + + queries_indicies = [torch.arange(query_group_size_per_rank) for _ in range(num_key_value_heads_per_rank)] + + keys_indicies = torch.arange(num_key_value_heads).repeat(kv_size_multiplier) + keys_indicies = torch.repeat_interleave( + keys_indicies, num_attention_heads_per_rank // num_key_value_heads_per_rank + ) + keys_indicies = torch.chunk(keys_indicies, tp_size) + + shift_per_key = torch.arange(0, num_attention_heads, query_group_size) + + shift_within_query_group = torch.arange(0, query_group_size, query_group_size_per_rank) + shift_within_query_group = torch.repeat_interleave( + shift_within_query_group, num_attention_heads_per_rank * num_key_value_heads_per_rank + ) + shift_within_query_group = torch.chunk(shift_within_query_group, tp_size) + + indicies = [] + for idx, q_indicies in enumerate(queries_indicies): + s = slice(idx * query_group_size_per_rank, (idx + 1) * query_group_size_per_rank) + k_indicies = keys_indicies[tp_rank][s] + k_shift = shift_per_key[k_indicies] + group_shift = shift_within_query_group[tp_rank][s] + indicies.append(q_indicies + k_shift + group_shift) + + indicies = torch.cat(indicies, dim=0) + return indicies + + +@requires_neuronx_distributed +def create_query_or_output_projection_local_weight_from_regular_weight( + weight_data: torch.Tensor, + num_attention_heads: int, + num_key_value_heads: int, + kv_size_multiplier: int, + query_or_output_proj: Union[Literal["query"], Literal["output"]], +) -> torch.Tensor: + """ + Creates the local version of the query or output projections weight for the given TP rank when using + GQAQKVColumnParallelLinear. + """ + assert query_or_output_proj in ["query", "output"] + assert not isinstance(weight_data, torch.nn.Parameter) + + from neuronx_distributed.parallel_layers.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_size, + ) + + tp_size = get_tensor_model_parallel_size() + tp_rank = get_tensor_model_parallel_rank() + + if query_or_output_proj == "query": + hidden_size = weight_data.size(1) + head_dim = weight_data.size(0) // num_attention_heads + else: + hidden_size = weight_data.size(0) + head_dim = weight_data.size(1) // num_attention_heads + weight_data = weight_data.transpose(0, 1) + + indicies = compute_query_indicies_for_rank( + tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier + ) + reshaped_weight = weight_data.view(num_attention_heads, head_dim, hidden_size) + shuffled_weight = reshaped_weight[indicies] + shuffled_weight = shuffled_weight.reshape(-1, hidden_size) + + if query_or_output_proj == "output": + shuffled_weight = shuffled_weight.transpose(0, 1) + + return shuffled_weight + + +def create_local_bias_from_regular_bias( + bias_weigth_data: torch.Tensor, + num_attention_heads: int, + num_key_value_heads: int, + kv_size_multiplier: int, + query_or_key_value_bias: Union[Literal["query"], Literal["key_value"]], + gather_output: bool, +) -> torch.Tensor: + """ + Creates the local version of the query, key and value projections bias for the given TP rank when using + GQAQKVColumnParallelLinear. + """ + assert query_or_key_value_bias in ["query", "key_value"] + assert not isinstance(bias_weigth_data, torch.nn.Parameter) + from neuronx_distributed.parallel_layers.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_size, + ) + + tp_size = get_tensor_model_parallel_size() + tp_rank = get_tensor_model_parallel_rank() + + if query_or_key_value_bias == "key_value": + local_bias_weight = bias_weigth_data.repeat(kv_size_multiplier) + if not gather_output: + local_bias_weight = local_bias_weight.chunk(tp_size)[tp_rank] + + else: + if gather_output: + indicies = torch.cat( + [ + compute_query_indicies_for_rank( + tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier + ) + for tp_rank in range(tp_size) + ], + dim=0, + ) + else: + indicies = compute_query_indicies_for_rank( + tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier + ) + reshaped_bias_weight = bias_weigth_data.view(num_attention_heads, -1) + shuffled_bias_weight = reshaped_bias_weight[indicies] + local_bias_weight = shuffled_bias_weight.reshape(-1) + return local_bias_weight + + +@requires_neuronx_distributed +def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( + layer: OptimumGQAQKVColumnParallelLinear, + weight_name: str, + linear_layer_weight_info: Optional[WeightInformation] = None, + linear_layer_bias_weight_info: Optional[WeightInformation] = None, + linear_layer: Optional["torch.nn.Linear"] = None, +): + if ( + linear_layer_weight_info is not None or linear_layer_bias_weight_info is not None + ) and linear_layer is not None: + raise ValueError( + "Specify either a linear layer's WeightInformation, or a linear layer to copy the weights from, but not both." + ) + if linear_layer_weight_info is None and linear_layer_bias_weight_info is None and linear_layer is None: + raise ValueError( + "A linear's layer WeightInformation or a linear layer to copy the weights from need to specified." + ) + + proj_name = weight_name[-1] + weight = getattr(layer, weight_name) + bias = getattr(layer, f"bias_{proj_name}") + + num_attention_heads = layer.num_attention_heads + num_key_value_heads = layer.num_key_value_heads + kv_size_multiplier = layer.kv_size_multiplier + + with torch.no_grad(): + if not was_already_initialized_during_parallelization(weight): + weight_data = None + if linear_layer_weight_info is not None: + weight_data = load_tensor_for_weight(linear_layer_weight_info) + elif linear_layer is not None and linear_layer.weight.device != torch.device("meta"): + weight_data = linear_layer.weight.data + if weight_data is not None: + if proj_name in "kv": + weight_data = create_kv_proj_local_weight_from_regular_weight( + weight_data, kv_size_multiplier, weight.size(0) + ) + else: + weight_data = create_query_or_output_projection_local_weight_from_regular_weight( + weight_data, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query" + ) + weight.copy_(weight_data) + mark_parameter_init_status_during_parallelization(weight, True) + else: + mark_parameter_init_status_during_parallelization(weight, False) + + if bias is not None: + if not was_already_initialized_during_parallelization(bias): + bias_weight_data = None + if linear_layer_bias_weight_info is not None: + bias_weight_data = load_tensor_for_weight(linear_layer_bias_weight_info) + elif linear_layer is not None and linear_layer.bias.device != torch.device("meta"): + bias_weight_data = linear_layer.bias.data + if bias_weight_data is not None: + local_bias_weight_data = create_local_bias_from_regular_bias( + bias_weight_data, + num_attention_heads, + num_key_value_heads, + kv_size_multiplier, + "key_value" if proj_name in "kv" else "query", + layer.gather_output, + ) + bias.copy_(local_bias_weight_data) + mark_parameter_init_status_during_parallelization(bias, True) + else: + mark_parameter_init_status_during_parallelization(bias, False) + + +def maybe_load_weights_to_gqa_qkv_column_parallel_linear( + model: torch.nn.Module, + layer: OptimumGQAQKVColumnParallelLinear, + try_from_checkpoint: bool = True, + try_from_original_layer: bool = False, +): + weight_map = getattr(model, "_weight_map", {}) + named_modules = {v: k for k, v in model.named_modules()} + original_to_gqa = layer.get_parameter_names_mapping(named_modules) + + for orig_name, gqa_name in original_to_gqa.items(): + linear_layer_qualified_name, _ = orig_name.rsplit(".", maxsplit=1) + linear_weight_info, linear_bias_weight_info = get_linear_weight_info( + weight_map, linear_layer_qualified_name, fail_if_not_found=False + ) + weight_name = gqa_name.split(".")[-1] + if try_from_checkpoint and linear_weight_info is not None: + maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( + layer, + weight_name, + linear_layer_weight_info=linear_weight_info, + linear_layer_bias_weight_info=linear_bias_weight_info, + ) + elif try_from_original_layer: + orig_layer_name, _ = orig_name.rsplit(".", maxsplit=1) + maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( + layer, + weight_name, + linear_layer=model.get_submodule(orig_layer_name), + ) + + +def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear( + output_projection: "layers.RowParallelLinear", + num_attention_heads: int, + num_key_value_heads: int, + kv_size_multiplier: int, + original_output_projection: Optional[torch.nn.Linear] = None, + linear_layer_weight_info: Optional[WeightInformation] = None, + linear_layer_bias_weight_info: Optional[WeightInformation] = None, + try_from_checkpoint: bool = True, + try_from_original_layer: bool = False, +): + weight = output_projection.weight + bias = output_projection.bias + with torch.no_grad(): + if not was_already_initialized_during_parallelization(weight): + weight_data = None + if try_from_checkpoint and linear_layer_weight_info is not None: + weight_data = load_tensor_for_weight(linear_layer_weight_info) + elif ( + try_from_original_layer + and original_output_projection is not None + and original_output_projection.weight.device != torch.device("meta") + ): + weight_data = original_output_projection.weight.data + if weight_data is not None: + weight_data = create_query_or_output_projection_local_weight_from_regular_weight( + weight_data, num_attention_heads, num_key_value_heads, kv_size_multiplier, "output" + ) + weight.copy_(weight_data) + mark_parameter_init_status_during_parallelization(weight, True) + else: + mark_parameter_init_status_during_parallelization(weight, False) + if bias is not None and not was_already_initialized_during_parallelization(bias): + bias_weight_data = None + if linear_layer_bias_weight_info is not None: + bias_weight_data = load_tensor_for_weight(linear_layer_bias_weight_info) + elif original_output_projection is not None and original_output_projection.bias.device != torch.device( + "meta" + ): + bias_weight_data = original_output_projection.bias.data + if bias_weight_data is not None: + output_projection.bias.copy_(bias_weight_data) + mark_parameter_init_status_during_parallelization(output_projection.bias, True) + else: + mark_parameter_init_status_during_parallelization(output_projection.bias, False) + + @requires_neuronx_distributed def maybe_load_linear_weight_to_parallel_linear( parallel_linear_layer: "layers.BaseParallelLinear", @@ -332,7 +827,7 @@ def maybe_load_linear_weight_to_parallel_linear( mark_parameter_init_status_during_parallelization(parallel_linear_layer.weight, True) elif linear_layer.weight.device != torch.device("meta"): parallel_linear_layer.weight.copy_( - linear_layer.weight[:, tp_rank * col_size : (tp_rank + 1) * col_size] + linear_layer.weight.data[:, tp_rank * col_size : (tp_rank + 1) * col_size] ) mark_parameter_init_status_during_parallelization(parallel_linear_layer.weight, True) else: @@ -345,7 +840,7 @@ def maybe_load_linear_weight_to_parallel_linear( parallel_linear_layer.bias.copy_(bias_weight_data) mark_parameter_init_status_during_parallelization(parallel_linear_layer.bias, True) elif linear_layer.bias.device != torch.device("meta"): - parallel_linear_layer.bias.copy_(linear_layer.bias) + parallel_linear_layer.bias.copy_(linear_layer.bias.data) mark_parameter_init_status_during_parallelization(parallel_linear_layer.bias, True) else: mark_parameter_init_status_during_parallelization(parallel_linear_layer.bias, False) @@ -365,7 +860,7 @@ def maybe_load_linear_weight_to_parallel_linear( del weight_data elif linear_layer.weight.device != torch.device("meta"): parallel_linear_layer.weight.copy_( - linear_layer.weight[tp_rank * row_size : (tp_rank + 1) * row_size, :] + linear_layer.weight.data[tp_rank * row_size : (tp_rank + 1) * row_size, :] ) mark_parameter_init_status_during_parallelization(parallel_linear_layer.weight, True) else: @@ -392,10 +887,10 @@ def maybe_load_linear_weight_to_parallel_linear( del bias_weight_data elif linear_layer.bias.device != torch.device("meta"): if parallel_linear_layer.gather_output: - parallel_linear_layer.bias.copy_(linear_layer.bias) + parallel_linear_layer.bias.copy_(linear_layer.bias.data) else: parallel_linear_layer.bias.copy_( - linear_layer.bias[tp_rank * row_size : (tp_rank + 1) * row_size] + linear_layer.bias.data[tp_rank * row_size : (tp_rank + 1) * row_size] ) mark_parameter_init_status_during_parallelization(parallel_linear_layer.bias, True) else: @@ -495,10 +990,6 @@ def linear_to_parallel_linear( if embedding_weight_to_tie is not None: parallel_linear_layer.weight = embedding_weight_to_tie - del linear_layer.weight - if linear_layer.bias is not None: - del linear_layer.bias - return parallel_linear_layer @@ -596,37 +1087,52 @@ def delete_tensor_model_parallel_attributes(tensor: torch.Tensor): delattr(tensor, attr_name) -def try_to_hf_initialize(model: "PreTrainedModel", mod: torch.nn.Module, parameter_names: List[str]) -> List[str]: +def try_to_hf_initialize( + model: "PreTrainedModel", + mod: torch.nn.Module, + parameter_names: List[str], + parameter_names_mapping: Optional[Dict[str, str]] = None, +) -> List[str]: """ Tries to initialize the parameters in `parameter_names` that belong to the module `mod` by using the `model._init_weights` method. It returns the names of the parameters that were left uninitialized. """ - cached_params_data = {name: param.data.clone() for name, param in mod.named_parameters()} + cached_params_data = {name: param.data.detach().clone().to("cpu") for name, param in mod.named_parameters()} model._init_weights(mod) - dummy_mod = copy.deepcopy(mod) + if parameter_names_mapping is None: + parameter_names_mapping = {} + reverse_parameter_names_mapping = {v: k for k, v in parameter_names_mapping.items()} + + def name_in_mod(name: str): + return parameter_names_mapping.get(name, name) + + dummy_mod = copy.deepcopy(mod).to("cpu") for name in parameter_names: - getattr(dummy_mod, name).random_() + getattr(dummy_mod, name_in_mod(name)).random_() model._init_weights(dummy_mod) left_uninitialized = [] with torch.no_grad(): - for name in parameter_names: + for param_name in parameter_names: + name = name_in_mod(param_name) # The parameter was left unchanged. - if torch.all(getattr(mod, name).data == cached_params_data[name]): + param_on_cpu = getattr(mod, name).data.to("cpu") + if torch.all(param_on_cpu == cached_params_data[name]): # There are two possible reasons: # 1. The model cannot initialize the module that owns the parameter. # 2. The parameter already had the proper value. # We check if a dummy copy of the module, filled with random values is modified to know if the model # can initialize the module. - dummy_param_was_changed = torch.all(getattr(dummy_mod, name).data == getattr(mod, name).data) + dummy_param_was_changed = torch.all(getattr(dummy_mod, name).data == param_on_cpu) if not dummy_param_was_changed: - left_uninitialized.append(name) + left_uninitialized.append(param_name) for name, cached_data in cached_params_data.items(): - if name not in parameter_names: + param_name = reverse_parameter_names_mapping.get(name, name) + if param_name not in parameter_names: param = getattr(mod, name) param.data = cached_data @@ -639,7 +1145,7 @@ def initialize_torch_nn_module(mod: torch.nn.Module, parameter_names: List[str]) """ if not hasattr(mod, "reset_parameters"): raise ValueError(f"{mod} does not have a `reset_parameters` method.") - cached_parameters = {name: param.data.clone() for name, param in mod.named_parameters()} + cached_parameters = {name: param.data.detach().clone() for name, param in mod.named_parameters()} mod.reset_parameters() with torch.no_grad(): for name, param in mod.named_parameters(): @@ -647,18 +1153,29 @@ def initialize_torch_nn_module(mod: torch.nn.Module, parameter_names: List[str]) param.data = cached_parameters[name] +@requires_neuronx_distributed def initialize_parallel_linear(mod: "layers.BaseParallelLinear", parameter_names: List[str]): """ Initializes the parameters in `parameter_names` of a parallel linear module. """ - if "weight" in parameter_names: - delete_tensor_model_parallel_attributes(mod.weight) - # It is needed to use `init_weight_cpu` instead of `_init_weights` because the initialization - # needs to happen on the full parameter and then scatter it accross TP ranks otherwise it will - # not be equivalent to the non-parallel case. - mod.init_weight_cpu() - if mod.bias is not None and "bias" in parameter_names: - mod._init_bias() + from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear + from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear + + if isinstance(mod, (RowParallelLinear, ColumnParallelLinear)): + if "weight" in parameter_names: + delete_tensor_model_parallel_attributes(mod.weight) + # It is needed to use `init_weight_cpu` instead of `_init_weights` because the initialization + # needs to happen on the full parameter and then scatter it accross TP ranks otherwise it will + # not be equivalent to the non-parallel case. + mod.init_weight_cpu() + if mod.bias is not None and "bias" in parameter_names: + mod._init_bias() + elif isinstance(mod, GQAQKVColumnParallelLinear): + # It ignores parameter_names, so it might initialize some parameters that should be left unchanged. + # To improve if it becomes neeeded. + mod.initialize_weight_biases() + else: + raise RuntimeError(f"This kind of parallel linear is not supported yet: {mod.__class__.__name__}") def parameter_can_be_initialized(model: torch.nn.Module, parent_module: torch.nn.Module, parameter_name: str) -> bool: @@ -958,3 +1475,8 @@ def is_tied(self): @property def is_sharded(self): return self.kind == "sharded" + + +class OptimumNeuronFXTracer(HFTracerWrapper): + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + return super().is_leaf_module(m, module_qualified_name) or isinstance(m, FakeProj) diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index 85a8a4a58..a49ccade6 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -20,6 +20,7 @@ import warnings from dataclasses import dataclass, field from datetime import timedelta +from typing import Optional import torch from accelerate.utils import DistributedType @@ -89,6 +90,16 @@ class NeuronTrainingArgumentsMixin: default=-1, metadata={"help": "The number of microbatches used for pipeline execution."}, ) + kv_size_multiplier: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The number of times to replicate the KV heads when the TP size is bigger than the number of KV heads." + "If left unspecified, the smallest multiplier that makes the number of KV heads divisible by the TP size" + "will be used." + ) + }, + ) def __post_init__(self): # Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available` @@ -147,6 +158,7 @@ def __post_init__(self): self.tensor_parallel_size, parallelize_embeddings=not self.disable_embedding_parallelization, sequence_parallel_enabled=not self.disable_sequence_parallel, + kv_size_multiplier=self.kv_size_multiplier, pipeline_parallel_size=self.pipeline_parallel_size, pipeline_parallel_num_microbatches=self.pipeline_parallel_num_microbatches, pipeline_parallel_use_zero1_optimizer=self.zero_1, diff --git a/optimum/neuron/utils/misc.py b/optimum/neuron/utils/misc.py index c619b2627..8df5b14cf 100644 --- a/optimum/neuron/utils/misc.py +++ b/optimum/neuron/utils/misc.py @@ -14,6 +14,7 @@ # limitations under the License. """Utilities of various sorts.""" +import functools import inspect import os import re @@ -40,7 +41,7 @@ from ...utils import logging from .import_utils import is_torch_xla_available -from .require_utils import requires_safetensors +from .require_utils import requires_safetensors, requires_torch_xla if TYPE_CHECKING: @@ -191,6 +192,19 @@ def convert_checkpoint_to_safetensors( return safetensors_path +@requires_torch_xla +@functools.wraps(cached_file) +def distributed_friendly_cached_file(*args, **kwargs): + import torch_xla.core.xla_model as xm + + if is_main_worker(): + output = cached_file(*args, **kwargs) + xm.rendezvous("Cached file done") + if not is_main_worker(): + output = cached_file(*args, **kwargs) + return output + + def download_checkpoints_in_cache( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], cache_dir: Optional[Union[str, os.PathLike]] = None, @@ -374,13 +388,16 @@ def download_checkpoints_in_cache( "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + resolved_archive_file = distributed_friendly_cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not. if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( + resolved_archive_file = distributed_friendly_cached_file( pretrained_model_name_or_path, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), **cached_file_kwargs, @@ -397,12 +414,12 @@ def download_checkpoints_in_cache( else: # This repo has no safetensors file of any kind, we switch to PyTorch. filename = _add_variant(WEIGHTS_NAME, variant) - resolved_archive_file = cached_file( + resolved_archive_file = distributed_friendly_cached_file( pretrained_model_name_or_path, filename, **cached_file_kwargs ) if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( + resolved_archive_file = distributed_friendly_cached_file( pretrained_model_name_or_path, _add_variant(WEIGHTS_INDEX_NAME, variant), **cached_file_kwargs, diff --git a/tests/distributed/test_common.py b/tests/distributed/test_common.py index 4cc99a741..aa9c44982 100644 --- a/tests/distributed/test_common.py +++ b/tests/distributed/test_common.py @@ -24,6 +24,7 @@ from optimum.neuron.accelerate.optimizer import NeuronAcceleratedOptimizer from optimum.neuron.accelerate.utils.dataclasses import NeuronDistributedType +from optimum.neuron.distributed.checkpointing import consolidate_model_parallel_checkpoints_to_unified_checkpoint from optimum.neuron.distributed.utils import ( TENSOR_PARALLEL_SHARDS_DIR_NAME, make_optimizer_constructor_lazy, @@ -56,6 +57,7 @@ from transformers import PreTrainedModel MODEL_NAME = "michaelbenayoun/llama-2-tiny-16layers-random" +MODEL_NAME_WITH_4_KV_HEADS = "michaelbenayoun/llama-2-tiny-4kv-heads-16layers-random" def get_tiny_llama_model( @@ -113,7 +115,7 @@ class TestCommonDistributed(DistributedTest): def parallel_sizes(self, request): return request.param - @pytest.fixture(scope="class", params=[False, True], ids=["no_lazy_load", "lazy_load"]) + @pytest.fixture(scope="class", params=[False, True], ids=["regular_load", "lazy_load"]) def lazy_load(self, request): return request.param @@ -121,7 +123,7 @@ def lazy_load(self, request): def from_config(self, request): return request.param - @pytest.fixture(scope="class", params=[False, True], ids=["no_lazy_optimizer", "lazy_optimizer"]) + @pytest.fixture(scope="class", params=[False, True], ids=["regular_optimizer", "lazy_optimizer"]) def lazy_optimizer(self, request): return request.param @@ -129,15 +131,15 @@ def lazy_optimizer(self, request): def with_groups(self, request): return request.param - @pytest.fixture(scope="class", params=[False, True], ids=["no_zero_1", "zero_1"]) + @pytest.fixture(scope="class", params=[False, True], ids=["without_zero_1", "with_zero_1"]) def zero_1(self, request): return request.param - @pytest.fixture(scope="class", params=[1, 12], ids=["no_grad_acc", "grad_acc=12"]) + @pytest.fixture(scope="class", params=[1, 12], ids=["without_grad_acc", "with_grad_acc=12"]) def gradient_accumulation_steps(self, request): return request.param - @pytest.fixture(scope="class", params=[None, 0.01], ids=["no_clip_grad_norm", "clip_grad_norm"]) + @pytest.fixture(scope="class", params=[None, 0.01], ids=["without_clip_grad_norm", "with_clip_grad_norm"]) def max_grad_norm(self, request): return request.param @@ -413,3 +415,56 @@ def test_save_model_and_load_model(self, parallel_sizes, tmpdir, monkeypatch): if dp_rank == 0: assert all(torch.all(p1 == p2) for p1, p2 in zip(model_parameters, new_model_parameters)) + + @pytest.mark.parametrize( + "world_size,tp_size,pp_size,kv_size_multiplier,model_name", + [ + [8, 2, 1, None, MODEL_NAME_WITH_4_KV_HEADS], + [8, 1, 2, None, MODEL_NAME_WITH_4_KV_HEADS], + [16, 2, 2, None, MODEL_NAME_WITH_4_KV_HEADS], + [16, 8, 2, None, MODEL_NAME_WITH_4_KV_HEADS], + [16, 8, 2, 4, MODEL_NAME_WITH_4_KV_HEADS], + ], + ids=[ + "tp=2", + "pp=2", + "dp=4,tp=pp=2", + "dp=1,tp=8,pp=2,kv_size_multiplier=None,GQAQKVColumnParallelLinear", + "dp=1,tp=8,pp=2,kv_size_multiplier=4,GQAQKVColumnParallelLinear", + ], + ) + def test_consolidate_model_parallel_checkpoints( + self, tmpdir, world_size, tp_size, pp_size, kv_size_multiplier, model_name + ): + orig_model = get_model( + LlamaForCausalLM, + model_name, + use_static_seed_patcher=True, + ) + orig_model_path = Path(tmpdir) / "orig_model" + if xm.get_ordinal() == 0: + # Saving to pytorch instead of safetensors because it fails otherwise for pickling issues with distributed tests. + orig_model.save_pretrained(orig_model_path, safe_serialization=False) + + accelerator = create_accelerator_for_mp(tp_size, pp_size, kv_size_multiplier=kv_size_multiplier) + _ = accelerator.prepare(orig_model) + + output_dir = Path(tmpdir) / "parallel_model" + accelerator.save_state(output_dir.as_posix()) + + xm.rendezvous("Saving done.") + + consolidation_dir = Path(tmpdir) / "consolidated" + if xm.get_ordinal() == 0: + consolidate_model_parallel_checkpoints_to_unified_checkpoint( + output_dir, consolidation_dir, save_format="pytorch" + ) + consolidated_state_dict = torch.load(consolidation_dir / "pytorch_model.bin") + orig_state_dict = torch.load(orig_model_path / "pytorch_model.bin") + + assert orig_state_dict.keys() == consolidated_state_dict.keys() + for key in orig_state_dict: + orig_tensor = orig_state_dict[key] + consolidated_tensor = consolidated_state_dict[key] + print(f"Testing that {key} match") + torch.testing.assert_close(orig_tensor, consolidated_tensor) diff --git a/tests/distributed/test_model_parallelization.py b/tests/distributed/test_model_parallelization.py index 4bb9abaaf..dbbb0abd1 100644 --- a/tests/distributed/test_model_parallelization.py +++ b/tests/distributed/test_model_parallelization.py @@ -14,12 +14,13 @@ # limitations under the License. """Tests validating that models can be parallelized correctly.""" +from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Type, Union import pytest import torch import torch.utils._pytree as pytree -from transformers import LlamaForCausalLM +from transformers import AutoTokenizer, LlamaForCausalLM from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, @@ -44,6 +45,7 @@ import optimum from optimum.neuron.accelerate.accelerator import NeuronAccelerator from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager +from optimum.neuron.distributed.utils import compute_query_indicies_for_rank from optimum.neuron.utils.cache_utils import ( get_num_neuron_cores, ) @@ -63,6 +65,7 @@ import torch_xla.core.xla_model as xm if is_neuronx_distributed_available(): + from neuronx_distributed.modules.qkv_linear import get_kv_shared_group from neuronx_distributed.parallel_layers.parallel_state import ( get_pipeline_model_parallel_rank, get_tensor_model_parallel_group, @@ -182,65 +185,7 @@ def _generate_supported_model_classes( ] -LLAMA_GQA_VARIANTS_TO_TEST = { - "MHA-setup": ( - 8, - 2, - 1, - { - "num_hidden_layers": "2", - "num_attention_heads": "8", - "num_key_value_heads": "8", - }, - ), - "num_key_value_heads > tp_size": ( - 8, - 2, - 1, - { - "num_hidden_layers": "2", - "num_attention_heads": "8", - "num_key_value_heads": "4", - }, - ), - "num_key_value_heads = tp_size": ( - 8, - 8, - 1, - { - "num_hidden_layers": "2", - "hidden_size": "32", - "num_attention_heads": "16", - "num_key_value_heads": "8", - }, - ), - "num_key_value_heads < tp_size": ( - 8, - 8, - 1, - { - "num_hidden_layers": "2", - "hidden_size": "32", - "num_attention_heads": "16", - "num_key_value_heads": "2", - }, - ), - "MQA-setup": ( - 8, - 8, - 1, - { - "num_hidden_layers": "2", - "hidden_size": "32", - "num_attention_heads": "16", - "num_key_value_heads": "1", - }, - ), -} -# LLAMA_V2_MODEL_NAME = "michaelbenayoun/llama-2-tiny-16layers-32kv-heads-random" -LLAMA_V2_MODEL_NAME = "anushehchaudry/llama-2-tiny-random" -# LLAMA_V2_MODEL_NAME = "michaelbenayoun/llama-2-tiny-16layers-random" -# LLAMA_V2_MODEL_NAME = "michaelbenayoun/llama-2-tiny-16layers-32kv-heads-random" +LLAMA_V2_MODEL_NAME = "michaelbenayoun/llama-2-tiny-16layers-32kv-heads-random" @is_trainium_test @@ -259,6 +204,24 @@ def parallel_sizes(self, request): def model_specs(self, request): return request.param + @pytest.fixture(scope="class", params=[True, False], ids=["from_pretrained", "from_config"]) + def from_pretrained(self, request): + return request.param + + @pytest.fixture(scope="class", params=[False, True], ids=["regular_load", "lazy_load"]) + def lazy_load(self, request): + return request.param + + @pytest.fixture( + scope="class", params=[False, True], ids=["sequence_parallel_disabled", "sequence_parallel_enabled"] + ) + def sequence_parallel_enabled(self, request): + return request.param + + @pytest.fixture(scope="class", params=[False, True], ids=["embeddings_not_parallel", "parallelized_embeddings"]) + def parallelize_embeddings(self, request): + return request.param + def early_skip(self, fixtures_kwargs): pp_size = fixtures_kwargs.get("pp_size", None) parallel_sizes = fixtures_kwargs.get("parallel_sizes", None) @@ -287,16 +250,24 @@ def _check_output(self, name: str, original_output, output): elif isinstance(original_output, torch.Tensor): xm.master_print(f"Comparing output named {name}") tp_size = get_tensor_model_parallel_size() + tp_group = get_tensor_model_parallel_group() if original_output.shape != output.shape: gather_dim = min( idx for idx in range(original_output.dim()) if original_output.shape[idx] != output.shape[idx] ) output = output.to(xm.xla_device()) gathered = [torch.empty_like(output) for _ in range(tp_size)] - torch.distributed.all_gather(gathered, output, group=get_tensor_model_parallel_group()) + torch.distributed.all_gather(gathered, output, group=tp_group) gathered_output = torch.cat(gathered, dim=gather_dim) xm.mark_step() output = gathered_output.to("cpu") + + # In this case, we assume GQAQKVColumnParallelLinear was used, we retrieve only the non-repeated KV heads. + if "past" in name and original_output.size(1) != output.size(1): + kv_size_multiplier = len(get_kv_shared_group(as_list=True)[0]) + output = torch.chunk(output, kv_size_multiplier, dim=1)[0] + + xm.master_print("Diff tensor:", original_output - output) torch.testing.assert_close(original_output, output) else: assert original_output == output, f"Output named {name} do not match." @@ -328,6 +299,9 @@ def _parallel_model_matches_original_model( ) orig_model = NeuronAccelerator.patch_model_for_neuron(orig_model) + # TODO: enable that again once it's working, seems to be an AWS issue. + orig_model.config.use_cache = False + set_neuron_cc_optlevel_for_model(orig_model) move_model_to_device(orig_model, xm.xla_device()) @@ -341,6 +315,9 @@ def _parallel_model_matches_original_model( if sequence_parallel_enabled and not manager.supports_sequence_parallelism(): pytest.skip(f"Sequence parallelism is not supported for {model_class.__name__}.") + if not from_pretrained and lazy_load: + pytest.skip("This is not supported, issue with tying weights.") + pad_to_multiple_of = None if not sequence_parallel_enabled else tp_size inputs = get_model_inputs( orig_model, model_name_or_path, batch_size=dp_size, pad_to_multiple_of=pad_to_multiple_of @@ -449,20 +426,206 @@ def test_parallel_model_matches_original_model_from_config( ) @pytest.mark.parametrize( "world_size,tp_size,pp_size,config_overwrite", - LLAMA_GQA_VARIANTS_TO_TEST.values(), - ids=LLAMA_GQA_VARIANTS_TO_TEST.keys(), + [ + [ + 8, + 2, + 1, + { + "num_hidden_layers": "2", + "hidden_size": "32", + "num_attention_heads": "8", + "num_key_value_heads": "8", + }, + ], + [ + 8, + 2, + 1, + { + "num_hidden_layers": "2", + "hidden_size": "32", + "num_attention_heads": "8", + "num_key_value_heads": "4", + }, + ], + [ + 8, + 8, + 1, + { + "num_hidden_layers": "2", + "hidden_size": "32", + "num_attention_heads": "16", + "num_key_value_heads": "8", + }, + ], + [ + 8, + 8, + 1, + { + "num_hidden_layers": "2", + "hidden_size": "32", + "num_attention_heads": "16", + "num_key_value_heads": "2", + }, + ], + [ + 16, + 8, + 2, + { + "num_hidden_layers": "2", + "hidden_size": "32", + "num_attention_heads": "16", + "num_key_value_heads": "2", + }, + ], + [ + 8, + 8, + 1, + { + "num_hidden_layers": "2", + "hidden_size": "32", + "num_attention_heads": "16", + "num_key_value_heads": "1", + }, + ], + ], + ids=[ + "MHA-setup", + "num_key_value_heads bigger than tp_size", + "num_key_value_heads equal to tp_size", + "num_key_value_heads lower than tp_size", + "num_key_value_heads lower than tp_size,pp enabled", + "MQA-setup", + ], ) - def test_llama_v2_gqa_variants(self, world_size, tp_size, pp_size, config_overwrite, monkeypatch): + def test_llama_v2_gqa( + self, + monkeypatch, + tmpdir, + world_size, + tp_size, + pp_size, + config_overwrite, + from_pretrained, + lazy_load, + sequence_parallel_enabled, + parallelize_embeddings, + ): monkeypatch.setattr( optimum.neuron.distributed.parallel_layers, "_PARALLEL_CROSS_ENTROPY_SHOULD_PRESERVE_INPUT", True ) + num_kv_heads = int(config_overwrite["num_key_value_heads"]) + # if num_kv_heads >= tp_size and (from_pretrained or lazy_load or sequence_parallel_enabled): + # pytest.skip("No need to test this setting.") + + # The following case can be skipped because since we set the seed, we would need to shuffle the output + # projections for this case to work. This is not needed in the real-case scenario, and since we test for every + # other setting, we can skip. + if num_kv_heads < tp_size and (not from_pretrained): + pytest.skip("This case will not work here because we set the seed. We can skip.") + + model_name_or_path = Path(tmpdir) / "llama_v2_gqa" + + # Since we are creating the model from config, we actually first create a model locally from config and then + # use that as a `from_pretrained` to have proper initialization. Without that we can end-up with uninitialized + # weights. + if xm.get_ordinal() == 0: + tokenizer = AutoTokenizer.from_pretrained(LLAMA_V2_MODEL_NAME) + tokenizer.save_pretrained(model_name_or_path) + model = get_model( + LlamaForCausalLM, + LLAMA_V2_MODEL_NAME, + from_config=True, + config_overwrite=config_overwrite, + ) + model.save_pretrained(model_name_or_path) + xm.rendezvous("Model creation done.") + return self._parallel_model_matches_original_model( LlamaForCausalLM, - LLAMA_V2_MODEL_NAME, + model_name_or_path, config_overwrite, (world_size, tp_size, pp_size), - False, - False, - False, - False, + from_pretrained, + lazy_load, + sequence_parallel_enabled, + parallelize_embeddings, + ) + + +@pytest.mark.parametrize( + "tp_size,num_attention_heads,num_key_value_heads,kv_size_multiplier,ground_truth", + [ + [ + 8, + 32, + 4, + 2, + [ + [0, 1, 2, 3], + [8, 9, 10, 11], + [16, 17, 18, 19], + [24, 25, 26, 27], + [4, 5, 6, 7], + [12, 13, 14, 15], + [20, 21, 22, 23], + [28, 29, 30, 31], + ], + ], + [ + 8, + 32, + 4, + 4, + [ + [0, 1, 8, 9], + [16, 17, 24, 25], + [2, 3, 10, 11], + [18, 19, 26, 27], + [4, 5, 12, 13], + [20, 21, 28, 29], + [6, 7, 14, 15], + [22, 23, 30, 31], + ], + ], + [ + 8, + 32, + 4, + 8, + [ + [0, 8, 16, 24], + [1, 9, 17, 25], + [2, 10, 18, 26], + [3, 11, 19, 27], + [4, 12, 20, 28], + [5, 13, 21, 29], + [6, 14, 22, 30], + [7, 15, 23, 31], + ], + ], + ], + ids=[ + "32-heads-4kv-heads-kv-mul-2,one kv head per rank", + "32-heads-4kv-heads-kv-mul-4,multiple kv heads per rank", + "32-heads-4kv-heads-kv-mul-8,all kv heads per rank", + ], +) +@is_trainium_test +def test_compute_query_indices_for_rank( + tp_size, num_attention_heads, num_key_value_heads, kv_size_multiplier, ground_truth +): + for tp_rank in range(tp_size): + expected = torch.tensor(ground_truth[tp_rank]) + computed = compute_query_indicies_for_rank( + tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier ) + print(f"TP rank = {tp_rank}") + print(f"Expected {expected}") + print(f"Computed {computed}") + torch.testing.assert_close(expected, computed) diff --git a/tests/distributed/utils.py b/tests/distributed/utils.py index 8cd35f214..134a6f6b7 100644 --- a/tests/distributed/utils.py +++ b/tests/distributed/utils.py @@ -246,6 +246,11 @@ def create_static_seed_patcher(model_class: Type["PreTrainedModel"], seed: int): ("torch.Tensor.normal_", dynamic_patch), ("neuronx_distributed.parallel_layers.layers.ColumnParallelLinear.init_weight_cpu", dynamic_patch), ("neuronx_distributed.parallel_layers.layers.RowParallelLinear.init_weight_cpu", dynamic_patch), + ( + "neuronx_distributed.modules.qkv_linear.GQAQKVColumnParallelLinear._init_per_layer_weight", + dynamic_patch, + ), + ("neuronx_distributed.modules.qkv_linear.GQAQKVColumnParallelLinear._init_per_layer_bias", dynamic_patch), ] ) with patcher: @@ -354,10 +359,12 @@ def create_accelerator_for_mp( gradient_accumulation_steps: int = 1, parallelize_embeddings: bool = True, sequence_parallel_enabled: bool = True, + kv_size_multiplier: Optional[int] = None, checkpoint_dir: Optional[Union[Path, str]] = None, ) -> NeuronAccelerator: mp_plugin = ModelParallelismPlugin( tensor_parallel_size=tp_size, + kv_size_multiplier=kv_size_multiplier, parallelize_embeddings=parallelize_embeddings, sequence_parallel_enabled=sequence_parallel_enabled, pipeline_parallel_size=pp_size, diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 567de3178..0c83e97ff 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -25,6 +25,7 @@ from unittest import TestCase import huggingface_hub +import pytest import torch from huggingface_hub import HfApi, create_repo, delete_repo, get_token, hf_hub_download, login from transformers import BertConfig, BertModel, set_seed @@ -483,6 +484,7 @@ def test_neuron_hash_is_private(self): @is_trainium_test @is_staging_test +@pytest.mark.skip("This is not needed anymore and will be removed.") class CachedModelOnTheHubTestCase(StagingTestMixin, TestCase): def test_push_to_hub_fails_with_private_model_and_public_repo(self): with TemporaryDirectory() as tmpdirname: