From 34df63758eb607dd13dbde6ea616d7fe0fa2248f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 8 Jul 2024 16:10:24 +0200 Subject: [PATCH] [WIP] optimum/neuron/models --- optimum/neuron/accelerate/accelerator.py | 8 ++++-- optimum/neuron/accelerate/utils/misc.py | 27 ++---------------- optimum/neuron/distributed/decoder_models.py | 11 +++---- optimum/neuron/models/core.py | 30 ++++++++++++++++++-- optimum/neuron/models/preparator.py | 16 +++++++---- 5 files changed, 52 insertions(+), 40 deletions(-) diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index ff389ef0d..389977cf0 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -36,6 +36,7 @@ from ...utils import logging from ..distributed import Parallelizer, ParallelizersManager +from ..models.preparator import NeuronPreparator from ..utils import ( Patcher, is_neuronx_distributed_available, @@ -79,6 +80,8 @@ logger = logging.get_logger(__name__) +NxDPPMODEL_PATCHING_SPECS = [] + class NeuronAccelerator(Accelerator): def __init__( @@ -322,7 +325,7 @@ def _prepare_model_for_mp( setattr(model, "main_input_name", model_main_input_name) if isinstance(model, NxDPPModel): - model.local_module = self.patch_model_for_neuron( + model.local_module = NeuronPreparator.patch_model_for_neuron( model.local_module, patching_specs=NxDPPMODEL_PATCHING_SPECS ) @@ -374,7 +377,8 @@ def prepare_model( # we get access to the model, we simply check if the flags are the best and notify the user otherwise. check_neuron_cc_flags_for_model(model) - model = self.patch_model_for_neuron(model) + NeuronPreparator.prepare_modeling(model) + NeuronPreparator.patch_model_for_neuron(model) # We do not want to use the cache, or output unused tensors as it would imply more communication that we do not # need. diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index 3eb06f23c..aca3f0115 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -16,20 +16,19 @@ import functools import inspect -from typing import TYPE_CHECKING, Dict, Optional, Union +from typing import TYPE_CHECKING, Union import torch from ....utils import logging from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere from ...utils.peft_utils import NeuronPeftModel -from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla +from ...utils.require_utils import requires_neuronx_distributed, requires_torch_xla logger = logging.get_logger(__name__) if TYPE_CHECKING: - import os from transformers import PreTrainedModel @@ -60,28 +59,6 @@ def patch_accelerate_is_torch_xla_available(): ) -@requires_neuronx_distributed -@requires_safetensors -def torch_xla_safe_save_file( - tensors: Dict[str, torch.Tensor], - filename: Union[str, "os.PathLike"], - metadata: Optional[Dict[str, str]] = None, - master_only: bool = True, - global_master: bool = False, -): - """ - Torch XLA compatible implementation of `safetensors.torch.save_file`. - """ - from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu - from safetensors.torch import save_file - from torch_xla.core.xla_model import is_master_ordinal - - should_write_data = not master_only or is_master_ordinal(local=not global_master) - cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data) - if should_write_data: - save_file(cpu_data, filename, metadata=metadata) - - # TODO: @michaelbenayoun # Needs to make it work in the general case or be deleted and only use `apply_activation_checkpointing`. @requires_torch_xla diff --git a/optimum/neuron/distributed/decoder_models.py b/optimum/neuron/distributed/decoder_models.py index 94c9c0f81..ff7da77b2 100644 --- a/optimum/neuron/distributed/decoder_models.py +++ b/optimum/neuron/distributed/decoder_models.py @@ -36,6 +36,7 @@ MistralRMSNorm, ) +from ..models.core import NeuronAttention from .base import Parallelizer, PipelineParallelismSpecs, SequenceParallelismSpecs from .parallel_layers import ( LayerNormType, @@ -432,12 +433,12 @@ class LlamaSequenceParallelismSpecs(SequenceParallelismSpecs): @classmethod def patch_for_sequence_parallelism(cls, model: "PreTrainedModel", sequence_parallel_enabled: bool): - if not sequence_parallel_enabled: - return - for module in model.modules(): - if isinstance(module, LlamaAttention): - module.forward = attention_forward.__get__(module) + if isinstance(module, LlamaAttention) and not isinstance(module, NeuronAttention): + raise ValueError( + "The llama model has not been prepare by the NeuronPreparator. It is required for sequence " + "parallelism." + ) class LlamaPipelineParallelismSpecs(PipelineParallelismSpecs): diff --git a/optimum/neuron/models/core.py b/optimum/neuron/models/core.py index ec58ccdc0..2a5978223 100644 --- a/optimum/neuron/models/core.py +++ b/optimum/neuron/models/core.py @@ -14,15 +14,19 @@ # limitations under the License. """Core functionalities and tools for rewriting modules for Neuron.""" +import functools +import gc import math +import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union import torch import torch.nn as nn from transformers.modeling_utils import get_parameter_dtype -from ..utils.require_utils import requires_neuronx_distributed +from ..utils.patching import Patcher +from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors if TYPE_CHECKING: @@ -53,6 +57,28 @@ def patched_get_parameter_dtype(module): return patched_get_parameter_dtype +@requires_neuronx_distributed +@requires_safetensors +def torch_xla_safe_save_file( + tensors: Dict[str, torch.Tensor], + filename: Union[str, "os.PathLike"], + metadata: Optional[Dict[str, str]] = None, + master_only: bool = True, + global_master: bool = False, +): + """ + Torch XLA compatible implementation of `safetensors.torch.save_file`. + """ + from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu + from safetensors.torch import save_file + from torch_xla.core.xla_model import is_master_ordinal + + should_write_data = not master_only or is_master_ordinal(local=not global_master) + cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data) + if should_write_data: + save_file(cpu_data, filename, metadata=metadata) + + @requires_neuronx_distributed def create_patched_save_pretrained(orig_save_pretrained_function: Callable[["PreTrainedModel"], None]): """ diff --git a/optimum/neuron/models/preparator.py b/optimum/neuron/models/preparator.py index 3ce7d4837..bd28615e0 100644 --- a/optimum/neuron/models/preparator.py +++ b/optimum/neuron/models/preparator.py @@ -16,7 +16,7 @@ import contextlib import importlib -from typing import Dict +from typing import Any, Dict, List, Optional, Tuple import torch from transformers import PreTrainedModel @@ -34,13 +34,11 @@ from .core import create_patched_finfo, create_patched_save_pretrained -MODEL_PATCHING_SPECS = [ +DEFAULT_MODEL_PATCHING_SPECS = [ ("config.layerdrop", 0), ("no_sync", lambda: contextlib.nullcontext()), ] -NxDPPMODEL_PATCHING_SPECS = [] - class NeuronPreparator: _TRANSFORMERS_TO_NEURON_CLASSES: Dict[str, Dict[str, str]] = { @@ -52,6 +50,10 @@ class NeuronPreparator: @classmethod def prepare_modeling(cls, model: PreTrainedModel, **options): + """ + Prepares the modeling of a model by potentially changing some of its modules with Neuron optimized versions of + them. + """ if model.config.model_type not in cls._TRANSFORMERS_TO_NEURON_CLASSES: return @@ -74,8 +76,11 @@ def patch_model_for_neuron( model: "torch.nn.Module", patching_specs: Optional[List[Tuple[str, Any]]] = None, ) -> "torch.nn.Module": + """ + Patches the model in various ways to make sure it works properly on Neuron devices. + """ if patching_specs is None: - patching_specs = MODEL_PATCHING_SPECS + patching_specs = DEFAULT_MODEL_PATCHING_SPECS # Working on a copy for safety. patching_specs = list(patching_specs) @@ -136,4 +141,3 @@ def patch_model_for_neuron( "It appears that the model is using a PEFT method, please wrap your model with `PeftModel` " "to make it work with `optimum-neuron`" ) - return model