From dd49c38ea59cd5bf886f25f9c63258c377fdac84 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 6 Feb 2024 17:41:26 +0100 Subject: [PATCH] [WIP] llama-70b --- optimum/neuron/distributed/base.py | 22 +++++++++---------- optimum/neuron/distributed/decoder_models.py | 5 ++++- .../distributed/encoder_decoder_models.py | 11 ++++++++-- optimum/neuron/distributed/encoder_models.py | 11 ++++++++-- optimum/neuron/distributed/parallel_layers.py | 5 +++-- optimum/neuron/trainers.py | 5 ++--- optimum/neuron/utils/training_utils.py | 2 +- 7 files changed, 38 insertions(+), 23 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index adc21c8c1..b241687fb 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -244,7 +244,7 @@ def _parallelize( device: Optional["torch.device"] = None, parallelize_embeddings: bool = True, sequence_parallel_enabled: bool = False, - should_parallelize_predicate_func: Optional[Callable[["torch.nn.Module"], "torch.nn.Module"]] = None, + should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None, ) -> "PreTrainedModel": """ Parallelizes the model by transforming regular layer into their parallel counterparts. @@ -260,7 +260,9 @@ def _parallelize( This can be disabled in the case when the TP size does not divide the vocabulary size. sequence_parallel_enabled (`bool`, defaults to `False`): Whether or not sequence parallelism is enabled. - # TODO: add docstring + should_parallelize_layer_predicate_func (Optional[Callable[[torch.nn.Module], bool]], defaults to `None`): + A function that takes a layer as input and returns a boolean specifying if the input layer should be + parallelized. This is useful to skip unnecessary parallelization, for pipeline parallelism for instance. Returns: `PreTrainedModel`: The parallelized model. """ @@ -337,27 +339,23 @@ def parallelize( name_to_parameter = dict(named_parameters(model, remove_duplicate=False)) parameter_to_name = {p: n for n, p in name_to_parameter.items()} - xm.master_print(name_to_parameter.keys()) - def predicate_func(layer): - for n, p in layer.named_parameters(): + for p in layer.parameters(): if p not in parameter_to_name: - xm.master_print(n) return True names = {parameter_to_name[p] for p in layer.parameters()} return names < names_of_the_parameters_to_consider - model.predicate = predicate_func - if tp_size > 1: model = cls._parallelize( model, device=device, parallelize_embeddings=parallelize_embeddings, sequence_parallel_enabled=sequence_parallel_enabled, - # should_parallelize_predicate_func=predicate_func, + should_parallelize_predicate_func=predicate_func, ) - # xm.rendezvous("End of tensor parallelism") + + xm.rendezvous("End of tensor parallelism") # Preparing the model for sequence parallelism: sp_specs_cls = cls.SEQUENCE_PARALLELSIM_SPECS_CLS @@ -507,7 +505,7 @@ def predicate_func(layer): if left_uninitialized and hasattr(mod, "reset_parameters"): initialize_torch_nn_module(mod, parameter_names) - # xm.rendezvous("End of initalization") + xm.rendezvous("End of initalization") pp_size = get_pipeline_model_parallel_size() if pp_size > 1: @@ -535,7 +533,7 @@ def predicate_func(layer): if gradient_checkpointing: apply_checkpoint(model) - # xxm.rendezvous("End of pipeline paralellism") + xm.rendezvous("End of pipeline paralellism") if checkpoint_dir is not None: cls.load_model_checkpoint(model, checkpoint_dir) diff --git a/optimum/neuron/distributed/decoder_models.py b/optimum/neuron/distributed/decoder_models.py index 8e30829b1..44763076b 100644 --- a/optimum/neuron/distributed/decoder_models.py +++ b/optimum/neuron/distributed/decoder_models.py @@ -15,7 +15,7 @@ """Classes related to `neuronx-distributed` to perform parallelism.""" import warnings -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple import torch from transformers.cache_utils import Cache @@ -605,7 +605,10 @@ def transform( layer: "torch.nn.Module", sequence_parallel_enabled: bool = False, device: Optional["torch.device"] = None, + should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None, ) -> "torch.nn.Module": + if should_parallelize_layer_predicate_func is not None and not should_parallelize_layer_predicate_func(layer): + return layer # TODO: Make it smart by merging the gate and the up_proj. # WARNING: be careful of the interleaved outputs when doing TP! layer = super().transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device) diff --git a/optimum/neuron/distributed/encoder_decoder_models.py b/optimum/neuron/distributed/encoder_decoder_models.py index fa29ee8b6..737f6299d 100644 --- a/optimum/neuron/distributed/encoder_decoder_models.py +++ b/optimum/neuron/distributed/encoder_decoder_models.py @@ -14,7 +14,7 @@ # limitations under the License. """Classes related to `neuronx-distributed` to perform parallelism.""" -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional import torch from transformers.models.t5.modeling_t5 import T5Attention, T5ForSequenceClassification, T5LayerNorm @@ -54,9 +54,12 @@ def transform( cls, model: "PreTrainedModel", layer: "torch.nn.Module", - device: Optional["torch.device"] = None, sequence_parallel_enabled: bool = False, + device: Optional["torch.device"] = None, + should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None, ) -> "torch.nn.Module": + if should_parallelize_layer_predicate_func is not None and not should_parallelize_layer_predicate_func(layer): + return layer from neuronx_distributed.parallel_layers.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_size, @@ -100,7 +103,11 @@ def transform( layer: "torch.nn.Module", sequence_parallel_enabled: bool = False, device: Optional["torch.device"] = None, + should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None, ) -> "torch.nn.Module": + if should_parallelize_layer_predicate_func is not None and not should_parallelize_layer_predicate_func(layer): + return layer + from transformers.models.t5.modeling_t5 import T5DenseGatedActDense if cls.FIRST_LINEAR_NAME is None or cls.SECOND_LINEAR_NAME is None: diff --git a/optimum/neuron/distributed/encoder_models.py b/optimum/neuron/distributed/encoder_models.py index c8e2c617c..5abdfe8a9 100644 --- a/optimum/neuron/distributed/encoder_models.py +++ b/optimum/neuron/distributed/encoder_models.py @@ -14,7 +14,7 @@ # limitations under the License. """Classes related to `neuronx-distributed` to perform parallelism.""" -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional import torch @@ -64,8 +64,15 @@ def transform( layer: "torch.nn.Module", sequence_parallel_enabled: bool = False, device: Optional["torch.device"] = None, + should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None, ) -> "torch.nn.Module": - layer = super().transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device) + layer = super().transform( + model, + layer, + sequence_parallel_enabled=sequence_parallel_enabled, + device=device, + should_parallelize_layer_predicate_func=should_parallelize_layer_predicate_func, + ) from transformers.models.bert.modeling_bert import BertLMPredictionHead for mod in layer.modules(): diff --git a/optimum/neuron/distributed/parallel_layers.py b/optimum/neuron/distributed/parallel_layers.py index 554460d11..a5c37549d 100644 --- a/optimum/neuron/distributed/parallel_layers.py +++ b/optimum/neuron/distributed/parallel_layers.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, Union import torch from torch.nn.modules.loss import _WeightedLoss @@ -133,8 +133,9 @@ def transform( layer: "torch.nn.Module", sequence_parallel_enabled: bool = False, device: Optional["torch.device"] = None, + should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None, ) -> "torch.nn.Module": - if not model.predicate(layer): + if should_parallelize_layer_predicate_func is not None and not should_parallelize_layer_predicate_func(layer): return layer return cls._transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index c92890881..ebfe94b2a 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -74,7 +74,7 @@ is_torch_xla_available, patch_within_function, ) -from .utils.cache_utils import get_neuron_cache_path, set_neuron_cache_path +from .utils.cache_utils import get_neuron_cache_path from .utils.misc import is_main_worker from .utils.require_utils import requires_neuronx_distributed from .utils.training_utils import ( @@ -787,8 +787,7 @@ def _inner_training_loop( # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. # Train! - # parameter_count = get_model_param_count(model, trainable_only=True) - parameter_count = 10 + parameter_count = get_model_param_count(model, trainable_only=True) if is_main_worker(): logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index fe709d143..2f836099c 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -419,7 +419,7 @@ def numel(parameter_name, parameter) -> int: if get_pipeline_model_parallel_size() > 1: param_count = torch.tensor(param_count, dtype=torch.float32).to(xm.xla_device()) - # param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True)) + param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True)) param_count = int(param_count.detach().item()) return param_count