Skip to content

Commit

Permalink
[WIP] optimum/neuron/models
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jul 8, 2024
1 parent 3948b7c commit b6c0217
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 159 deletions.
84 changes: 1 addition & 83 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import warnings
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch
from accelerate import Accelerator
Expand All @@ -33,19 +33,14 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import PreTrainedModel
from transformers.utils import is_peft_available

from ...utils import logging
from ..distributed import Parallelizer, ParallelizersManager
from ..utils import (
DynamicPatch,
ModelPatcher,
NeuronPeftModel,
Patcher,
is_neuronx_distributed_available,
is_torch_xla_available,
patch_within_function,
replace_class_in_inheritance_hierarchy,
)
from ..utils.misc import args_and_kwargs_to_kwargs_only, is_main_worker
from ..utils.model_utils import get_tied_parameters_dict, tie_parameters
Expand All @@ -62,8 +57,6 @@
)
from .utils.misc import (
apply_activation_checkpointing,
create_patched_finfo,
create_patched_save_pretrained,
)
from .utils.operations import _xla_gather

Expand All @@ -87,14 +80,6 @@
logger = logging.get_logger(__name__)


MODEL_PATCHING_SPECS = [
("config.layerdrop", 0),
("no_sync", lambda: contextlib.nullcontext()),
]

NxDPPMODEL_PATCHING_SPECS = []


class NeuronAccelerator(Accelerator):
def __init__(
self,
Expand Down Expand Up @@ -318,73 +303,6 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement:
def prepare_scheduler(self, scheduler: "LRScheduler"):
return super().prepare_scheduler(scheduler)

def patch_model_for_neuron(
self,
model: "torch.nn.Module",
patching_specs: Optional[List[Tuple[str, Any]]] = None,
) -> "torch.nn.Module":
if patching_specs is None:
patching_specs = MODEL_PATCHING_SPECS

# Working on a copy for safety.
patching_specs = list(patching_specs)

mixed_precision_is_bf16 = self.state.mixed_precision == "bf16"
patched_finfo = create_patched_finfo(
xla_downcast_bf16=mixed_precision_is_bf16 and self.state.downcast_bfloat,
use_amp=mixed_precision_is_bf16 and self.state.autocast_backend is AutocastBackend.AMP,
xla_use_bf16=mixed_precision_is_bf16 and not self.state.downcast_bfloat,
)
patching_specs.append(
(
"forward",
DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))),
),
)

if isinstance(model, PreTrainedModel):
patching_specs.append(
(
"save_pretrained",
DynamicPatch(create_patched_save_pretrained),
),
)

# TODO: @michaelbenayoun generalize an implementation of gradient checkpointing working for:
# - DDP
# - TP
# - PP
# if hasattr(model, "gradient_checkpointing_enable"):
# patching_specs.append(
# (
# "gradient_checkpointing_enable",
# patched_gradient_checkpointing_enable,
# ),
# )

prepared_patching_specs = []
for spec in patching_specs:
prepared_patching_specs.append((model,) + spec)

model_patcher = ModelPatcher(prepared_patching_specs, ignore_missing_attributes=True)
model_patcher.patch()

if is_peft_available():
from peft import PeftModel
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper

if isinstance(model, PeftModel):
replace_class_in_inheritance_hierarchy(model, PeftModel, NeuronPeftModel)
else:
for _, module in model.named_modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
raise ValueError(
"It appears that the model is using a PEFT method, please wrap your model with `PeftModel` "
"to make it work with `optimum-neuron`"
)
return model

@requires_neuronx_distributed
def _prepare_model_for_mp(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
Expand Down
71 changes: 1 addition & 70 deletions optimum/neuron/accelerate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@
"""Utilities of various sorts related to accelerate with Neuron."""

import functools
import gc
import inspect
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Dict, Optional, Union

import torch
from transformers.modeling_utils import get_parameter_dtype

from ....utils import logging
from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere
from ...utils.patching import Patcher
from ...utils.peft_utils import NeuronPeftModel
from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla

Expand Down Expand Up @@ -63,30 +60,6 @@ def patch_accelerate_is_torch_xla_available():
)


_ORIG_TORCH_FINFO = torch.finfo


def create_patched_finfo(xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False):
def patched_finfo(dtype):
if xla_downcast_bf16 or use_amp or xla_use_bf16:
return _ORIG_TORCH_FINFO(torch.bfloat16)
return _ORIG_TORCH_FINFO(dtype)

return patched_finfo


def create_patched_get_parameter_dtype(
xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False
):
def patched_get_parameter_dtype(module):
dtype = get_parameter_dtype(module)
if xla_downcast_bf16 or use_amp or xla_use_bf16:
return torch.bfloat16
return dtype

