diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index d54efc143..4b1e7ffac 100755 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -466,9 +466,10 @@ def main(): # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. - embedding_size = model.get_input_embeddings().weight.shape[0] - if len(tokenizer) > embedding_size: - model.resize_token_embeddings(len(tokenizer)) + # TODO: uncomment that. + # embedding_size = model.get_input_embeddings().weight.shape[0] + # if len(tokenizer) > embedding_size: + # model.resize_token_embeddings(len(tokenizer)) # Preprocessing the datasets. # First we tokenize all the texts. diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 2ab59a995..9b624f065 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -21,7 +21,7 @@ from dataclasses import asdict from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union, Callable import torch from transformers import PreTrainedModel @@ -98,7 +98,7 @@ class PipelineParallelismSpecs: @classmethod @requires_torch_xla - def create_pipeline_cuts(cls, model: PreTrainedModel, pipeline_parallel_size: int) -> List[str]: + def create_pipeline_cuts(cls, model: PreTrainedModel, pipeline_parallel_size: int, log: bool = True) -> List[str]: """ Creates the pipeline cuts, e.g. the name of the layers at each the cuts happen for pipeline parallelism. """ @@ -117,7 +117,7 @@ def create_pipeline_cuts(cls, model: PreTrainedModel, pipeline_parallel_size: in for cut_idx in range(num_layers_per_partition - 1, num_layers - 1, num_layers_per_partition) ] - if xm.get_local_ordinal() == 0: + if log and xm.get_ordinal() == 0: logger.info(f"Pipeline parallelism cuts: {pipeline_cuts}.") return pipeline_cuts @@ -197,7 +197,7 @@ def _get_parameter_names_for_current_pipeline( if not cls.supports_pipeline_parallelism(): raise NotImplementedError(f"{cls} does not support pipeline parallelism.") - cuts = cls.PIPELINE_PARALLELISM_SPECS_CLS.create_pipeline_cuts(model, pp_size) + cuts = cls.PIPELINE_PARALLELISM_SPECS_CLS.create_pipeline_cuts(model, pp_size, log=False) start_module_name = cuts[pp_rank - 1] if pp_rank >= 1 else None end_module_name = None if pp_rank == pp_size - 1 else cuts[pp_rank] @@ -243,6 +243,7 @@ def _parallelize( device: Optional["torch.device"] = None, parallelize_embeddings: bool = True, sequence_parallel_enabled: bool = False, + should_parallelize_predicate_func: Optional[Callable[["torch.nn.Module"], "torch.nn.Module"]] = None ) -> "PreTrainedModel": """ Parallelizes the model by transforming regular layer into their parallel counterparts. @@ -258,6 +259,7 @@ def _parallelize( This can be disabled in the case when the TP size does not divide the vocabulary size. sequence_parallel_enabled (`bool`, defaults to `False`): Whether or not sequence parallelism is enabled. + # TODO: add docstring Returns: `PreTrainedModel`: The parallelized model. """ @@ -304,6 +306,7 @@ def parallelize( Returns: `PreTrainedModel`: The parallelized model. """ + import torch_xla.core.xla_model as xm from neuronx_distributed import parallel_layers if sequence_parallel_enabled and not cls.supports_sequence_parallelism(): @@ -322,13 +325,34 @@ def parallelize( # Parallelizing the model. # This needs to be done prior to preparing the model for sequence parallelism because modules can be overriden. + + names_of_the_parameters_to_consider = cls._get_parameter_names_for_current_pipeline( + model, remove_duplicate=True + ) + + + name_to_parameter = dict(named_parameters(model, remove_duplicate=False)) + parameter_to_name = {p: n for n, p in name_to_parameter.items()} + + def predicate_func(layer): + for n, p in layer.named_parameters(): + if p not in parameter_to_name: + print(n) + return False + names = {parameter_to_name[p] for p in layer.parameters()} + return names < names_of_the_parameters_to_consider + + model.predicate = predicate_func + if tp_size > 1: model = cls._parallelize( model, device=device, parallelize_embeddings=parallelize_embeddings, sequence_parallel_enabled=sequence_parallel_enabled, + # should_parallelize_predicate_func=predicate_func, ) + xm.rendezvous("End of tensor parallelism") # Preparing the model for sequence parallelism: sp_specs_cls = cls.SEQUENCE_PARALLELSIM_SPECS_CLS @@ -358,10 +382,6 @@ def parallelize( # The model was not loaded lazily, it is already ready. weight_map = getattr(model, "_weight_map", {}) - names_of_the_parameters_to_consider = cls._get_parameter_names_for_current_pipeline( - model, remove_duplicate=True - ) - with torch.no_grad(): tied_weights = {} new_parameters = set() @@ -482,6 +502,8 @@ def parallelize( if left_uninitialized and hasattr(mod, "reset_parameters"): initialize_torch_nn_module(mod, parameter_names) + xm.rendezvous("End of initalization") + pp_size = get_pipeline_model_parallel_size() if pp_size > 1: if not cls.supports_pipeline_parallelism(): @@ -506,9 +528,12 @@ def parallelize( use_zero1_optimizer=pipeline_parallel_use_zero1_optimizer, ) + xm.rendezvous("End of pipeline paralellism") + if checkpoint_dir is not None: cls.load_model_checkpoint(model, checkpoint_dir) + return model @classmethod diff --git a/optimum/neuron/distributed/decoder_models.py b/optimum/neuron/distributed/decoder_models.py index 0bb795e31..8e30829b1 100644 --- a/optimum/neuron/distributed/decoder_models.py +++ b/optimum/neuron/distributed/decoder_models.py @@ -330,7 +330,7 @@ class LLamaParallelMLP(ParallelMLP): SECOND_LINEAR_NAME = "down_proj" @classmethod - def transform( + def _transform( cls, model: "PreTrainedModel", layer: "torch.nn.Module", @@ -339,7 +339,7 @@ def transform( ) -> "torch.nn.Module": # TODO: Make it smart by merging the gate and the up_proj. # WARNING: be careful of the interleaved outputs when doing TP! - layer = super().transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device) + layer = super()._transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device) weight_map = getattr(model, "_weight_map", None) diff --git a/optimum/neuron/distributed/parallel_layers.py b/optimum/neuron/distributed/parallel_layers.py index 4129cf327..1f4cb24bc 100644 --- a/optimum/neuron/distributed/parallel_layers.py +++ b/optimum/neuron/distributed/parallel_layers.py @@ -105,7 +105,7 @@ def _get_linear_weight_info( return linear_layer_weight_info, linear_layer_bias_weight_info @abstractclassmethod - def transform( + def _transform( cls, model: "PreTrainedModel", layer: "torch.nn.Module", @@ -126,6 +126,18 @@ def transform( The device where the new parallel layer should be put. """ + @classmethod + def transform( + cls, + model: "PreTrainedModel", + layer: "torch.nn.Module", + sequence_parallel_enabled: bool = False, + device: Optional["torch.device"] = None, + ) -> "torch.nn.Module": + if not model.predicate(layer): + return layer + return cls._transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device) + class ParallelEmbedding(ParallelLayer): """ @@ -164,7 +176,7 @@ def overwrite_vocab_size_value_for_cross_entropy_computation(cls, layer: "torch. @classmethod @requires_neuronx_distributed - def transform( + def _transform( cls, model: "PreTrainedModel", layer: "torch.nn.Module", @@ -295,7 +307,7 @@ class ParallelSelfAttention(ParallelLayer): @classmethod @requires_neuronx_distributed - def transform( + def _transform( cls, model: "PreTrainedModel", layer: "torch.nn.Module", @@ -475,7 +487,7 @@ class ParallelSelfAttentionWithFusedQKV(ParallelLayer): @classmethod @requires_neuronx_distributed - def transform( + def _transform( cls, model: "PreTrainedModel", layer: "torch.nn.Module", @@ -582,7 +594,7 @@ class ParallelSelfOutput(ParallelLayer): OUTPUT_PROJECTION_NAME = "dense" @classmethod - def transform( + def _transform( cls, model: "PreTrainedModel", layer: "torch.nn.Module", @@ -632,7 +644,7 @@ class ParallelMLP(ParallelLayer): SECOND_LINEAR_NAME: str @classmethod - def transform( + def _transform( cls, model: "PreTrainedModel", layer: "torch.nn.Module", @@ -789,7 +801,7 @@ def patch_cross_entropy(cls, model: "PreTrainedModel"): @classmethod @requires_neuronx_distributed - def transform( + def _transform( cls, model: "PreTrainedModel", layer: "torch.nn.Module", diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 7f98f392f..f5837e8ff 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -408,6 +408,7 @@ def linear_to_parallel_linear( ) parallel_linear_layer.weight.copy_(weight_data) mark_parameter_init_status_during_parallelization(parallel_linear_layer.weight, True) + del weight_data elif linear_layer.weight.device != torch.device("meta"): parallel_linear_layer.weight.copy_( linear_layer.weight[tp_rank * row_size : (tp_rank + 1) * row_size, :] @@ -433,6 +434,7 @@ def linear_to_parallel_linear( ) parallel_linear_layer.bias.copy_(bias_weight_data) mark_parameter_init_status_during_parallelization(parallel_linear_layer.bias, True) + del bias_weight_data elif linear_layer.bias.device != torch.device("meta"): if gather_output: parallel_linear_layer.bias.copy_(linear_layer.bias) @@ -444,6 +446,10 @@ def linear_to_parallel_linear( else: mark_parameter_init_status_during_parallelization(parallel_linear_layer.bias, False) + del linear_layer.weight + if linear_layer.bias is not None: + del linear_layer.bias + return parallel_linear_layer