Skip to content

Commit

Permalink
Performance improvements and neuron_parallel_compile and gradient c…
Browse files Browse the repository at this point in the history
…heckpointing fixes (#602)
  • Loading branch information
michaelbenayoun authored May 29, 2024
1 parent 1efb43d commit c7f1b7e
Show file tree
Hide file tree
Showing 14 changed files with 214 additions and 109 deletions.
45 changes: 39 additions & 6 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
patch_accelerate_is_torch_xla_available,
tie_parameters,
)
from .utils.misc import apply_activation_checkpointing, create_patched_finfo, create_patched_save_pretrained
from .utils.misc import (
apply_activation_checkpointing,
create_patched_finfo,
create_patched_save_pretrained,
)
from .utils.operations import _xla_gather


Expand Down Expand Up @@ -132,6 +136,15 @@ def __init__(
if not isinstance(autocast_backend, AutocastBackend):
autocast_backend = AutocastBackend(autocast_backend)

# The original `is_torch_xla_available` function is checking for TPU or GPU in `accelerate`.
# Here, we patch it to return True for Neuron cores as well.
def patched_is_torch_xla_available(check_is_tpu: bool = False, check_is_gpu: bool = False) -> bool:
return is_torch_xla_available()

import accelerate

accelerate.state.is_torch_xla_available = patched_is_torch_xla_available

patched_accelerator_state = partial(
NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend
)
Expand Down Expand Up @@ -336,13 +349,24 @@ def patch_model_for_neuron(
),
)

# TODO: @michaelbenayoun generalize an implementation of gradient checkpointing working for:
# - DDP
# - TP
# - PP
# if hasattr(model, "gradient_checkpointing_enable"):
# patching_specs.append(
# (
# "gradient_checkpointing_enable",
# patched_gradient_checkpointing_enable,
# ),
# )

prepared_patching_specs = []
for spec in patching_specs:
prepared_patching_specs.append((model,) + spec)

model_patcher = ModelPatcher(prepared_patching_specs, ignore_missing_attributes=True)
model_patcher.patch()

return model

@requires_neuronx_distributed
Expand Down Expand Up @@ -428,6 +452,12 @@ def prepare_model(
model.config.output_attentions = False
model.config.output_hidden_states = False

should_apply_activation_checkpointing = False
for mod in model.modules():
if getattr(mod, "gradient_checkpointing", False):
should_apply_activation_checkpointing = True
model.gradient_checkpointing_disable()

# It is needed for now otherwise sdpa is used since PT > 2.* is available.
for module in model.modules():
if getattr(module, "_use_sdpa", False):
Expand All @@ -439,13 +469,16 @@ def prepare_model(
model = self._prepare_model_for_mp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
)
apply_activation_checkpointing(model)
return model
if should_apply_activation_checkpointing:
apply_activation_checkpointing(model)
else:
apply_activation_checkpointing(model)
if should_apply_activation_checkpointing:
apply_activation_checkpointing(model)
move_model_to_device(model, xm.xla_device())
device_placement = False
return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
xm.mark_step()
return model

def backward(self, loss, **kwargs):
if self.distributed_type != DistributedType.DEEPSPEED:
Expand Down
5 changes: 1 addition & 4 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ def __init__(self, cpu: bool = False, **kwargs):
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)

def wait_for_everyone(self):
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
xm.rendezvous("accelerate.utils.wait_for_everyone")
else:
super().wait_for_everyone()
xm.rendezvous("accelerate.utils.wait_for_everyone")


class NeuronAcceleratorState(AcceleratorState):
Expand Down
50 changes: 46 additions & 4 deletions optimum/neuron/accelerate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@

import functools
import gc
import inspect
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union

import torch
from transformers.modeling_utils import get_parameter_dtype

from ....utils import logging
from ...distributed.utils import named_parameters
from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere
from ...utils.patching import Patcher
from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors
from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla


logger = logging.get_logger(__name__)

if TYPE_CHECKING:
import os

Expand Down Expand Up @@ -191,6 +195,41 @@ def tie_parameters(model: Union["torch.nn.Module", "NxDPPModel"], tied_parameter
setattr(param_to_tie_parent_module, param_to_tie_name[1], param)


# TODO: @michaelbenayoun
# Needs to make it work in the general case or be deleted and only use `apply_activation_checkpointing`.
@requires_torch_xla
def patched_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
from torch_xla.utils.checkpoint import checkpoint

if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}

gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)

# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters

if not _is_using_old_format:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
else:
self.apply(functools.partial(self._set_gradient_checkpointing, value=True))
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)

if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()


@requires_neuronx_distributed
def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel"]):
from neuronx_distributed.pipeline import NxDPPModel
Expand All @@ -205,9 +244,12 @@ def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel"]

gradient_checkpointing_modules = set()
for module in modules:
if getattr(module, "gradient_checkpointing", False):
module.gradient_checkpointing = False
gradient_checkpointing_modules.add(module)
if isinstance(module, torch.nn.ModuleList):
for mod in module:
# TODO: @michaelbenayoun. Need to find a better way to identify the blocks to apply gradient
# checkpointing to.
if "Layer" in mod.__class__.__name__ or "Block" in mod.__class__.__name__:
gradient_checkpointing_modules.add(mod)

def check_fn(m: torch.nn.Module) -> bool:
return m in gradient_checkpointing_modules
Expand Down
4 changes: 0 additions & 4 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,6 @@ def should_parallelize_layer_predicate_func(layer):
skip_linear_weight_load=skip_linear_weight_load,
kv_size_multiplier=kv_size_multiplier,
)
xm.rendezvous("End of tensor parallelism")
if is_main_worker():
logger.info("Tensor parallelism done.")

Expand Down Expand Up @@ -708,8 +707,6 @@ def should_parallelize_layer_predicate_func(layer):
# Initialize or load the weights for the parallelized model if it was lazily loaded.
cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device)
gc.collect()
xm.rendezvous(f"weight_loading_and_initialization_{worker}")
xm.rendezvous("End of initalization")

if is_main_worker():
logger.info("Load and initialization of the weights done.")
Expand Down Expand Up @@ -750,7 +747,6 @@ def should_parallelize_layer_predicate_func(layer):
tracer_cls=OptimumNeuronFXTracer,
)

xm.rendezvous("End of pipeline paralellism")
if is_main_worker():
logger.info("Pipeline parallelism done.")

Expand Down
22 changes: 17 additions & 5 deletions optimum/neuron/distributed/parallelizers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Factory class mapping model architectures to their Parallelizer class."""

import importlib
from typing import Dict, List, Type, Union
from typing import Dict, List, Tuple, Type, Union

from transformers import PreTrainedModel

Expand Down Expand Up @@ -83,16 +83,27 @@ def _get_model_type(cls, model_type_or_model: Union[str, PreTrainedModel]) -> st
return model_type

@classmethod
def is_model_supported(cls, model_type_or_model: Union[str, PreTrainedModel]) -> bool:
def is_model_supported(cls, model_type_or_model: Union[str, PreTrainedModel]) -> Tuple[bool, bool, bool]:
"""
Returns `True` if the model can be parallelized, `False` otherwise.
Returns a tuple of 3 booleans where:
- The first element indicates if tensor parallelism can be used for this model,
- The second element indicates if sequence parallelism can be used on top of tensor parallelism for this model,
- The third element indicates if pipeline parallelism can be used for this model.
Args:
model_type_or_model (`Union[str, PreTrainedModel]`):
Either the model type or an instance of the model.
"""
model_type = cls._get_model_type(model_type_or_model)
return model_type in cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS
for_tp = model_type in cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS
if for_tp:
parallelizer = cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS[model_type]
for_sp = parallelizer.supports_sequence_parallelism()
for_pp = parallelizer.supports_pipeline_parallelism()
else:
for_sp = for_pp = False

return (for_tp, for_sp, for_pp)

@classmethod
def parallelizer_for_model(cls, model_type_or_model: Union[str, PreTrainedModel]) -> Type[Parallelizer]:
Expand All @@ -105,7 +116,8 @@ def parallelizer_for_model(cls, model_type_or_model: Union[str, PreTrainedModel]
"""
model_type = cls._get_model_type(model_type_or_model)
if not cls.is_model_supported(model_type_or_model):
is_tp_supported, _, _ = cls.is_model_supported(model_type_or_model)
if not is_tp_supported:
supported_models = ", ".join(cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS.keys())
raise NotImplementedError(
f"{model_type} is not supported for parallelization, supported models: {supported_models}"
Expand Down
Loading

0 comments on commit c7f1b7e

Please sign in to comment.