return patched_get_parameter_dtype


@requires_neuronx_distributed
@requires_safetensors
def torch_xla_safe_save_file(
Expand All @@ -109,48 +82,6 @@ def torch_xla_safe_save_file(
save_file(cpu_data, filename, metadata=metadata)


@requires_neuronx_distributed
def create_patched_save_pretrained(orig_save_pretrained_function: Callable[["PreTrainedModel"], None]):
"""
Creates a wrapper around the `transformers.modeling_utils.PreTrainedModel.save_pretrained` method.
This methods calls `tensor.data_ptr()` on the model parameters, which causes segmentation fault when the tensors
are on the XLA device. To prevent that, this wrapper calls `save_pretrained` with the model on the CPU device.
"""
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
model_parallel_is_initialized,
)
from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu

orig_self = orig_save_pretrained_function.__self__
orig_func = orig_save_pretrained_function.__func__

patcher = Patcher([("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)])

@functools.wraps(orig_func)
def wrapper(*args, **kwargs):
self = args[0]
if model_parallel_is_initialized():
should_write_data = get_data_parallel_rank() == 0
else:
should_write_data = xm.is_master_ordinal(local=True)
orig_state_dict = self.state_dict()
cpu_state_dict = move_all_tensor_to_cpu(self.state_dict(), convert=should_write_data)
self.load_state_dict(cpu_state_dict, assign=True)
output = None
if should_write_data:
with patcher:
output = orig_func(*args, **kwargs)
self.load_state_dict(orig_state_dict, assign=True)
xm.mark_step()
del cpu_state_dict
gc.collect()
return output

return wrapper.__get__(orig_self)


# TODO: @michaelbenayoun
# Needs to make it work in the general case or be deleted and only use `apply_activation_checkpointing`.
@requires_torch_xla
Expand Down
16 changes: 16 additions & 0 deletions optimum/neuron/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .preparator import NeuronPreparator
84 changes: 82 additions & 2 deletions optimum/neuron/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,93 @@
"""Core functionalities and tools for rewriting modules for Neuron."""

import math
from typing import Optional
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Optional

import torch
import torch.nn as nn
from transformers.modeling_utils import get_parameter_dtype

from ..utils.require_utils import requires_neuronx_distributed

class NeuronAttention:

if TYPE_CHECKING:
from transformers import PreTrainedModel


_ORIG_TORCH_FINFO = torch.finfo


def create_patched_finfo(xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False):
def patched_finfo(dtype):
if xla_downcast_bf16 or use_amp or xla_use_bf16:
return _ORIG_TORCH_FINFO(torch.bfloat16)
return _ORIG_TORCH_FINFO(dtype)

return patched_finfo


def create_patched_get_parameter_dtype(
xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False
):
def patched_get_parameter_dtype(module):
dtype = get_parameter_dtype(module)
if xla_downcast_bf16 or use_amp or xla_use_bf16:
return torch.bfloat16
return dtype

return patched_get_parameter_dtype


@requires_neuronx_distributed
def create_patched_save_pretrained(orig_save_pretrained_function: Callable[["PreTrainedModel"], None]):
"""
Creates a wrapper around the `transformers.modeling_utils.PreTrainedModel.save_pretrained` method.
This methods calls `tensor.data_ptr()` on the model parameters, which causes segmentation fault when the tensors
are on the XLA device. To prevent that, this wrapper calls `save_pretrained` with the model on the CPU device.
"""
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
model_parallel_is_initialized,
)
from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu

orig_self = orig_save_pretrained_function.__self__
orig_func = orig_save_pretrained_function.__func__

patcher = Patcher([("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)])

@functools.wraps(orig_func)
def wrapper(*args, **kwargs):
self = args[0]
if model_parallel_is_initialized():
should_write_data = get_data_parallel_rank() == 0
else:
should_write_data = xm.is_master_ordinal(local=True)
orig_state_dict = self.state_dict()
cpu_state_dict = move_all_tensor_to_cpu(self.state_dict(), convert=should_write_data)
self.load_state_dict(cpu_state_dict, assign=True)
output = None
if should_write_data:
with patcher:
output = orig_func(*args, **kwargs)
self.load_state_dict(orig_state_dict, assign=True)
xm.mark_step()
del cpu_state_dict
gc.collect()
return output

return wrapper.__get__(orig_self)


class PatchedModule(ABC):
@abstractmethod
def from_original(cls, orig_module: torch.nn.Module, **options) -> "PatchedModule":
pass


class NeuronAttention(PatchedModule):
# TODO: add dosctring
@property
def sequence_parallel_enabled(self) -> bool:
Expand Down
Loading

0 comments on commit b6c0217

Please sign in to comment.