diff --git a/optimum/neuron/accelerate/utils/dataclasses.py b/optimum/neuron/accelerate/utils/dataclasses.py index 8f4ce5b45..1461d6c9f 100644 --- a/optimum/neuron/accelerate/utils/dataclasses.py +++ b/optimum/neuron/accelerate/utils/dataclasses.py @@ -150,6 +150,7 @@ class ModelParallelismPlugin: pipeline_parallel_use_zero1_optimizer: bool = False gradient_checkpointing: bool = False checkpoint_dir: Optional[Union[str, Path]] = None + num_ranks_per_loading_step: int = -1 def __post_init__(self): if self.tensor_parallel_size < 1: @@ -181,5 +182,6 @@ def parallelize_model( pipeline_parallel_use_zero1_optimizer=self.pipeline_parallel_use_zero1_optimizer, pipeline_parallel_gradient_checkpointing_enabled=self.gradient_checkpointing, checkpoint_dir=self.checkpoint_dir, + num_ranks_per_loading_step=self.num_ranks_per_loading_step, ) return parallelized_model diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index d20dd620d..39f9a39b4 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -15,6 +15,8 @@ """Base class related to `neuronx_distributed` to perform parallelism.""" import contextlib +import gc +import math import shutil from abc import ABC, abstractclassmethod from collections import defaultdict @@ -540,6 +542,7 @@ def parallelize( pipeline_parallel_use_zero1_optimizer: bool = False, pipeline_parallel_gradient_checkpointing_enabled: bool = False, checkpoint_dir: Optional[Union[str, Path]] = None, + num_ranks_per_loading_step: int = -1, ) -> "PreTrainedModel": """ Parallelizes the model by transforming regular layer into their parallel counterparts using @@ -572,6 +575,9 @@ def parallelize( checkpoint_dir (`Optional[Union[str, Path]]`): Path to a sharded checkpoint. If specified, the checkpoint weights will be loaded to the parallelized model. + num_ranks_per_loading_step (`int`, defaults to `-1`): + Corresponds to the number of ranks that can initialize and load the model weights at the same time. + If the value is inferior to 0, the maximum number of ranks will be used. Returns: `PreTrainedModel`: The parallelized model. @@ -586,6 +592,7 @@ def parallelize( get_tensor_model_parallel_size, ) from neuronx_distributed.parallel_layers.random import _MODEL_PARALLEL_RNG_TRACKER_NAME, get_xla_rng_tracker + from neuronx_distributed.parallel_layers.utils import get_local_world_size from neuronx_distributed.pipeline import NxDPPModel tp_size = get_tensor_model_parallel_size() @@ -698,14 +705,20 @@ def should_parallelize_layer_predicate_func(layer): if is_precompilation(): cls._initialize_for_precompilation(model, names_of_the_parameters_to_consider) else: - if skip_linear_weight_load: - # Load the weights to the parallel linears if the loading was skipped during parallelization. - cls._maybe_load_weights_to_parallel_linears(model) - - if skip_linear_weight_load or any(p.device == torch.device("meta") for p in model.parameters()): - # Initialize or load the weights for the parallelized model if it was lazily loaded. - cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device) - + local_rank = xm.get_local_ordinal() + if num_ranks_per_loading_step < 0: + num_ranks_per_loading_step = get_local_world_size() + for worker in range(math.ceil(get_local_world_size() / num_ranks_per_loading_step)): + if local_rank // num_ranks_per_loading_step == worker: + if skip_linear_weight_load: + # Load the weights to the parallel linears if the loading was skipped during parallelization. + cls._maybe_load_weights_to_parallel_linears(model) + + if skip_linear_weight_load or any(p.device == torch.device("meta") for p in model.parameters()): + # Initialize or load the weights for the parallelized model if it was lazily loaded. + cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device) + gc.collect() + xm.rendezvous(f"weight_loading_and_initialization_{worker}") xm.rendezvous("End of initalization") if is_main_worker(): diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index 3466e8cc1..97a6128e3 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -23,7 +23,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors -from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indicies_for_rank +from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indices_for_rank def create_gqa_query_or_output_projection_weight_from_full_weight( @@ -42,15 +42,15 @@ def create_gqa_query_or_output_projection_weight_from_full_weight( hidden_size = full_weight.size(0) full_weight = full_weight.transpose(0, 1) - indicies = [ - compute_query_indicies_for_rank(tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier) + indices = [ + compute_query_indices_for_rank(tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier) for tp_rank in range(tp_size) ] - indicies = torch.cat(indicies, dim=0) - reversed_indicies = torch.sort(indicies, dim=0).indices + indices = torch.cat(indices, dim=0) + reversed_indices = torch.sort(indices, dim=0).indices full_weight = full_weight.reshape(num_attention_heads, -1, hidden_size) - full_weight = full_weight[reversed_indicies] + full_weight = full_weight[reversed_indices] full_weight = full_weight.reshape(-1, hidden_size) if query_or_output == "output": diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 5a34fb3a3..3d4d6df27 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -508,7 +508,7 @@ def create_kv_proj_local_weight_from_regular_weight( return torch.cat(split[tp_rank::tp_size], dim=0) -def compute_query_indicies_for_rank( +def compute_query_indices_for_rank( tp_size: int, tp_rank: int, num_attention_heads: int, num_key_value_heads: int, kv_size_multiplier: int ): """ @@ -519,32 +519,34 @@ def compute_query_indicies_for_rank( query_group_size = num_attention_heads // num_key_value_heads query_group_size_per_rank = num_attention_heads_per_rank // num_key_value_heads_per_rank - queries_indicies = [torch.arange(query_group_size_per_rank) for _ in range(num_key_value_heads_per_rank)] + queries_indices = [torch.arange(query_group_size_per_rank) for _ in range(num_key_value_heads_per_rank)] - keys_indicies = torch.arange(num_key_value_heads).repeat(kv_size_multiplier) - keys_indicies = torch.repeat_interleave( - keys_indicies, num_attention_heads_per_rank // num_key_value_heads_per_rank - ) - keys_indicies = torch.chunk(keys_indicies, tp_size) + keys_indices = torch.arange(num_key_value_heads).repeat(kv_size_multiplier) + keys_indices = torch.repeat_interleave(keys_indices, num_attention_heads_per_rank // num_key_value_heads_per_rank) + keys_indices = torch.chunk(keys_indices, tp_size) shift_per_key = torch.arange(0, num_attention_heads, query_group_size) shift_within_query_group = torch.arange(0, query_group_size, query_group_size_per_rank) + num_ranks_to_fit_all_key_value_heads = num_key_value_heads // num_key_value_heads_per_rank + num_query_heads_before_next_head_of_same_group = ( + num_ranks_to_fit_all_key_value_heads * num_attention_heads_per_rank + ) shift_within_query_group = torch.repeat_interleave( - shift_within_query_group, num_attention_heads_per_rank * num_key_value_heads_per_rank + shift_within_query_group, num_query_heads_before_next_head_of_same_group ) shift_within_query_group = torch.chunk(shift_within_query_group, tp_size) - indicies = [] - for idx, q_indicies in enumerate(queries_indicies): + indices = [] + for idx, q_indices in enumerate(queries_indices): s = slice(idx * query_group_size_per_rank, (idx + 1) * query_group_size_per_rank) - k_indicies = keys_indicies[tp_rank][s] - k_shift = shift_per_key[k_indicies] + k_indices = keys_indices[tp_rank][s] + k_shift = shift_per_key[k_indices] group_shift = shift_within_query_group[tp_rank][s] - indicies.append(q_indicies + k_shift + group_shift) + indices.append(q_indices + k_shift + group_shift) - indicies = torch.cat(indicies, dim=0) - return indicies + indices = torch.cat(indices, dim=0) + return indices @requires_neuronx_distributed @@ -578,11 +580,11 @@ def create_query_or_output_projection_local_weight_from_regular_weight( head_dim = weight_data.size(1) // num_attention_heads weight_data = weight_data.transpose(0, 1) - indicies = compute_query_indicies_for_rank( + indices = compute_query_indices_for_rank( tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier ) reshaped_weight = weight_data.view(num_attention_heads, head_dim, hidden_size) - shuffled_weight = reshaped_weight[indicies] + shuffled_weight = reshaped_weight[indices] shuffled_weight = shuffled_weight.reshape(-1, hidden_size) if query_or_output_proj == "output": @@ -620,9 +622,9 @@ def create_local_bias_from_regular_bias( else: if gather_output: - indicies = torch.cat( + indices = torch.cat( [ - compute_query_indicies_for_rank( + compute_query_indices_for_rank( tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier ) for tp_rank in range(tp_size) @@ -630,11 +632,11 @@ def create_local_bias_from_regular_bias( dim=0, ) else: - indicies = compute_query_indicies_for_rank( + indices = compute_query_indices_for_rank( tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier ) reshaped_bias_weight = bias_weigth_data.view(num_attention_heads, -1) - shuffled_bias_weight = reshaped_bias_weight[indicies] + shuffled_bias_weight = reshaped_bias_weight[indices] local_bias_weight = shuffled_bias_weight.reshape(-1) return local_bias_weight diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 41941d205..fe2c305aa 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -339,7 +339,8 @@ def compute_loss(self, model, inputs, return_outputs: bool = False): loss = model.run_train(**inputs) return loss - return super().compute_loss(model, inputs, return_outputs=return_outputs) + loss = super().compute_loss(model, inputs, return_outputs=return_outputs) + return loss def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: from neuronx_distributed.pipeline import NxDPPModel @@ -397,10 +398,10 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for ) if self.args.mp_plugin.should_parallelize: - dp_size = get_data_parallel_size() pp_size = get_pipeline_model_parallel_size() pp_rank = get_pipeline_model_parallel_rank() + tr_loss_div = tr_loss / dp_size if pp_size > 1 and pp_rank == pp_size - 1: diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index a49ccade6..c56f2fe1d 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -38,7 +38,7 @@ from ..utils import check_if_transformers_greater, logging from .accelerate import NeuronAcceleratorState, NeuronPartialState from .accelerate.utils import ModelParallelismPlugin, patch_accelerate_is_tpu_available -from .utils import is_accelerate_available, is_torch_xla_available +from .utils import is_accelerate_available, is_main_worker, is_torch_xla_available from .utils.patching import Patcher from .utils.training_utils import TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP @@ -100,6 +100,15 @@ class NeuronTrainingArgumentsMixin: ) }, ) + num_ranks_per_loading_step: int = field( + default=-1, + metadata={ + "help": ( + "The number of ranks to use concurrently during weight initialization and loading when tensor " + "parallelism is enabled. If left unspecified, the maximum number of ranks will be used." + ) + }, + ) def __post_init__(self): # Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available` @@ -141,6 +150,14 @@ def __post_init__(self): resume_from_checkpoint = checkpoint if self.pipeline_parallel_size > 1: + if self.gradient_accumulation_steps > 1: + if is_main_worker(): + logger.info( + "Pipeline parallel used, setting gradient_accumulation_steps to 1 and scaling the pipeline batch size." + ) + self.per_device_train_batch_size *= self.gradient_accumulation_steps + self.per_device_eval_batch_size *= self.gradient_accumulation_steps + self.gradient_accumulation_steps = 1 if self.pipeline_parallel_num_microbatches == -1: self.pipeline_parallel_num_microbatches = self.per_device_train_batch_size if self.per_device_train_batch_size % self.pipeline_parallel_num_microbatches != 0: @@ -164,6 +181,7 @@ def __post_init__(self): pipeline_parallel_use_zero1_optimizer=self.zero_1, gradient_checkpointing=self.gradient_checkpointing, checkpoint_dir=resume_from_checkpoint, + num_ranks_per_loading_step=self.num_ranks_per_loading_step, ) # This is required to be able to use bf16, otherwise a check in super().__post_init__() fails. diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 11a74a518..17ac6890c 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -39,6 +39,7 @@ DiffusersPretrainedConfig, check_if_weights_replacable, get_stable_diffusion_configs, + is_main_worker, replace_weights, ) from .optimization_utils import get_attention_scores_sd, get_attention_scores_sdxl diff --git a/tests/distributed/test_model_parallelization.py b/tests/distributed/test_model_parallelization.py index dbbb0abd1..8eaceda2c 100644 --- a/tests/distributed/test_model_parallelization.py +++ b/tests/distributed/test_model_parallelization.py @@ -45,7 +45,7 @@ import optimum from optimum.neuron.accelerate.accelerator import NeuronAccelerator from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager -from optimum.neuron.distributed.utils import compute_query_indicies_for_rank +from optimum.neuron.distributed.utils import compute_query_indices_for_rank from optimum.neuron.utils.cache_utils import ( get_num_neuron_cores, ) @@ -609,11 +609,52 @@ def test_llama_v2_gqa( [7, 15, 23, 31], ], ], + [ + 32, + 32, + 4, + 8, + [ + [0], + [8], + [16], + [24], + [1], + [9], + [17], + [25], + [2], + [10], + [18], + [26], + [3], + [11], + [19], + [27], + [4], + [12], + [20], + [28], + [5], + [13], + [21], + [29], + [6], + [14], + [22], + [30], + [7], + [15], + [23], + [31], + ], + ], ], ids=[ "32-heads-4kv-heads-kv-mul-2,one kv head per rank", "32-heads-4kv-heads-kv-mul-4,multiple kv heads per rank", "32-heads-4kv-heads-kv-mul-8,all kv heads per rank", + "tp=32,32-heads-4kv-heads-kv-mul-8,one query head per rank", ], ) @is_trainium_test @@ -622,7 +663,7 @@ def test_compute_query_indices_for_rank( ): for tp_rank in range(tp_size): expected = torch.tensor(ground_truth[tp_rank]) - computed = compute_query_indicies_for_rank( + computed = compute_query_indices_for_rank( tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier ) print(f"TP rank = {tp_rank}")