From b6c0217a3631cf697478db76a988215e01d08679 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 8 Jul 2024 15:41:11 +0200 Subject: [PATCH] [WIP] optimum/neuron/models --- optimum/neuron/accelerate/accelerator.py | 84 +---------- optimum/neuron/accelerate/utils/misc.py | 71 +-------- optimum/neuron/models/__init__.py | 16 ++ optimum/neuron/models/core.py | 84 ++++++++++- optimum/neuron/models/modeling_llama.py | 34 ++++- optimum/neuron/models/preparator.py | 139 ++++++++++++++++++ .../torch_xla_and_neuronx_initialization.py | 2 +- 7 files changed, 271 insertions(+), 159 deletions(-) create mode 100644 optimum/neuron/models/__init__.py create mode 100644 optimum/neuron/models/preparator.py diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 05d5a8a3e..ff389ef0d 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -23,7 +23,7 @@ import warnings from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from accelerate import Accelerator @@ -33,19 +33,14 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from transformers import PreTrainedModel -from transformers.utils import is_peft_available from ...utils import logging from ..distributed import Parallelizer, ParallelizersManager from ..utils import ( - DynamicPatch, - ModelPatcher, - NeuronPeftModel, Patcher, is_neuronx_distributed_available, is_torch_xla_available, patch_within_function, - replace_class_in_inheritance_hierarchy, ) from ..utils.misc import args_and_kwargs_to_kwargs_only, is_main_worker from ..utils.model_utils import get_tied_parameters_dict, tie_parameters @@ -62,8 +57,6 @@ ) from .utils.misc import ( apply_activation_checkpointing, - create_patched_finfo, - create_patched_save_pretrained, ) from .utils.operations import _xla_gather @@ -87,14 +80,6 @@ logger = logging.get_logger(__name__) -MODEL_PATCHING_SPECS = [ - ("config.layerdrop", 0), - ("no_sync", lambda: contextlib.nullcontext()), -] - -NxDPPMODEL_PATCHING_SPECS = [] - - class NeuronAccelerator(Accelerator): def __init__( self, @@ -318,73 +303,6 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement: def prepare_scheduler(self, scheduler: "LRScheduler"): return super().prepare_scheduler(scheduler) - def patch_model_for_neuron( - self, - model: "torch.nn.Module", - patching_specs: Optional[List[Tuple[str, Any]]] = None, - ) -> "torch.nn.Module": - if patching_specs is None: - patching_specs = MODEL_PATCHING_SPECS - - # Working on a copy for safety. - patching_specs = list(patching_specs) - - mixed_precision_is_bf16 = self.state.mixed_precision == "bf16" - patched_finfo = create_patched_finfo( - xla_downcast_bf16=mixed_precision_is_bf16 and self.state.downcast_bfloat, - use_amp=mixed_precision_is_bf16 and self.state.autocast_backend is AutocastBackend.AMP, - xla_use_bf16=mixed_precision_is_bf16 and not self.state.downcast_bfloat, - ) - patching_specs.append( - ( - "forward", - DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))), - ), - ) - - if isinstance(model, PreTrainedModel): - patching_specs.append( - ( - "save_pretrained", - DynamicPatch(create_patched_save_pretrained), - ), - ) - - # TODO: @michaelbenayoun generalize an implementation of gradient checkpointing working for: - # - DDP - # - TP - # - PP - # if hasattr(model, "gradient_checkpointing_enable"): - # patching_specs.append( - # ( - # "gradient_checkpointing_enable", - # patched_gradient_checkpointing_enable, - # ), - # ) - - prepared_patching_specs = [] - for spec in patching_specs: - prepared_patching_specs.append((model,) + spec) - - model_patcher = ModelPatcher(prepared_patching_specs, ignore_missing_attributes=True) - model_patcher.patch() - - if is_peft_available(): - from peft import PeftModel - from peft.tuners.tuners_utils import BaseTunerLayer - from peft.utils import ModulesToSaveWrapper - - if isinstance(model, PeftModel): - replace_class_in_inheritance_hierarchy(model, PeftModel, NeuronPeftModel) - else: - for _, module in model.named_modules(): - if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): - raise ValueError( - "It appears that the model is using a PEFT method, please wrap your model with `PeftModel` " - "to make it work with `optimum-neuron`" - ) - return model - @requires_neuronx_distributed def _prepare_model_for_mp( self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index 15d094691..3eb06f23c 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -15,16 +15,13 @@ """Utilities of various sorts related to accelerate with Neuron.""" import functools -import gc import inspect -from typing import TYPE_CHECKING, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Dict, Optional, Union import torch -from transformers.modeling_utils import get_parameter_dtype from ....utils import logging from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere -from ...utils.patching import Patcher from ...utils.peft_utils import NeuronPeftModel from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla @@ -63,30 +60,6 @@ def patch_accelerate_is_torch_xla_available(): ) -_ORIG_TORCH_FINFO = torch.finfo - - -def create_patched_finfo(xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False): - def patched_finfo(dtype): - if xla_downcast_bf16 or use_amp or xla_use_bf16: - return _ORIG_TORCH_FINFO(torch.bfloat16) - return _ORIG_TORCH_FINFO(dtype) - - return patched_finfo - - -def create_patched_get_parameter_dtype( - xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False -): - def patched_get_parameter_dtype(module): - dtype = get_parameter_dtype(module) - if xla_downcast_bf16 or use_amp or xla_use_bf16: - return torch.bfloat16 - return dtype - - return patched_get_parameter_dtype - - @requires_neuronx_distributed @requires_safetensors def torch_xla_safe_save_file( @@ -109,48 +82,6 @@ def torch_xla_safe_save_file( save_file(cpu_data, filename, metadata=metadata) -@requires_neuronx_distributed -def create_patched_save_pretrained(orig_save_pretrained_function: Callable[["PreTrainedModel"], None]): - """ - Creates a wrapper around the `transformers.modeling_utils.PreTrainedModel.save_pretrained` method. - This methods calls `tensor.data_ptr()` on the model parameters, which causes segmentation fault when the tensors - are on the XLA device. To prevent that, this wrapper calls `save_pretrained` with the model on the CPU device. - """ - import torch_xla.core.xla_model as xm - from neuronx_distributed.parallel_layers.parallel_state import ( - get_data_parallel_rank, - model_parallel_is_initialized, - ) - from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu - - orig_self = orig_save_pretrained_function.__self__ - orig_func = orig_save_pretrained_function.__func__ - - patcher = Patcher([("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)]) - - @functools.wraps(orig_func) - def wrapper(*args, **kwargs): - self = args[0] - if model_parallel_is_initialized(): - should_write_data = get_data_parallel_rank() == 0 - else: - should_write_data = xm.is_master_ordinal(local=True) - orig_state_dict = self.state_dict() - cpu_state_dict = move_all_tensor_to_cpu(self.state_dict(), convert=should_write_data) - self.load_state_dict(cpu_state_dict, assign=True) - output = None - if should_write_data: - with patcher: - output = orig_func(*args, **kwargs) - self.load_state_dict(orig_state_dict, assign=True) - xm.mark_step() - del cpu_state_dict - gc.collect() - return output - - return wrapper.__get__(orig_self) - - # TODO: @michaelbenayoun # Needs to make it work in the general case or be deleted and only use `apply_activation_checkpointing`. @requires_torch_xla diff --git a/optimum/neuron/models/__init__.py b/optimum/neuron/models/__init__.py new file mode 100644 index 000000000..6fba0ec31 --- /dev/null +++ b/optimum/neuron/models/__init__.py @@ -0,0 +1,16 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .preparator import NeuronPreparator diff --git a/optimum/neuron/models/core.py b/optimum/neuron/models/core.py index 0aa144798..ec58ccdc0 100644 --- a/optimum/neuron/models/core.py +++ b/optimum/neuron/models/core.py @@ -15,13 +15,93 @@ """Core functionalities and tools for rewriting modules for Neuron.""" import math -from typing import Optional +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Callable, Optional import torch import torch.nn as nn +from transformers.modeling_utils import get_parameter_dtype +from ..utils.require_utils import requires_neuronx_distributed -class NeuronAttention: + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +_ORIG_TORCH_FINFO = torch.finfo + + +def create_patched_finfo(xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False): + def patched_finfo(dtype): + if xla_downcast_bf16 or use_amp or xla_use_bf16: + return _ORIG_TORCH_FINFO(torch.bfloat16) + return _ORIG_TORCH_FINFO(dtype) + + return patched_finfo + + +def create_patched_get_parameter_dtype( + xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False +): + def patched_get_parameter_dtype(module): + dtype = get_parameter_dtype(module) + if xla_downcast_bf16 or use_amp or xla_use_bf16: + return torch.bfloat16 + return dtype + + return patched_get_parameter_dtype + + +@requires_neuronx_distributed +def create_patched_save_pretrained(orig_save_pretrained_function: Callable[["PreTrainedModel"], None]): + """ + Creates a wrapper around the `transformers.modeling_utils.PreTrainedModel.save_pretrained` method. + This methods calls `tensor.data_ptr()` on the model parameters, which causes segmentation fault when the tensors + are on the XLA device. To prevent that, this wrapper calls `save_pretrained` with the model on the CPU device. + """ + import torch_xla.core.xla_model as xm + from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_rank, + model_parallel_is_initialized, + ) + from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu + + orig_self = orig_save_pretrained_function.__self__ + orig_func = orig_save_pretrained_function.__func__ + + patcher = Patcher([("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)]) + + @functools.wraps(orig_func) + def wrapper(*args, **kwargs): + self = args[0] + if model_parallel_is_initialized(): + should_write_data = get_data_parallel_rank() == 0 + else: + should_write_data = xm.is_master_ordinal(local=True) + orig_state_dict = self.state_dict() + cpu_state_dict = move_all_tensor_to_cpu(self.state_dict(), convert=should_write_data) + self.load_state_dict(cpu_state_dict, assign=True) + output = None + if should_write_data: + with patcher: + output = orig_func(*args, **kwargs) + self.load_state_dict(orig_state_dict, assign=True) + xm.mark_step() + del cpu_state_dict + gc.collect() + return output + + return wrapper.__get__(orig_self) + + +class PatchedModule(ABC): + @abstractmethod + def from_original(cls, orig_module: torch.nn.Module, **options) -> "PatchedModule": + pass + + +class NeuronAttention(PatchedModule): # TODO: add dosctring @property def sequence_parallel_enabled(self) -> bool: diff --git a/optimum/neuron/models/modeling_llama.py b/optimum/neuron/models/modeling_llama.py index 7d1a32aae..63b4ffdcc 100644 --- a/optimum/neuron/models/modeling_llama.py +++ b/optimum/neuron/models/modeling_llama.py @@ -14,21 +14,21 @@ # limitations under the License. """Parallelization of the Llama architecture.""" -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn.functional as F - from transformers import LlamaConfig from transformers.cache_utils import Cache from transformers.models.llama.modeling_llama import ( LlamaAttention, + LlamaModel, apply_rotary_pos_emb, repeat_kv, ) from ..utils.require_utils import requires_neuronx_distributed -from .core import NeuronAttention, CoreAttention +from .core import CoreAttention, NeuronAttention, PatchedModule class NeuronLlamaAttention(LlamaAttention, NeuronAttention): @@ -36,6 +36,12 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx=layer_idx) self.core_attn = CoreAttention() + @classmethod + def from_original(cls, orig_module: torch.nn.Module, **options) -> "NeuronLlamaAttention": + orig_module.core_attn = CoreAttention() + orig_module.__class__ = cls + return orig_module + @requires_neuronx_distributed def forward( self, @@ -130,3 +136,25 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +class NeuronLlamaModel(LlamaModel, PatchedModule): + @classmethod + def from_original(cls, orig_module: torch.nn.Module, **options) -> "NeuronLlamaModel": + orig_module.__class__ = cls + return orig_module + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: work on the validity of that. + if self.training: + return None + return super()._update_causal_mask( + attention_mask, input_tensor, cache_position, past_key_values, output_attentions + ) diff --git a/optimum/neuron/models/preparator.py b/optimum/neuron/models/preparator.py new file mode 100644 index 000000000..3ce7d4837 --- /dev/null +++ b/optimum/neuron/models/preparator.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implements the NeuronPreparator class, which transforms a model with custom modules defined for Neuron.""" + +import contextlib +import importlib +from typing import Dict + +import torch +from transformers import PreTrainedModel +from transformers.utils import is_peft_available + +from ..accelerate.state import NeuronAcceleratorState +from ..accelerate.utils import AutocastBackend +from ..utils import ( + DynamicPatch, + ModelPatcher, + NeuronPeftModel, + patch_within_function, + replace_class_in_inheritance_hierarchy, +) +from .core import create_patched_finfo, create_patched_save_pretrained + + +MODEL_PATCHING_SPECS = [ + ("config.layerdrop", 0), + ("no_sync", lambda: contextlib.nullcontext()), +] + +NxDPPMODEL_PATCHING_SPECS = [] + + +class NeuronPreparator: + _TRANSFORMERS_TO_NEURON_CLASSES: Dict[str, Dict[str, str]] = { + "llama": { + "LlamaAttention": "NeuronLlamaAttention", + "LlamaModel": "NeuronLlamaModel", + } + } + + @classmethod + def prepare_modeling(cls, model: PreTrainedModel, **options): + if model.config.model_type not in cls._TRANSFORMERS_TO_NEURON_CLASSES: + return + + patches = cls._TRANSFORMERS_TO_NEURON_CLASSES[model.config.model_type] + module = importlib.import_module(f"..modeling_{model.config.model_type}.py") + for name, mod in model.modules(): + replacement_cls_name = patches.get(mod.__class__.__name__, "") + if replacement_cls_name: + names = name.rsplit(".", maxsplit=1) + if len(names) == 1: + parent, attr_name = model, names[0] + else: + parent, attr_name = model.get_submodule(names[0]), names[1] + replacement_cls = getattr(module, replacement_cls_name) + setattr(parent, attr_name, replacement_cls.from_original(mod, **options)) + + @classmethod + def patch_model_for_neuron( + cls, + model: "torch.nn.Module", + patching_specs: Optional[List[Tuple[str, Any]]] = None, + ) -> "torch.nn.Module": + if patching_specs is None: + patching_specs = MODEL_PATCHING_SPECS + + # Working on a copy for safety. + patching_specs = list(patching_specs) + + accelerator_state = NeuronAcceleratorState() + + mixed_precision_is_bf16 = accelerator_state.mixed_precision == "bf16" + patched_finfo = create_patched_finfo( + xla_downcast_bf16=mixed_precision_is_bf16 and accelerator_state.downcast_bfloat, + use_amp=mixed_precision_is_bf16 and accelerator_state.autocast_backend is AutocastBackend.AMP, + xla_use_bf16=mixed_precision_is_bf16 and not accelerator_state.downcast_bfloat, + ) + patching_specs.append( + ( + "forward", + DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))), + ), + ) + + if isinstance(model, PreTrainedModel): + patching_specs.append( + ( + "save_pretrained", + DynamicPatch(create_patched_save_pretrained), + ), + ) + + # TODO: @michaelbenayoun generalize an implementation of gradient checkpointing working for: + # - DDP + # - TP + # - PP + # if hasattr(model, "gradient_checkpointing_enable"): + # patching_specs.append( + # ( + # "gradient_checkpointing_enable", + # patched_gradient_checkpointing_enable, + # ), + # ) + + prepared_patching_specs = [] + for spec in patching_specs: + prepared_patching_specs.append((model,) + spec) + + model_patcher = ModelPatcher(prepared_patching_specs, ignore_missing_attributes=True) + model_patcher.patch() + + if is_peft_available(): + from peft import PeftModel + from peft.tuners.tuners_utils import BaseTunerLayer + from peft.utils import ModulesToSaveWrapper + + if isinstance(model, PeftModel): + replace_class_in_inheritance_hierarchy(model, PeftModel, NeuronPeftModel) + else: + for _, module in model.named_modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + raise ValueError( + "It appears that the model is using a PEFT method, please wrap your model with `PeftModel` " + "to make it work with `optimum-neuron`" + ) + return model diff --git a/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py index b0e78e6ec..01ca8c379 100644 --- a/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py +++ b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py @@ -85,7 +85,7 @@ def set_neuron_cc_optlevel(optlevel: int = 2): def check_neuron_cc_flags_for_model(model: "PreTrainedModel"): """ - Sets flags for the Neuron compiler depending on the model. + Checks flags for the Neuron compiler depending on the model. """ neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "") if "ForCausalLM" or "ForConditionalGeneration" in model.__class__.__name__: