Skip to content

Commit

Permalink
[WIP] optimum/neuron/models
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jul 8, 2024
1 parent b6c0217 commit 34df637
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 40 deletions.
8 changes: 6 additions & 2 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from ...utils import logging
from ..distributed import Parallelizer, ParallelizersManager
from ..models.preparator import NeuronPreparator
from ..utils import (
Patcher,
is_neuronx_distributed_available,
Expand Down Expand Up @@ -79,6 +80,8 @@

logger = logging.get_logger(__name__)

NxDPPMODEL_PATCHING_SPECS = []


class NeuronAccelerator(Accelerator):
def __init__(
Expand Down Expand Up @@ -322,7 +325,7 @@ def _prepare_model_for_mp(
setattr(model, "main_input_name", model_main_input_name)

if isinstance(model, NxDPPModel):
model.local_module = self.patch_model_for_neuron(
model.local_module = NeuronPreparator.patch_model_for_neuron(
model.local_module, patching_specs=NxDPPMODEL_PATCHING_SPECS
)

Expand Down Expand Up @@ -374,7 +377,8 @@ def prepare_model(
# we get access to the model, we simply check if the flags are the best and notify the user otherwise.
check_neuron_cc_flags_for_model(model)

model = self.patch_model_for_neuron(model)
NeuronPreparator.prepare_modeling(model)
NeuronPreparator.patch_model_for_neuron(model)

# We do not want to use the cache, or output unused tensors as it would imply more communication that we do not
# need.
Expand Down
27 changes: 2 additions & 25 deletions optimum/neuron/accelerate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@

import functools
import inspect
from typing import TYPE_CHECKING, Dict, Optional, Union
from typing import TYPE_CHECKING, Union

import torch

from ....utils import logging
from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere
from ...utils.peft_utils import NeuronPeftModel
from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla
from ...utils.require_utils import requires_neuronx_distributed, requires_torch_xla


logger = logging.get_logger(__name__)

if TYPE_CHECKING:
import os

from transformers import PreTrainedModel

Expand Down Expand Up @@ -60,28 +59,6 @@ def patch_accelerate_is_torch_xla_available():
)


@requires_neuronx_distributed
@requires_safetensors
def torch_xla_safe_save_file(
tensors: Dict[str, torch.Tensor],
filename: Union[str, "os.PathLike"],
metadata: Optional[Dict[str, str]] = None,
master_only: bool = True,
global_master: bool = False,
):
"""
Torch XLA compatible implementation of `safetensors.torch.save_file`.
"""
from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu
from safetensors.torch import save_file
from torch_xla.core.xla_model import is_master_ordinal

should_write_data = not master_only or is_master_ordinal(local=not global_master)
cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data)
if should_write_data:
save_file(cpu_data, filename, metadata=metadata)


# TODO: @michaelbenayoun
# Needs to make it work in the general case or be deleted and only use `apply_activation_checkpointing`.
@requires_torch_xla
Expand Down
11 changes: 6 additions & 5 deletions optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
MistralRMSNorm,
)

from ..models.core import NeuronAttention
from .base import Parallelizer, PipelineParallelismSpecs, SequenceParallelismSpecs
from .parallel_layers import (
LayerNormType,
Expand Down Expand Up @@ -432,12 +433,12 @@ class LlamaSequenceParallelismSpecs(SequenceParallelismSpecs):

@classmethod
def patch_for_sequence_parallelism(cls, model: "PreTrainedModel", sequence_parallel_enabled: bool):
if not sequence_parallel_enabled:
return

for module in model.modules():
if isinstance(module, LlamaAttention):
module.forward = attention_forward.__get__(module)
if isinstance(module, LlamaAttention) and not isinstance(module, NeuronAttention):
raise ValueError(
"The llama model has not been prepare by the NeuronPreparator. It is required for sequence "
"parallelism."
)


class LlamaPipelineParallelismSpecs(PipelineParallelismSpecs):
Expand Down
30 changes: 28 additions & 2 deletions optimum/neuron/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
# limitations under the License.
"""Core functionalities and tools for rewriting modules for Neuron."""

import functools
import gc
import math
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union

import torch
import torch.nn as nn
from transformers.modeling_utils import get_parameter_dtype

from ..utils.require_utils import requires_neuronx_distributed
from ..utils.patching import Patcher
from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors


if TYPE_CHECKING:
Expand Down Expand Up @@ -53,6 +57,28 @@ def patched_get_parameter_dtype(module):
return patched_get_parameter_dtype


@requires_neuronx_distributed
@requires_safetensors
def torch_xla_safe_save_file(
tensors: Dict[str, torch.Tensor],
filename: Union[str, "os.PathLike"],
metadata: Optional[Dict[str, str]] = None,
master_only: bool = True,
global_master: bool = False,
):
"""
Torch XLA compatible implementation of `safetensors.torch.save_file`.
"""
from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu
from safetensors.torch import save_file
from torch_xla.core.xla_model import is_master_ordinal

should_write_data = not master_only or is_master_ordinal(local=not global_master)
cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data)
if should_write_data:
save_file(cpu_data, filename, metadata=metadata)


@requires_neuronx_distributed
def create_patched_save_pretrained(orig_save_pretrained_function: Callable[["PreTrainedModel"], None]):
"""
Expand Down
16 changes: 10 additions & 6 deletions optimum/neuron/models/preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import contextlib
import importlib
from typing import Dict
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers import PreTrainedModel
Expand All @@ -34,13 +34,11 @@
from .core import create_patched_finfo, create_patched_save_pretrained


MODEL_PATCHING_SPECS = [
DEFAULT_MODEL_PATCHING_SPECS = [
("config.layerdrop", 0),
("no_sync", lambda: contextlib.nullcontext()),
]

NxDPPMODEL_PATCHING_SPECS = []


class NeuronPreparator:
_TRANSFORMERS_TO_NEURON_CLASSES: Dict[str, Dict[str, str]] = {
Expand All @@ -52,6 +50,10 @@ class NeuronPreparator:

@classmethod
def prepare_modeling(cls, model: PreTrainedModel, **options):
"""
Prepares the modeling of a model by potentially changing some of its modules with Neuron optimized versions of
them.
"""
if model.config.model_type not in cls._TRANSFORMERS_TO_NEURON_CLASSES:
return

Expand All @@ -74,8 +76,11 @@ def patch_model_for_neuron(
model: "torch.nn.Module",
patching_specs: Optional[List[Tuple[str, Any]]] = None,
) -> "torch.nn.Module":
"""
Patches the model in various ways to make sure it works properly on Neuron devices.
"""
if patching_specs is None:
patching_specs = MODEL_PATCHING_SPECS
patching_specs = DEFAULT_MODEL_PATCHING_SPECS

# Working on a copy for safety.
patching_specs = list(patching_specs)
Expand Down Expand Up @@ -136,4 +141,3 @@ def patch_model_for_neuron(
"It appears that the model is using a PEFT method, please wrap your model with `PeftModel` "
"to make it work with `optimum-neuron`"
)
return model

0 comments on commit 34df637

Please sign in to comment.