Skip to content

Commit

Permalink
Initial PEFT support (#612)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Jun 7, 2024
1 parent af0506f commit da9d261
Show file tree
Hide file tree
Showing 19 changed files with 452 additions and 115 deletions.
2 changes: 2 additions & 0 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"ModelParallelismPlugin",
],
"pipelines": ["pipeline"],
"utils": ["get_peft_model"],
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -94,6 +95,7 @@
from .pipelines import pipeline
from .trainers import NeuronTrainer, Seq2SeqNeuronTrainer
from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments
from .utils import get_peft_model

else:
import sys
Expand Down
25 changes: 22 additions & 3 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,20 @@
from accelerate.utils.operations import gather_object, recursively_apply
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import PreTrainedModel
from transformers.utils import is_peft_available

from ...utils import logging
from ..distributed import Parallelizer, ParallelizersManager
from ..utils import (
DynamicPatch,
ModelPatcher,
NeuronPeftModel,
Patcher,
is_neuronx_distributed_available,
is_torch_xla_available,
patch_within_function,
replace_class_in_inheritance_hierarchy,
)
from ..utils.misc import args_and_kwargs_to_kwargs_only, is_main_worker
from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla
Expand All @@ -66,8 +70,6 @@


if TYPE_CHECKING:
from transformers import PreTrainedModel

try:
from torch.optim.lr_scheduler import LRScheduler
except ImportError:
Expand Down Expand Up @@ -341,7 +343,7 @@ def patch_model_for_neuron(
),
)

if hasattr(model, "save_pretrained"):
if isinstance(model, PreTrainedModel):
patching_specs.append(
(
"save_pretrained",
Expand All @@ -367,6 +369,21 @@ def patch_model_for_neuron(

model_patcher = ModelPatcher(prepared_patching_specs, ignore_missing_attributes=True)
model_patcher.patch()

if is_peft_available():
from peft import PeftModel
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper

if isinstance(model, PeftModel):
replace_class_in_inheritance_hierarchy(model, PeftModel, NeuronPeftModel)
else:
for _, module in model.named_modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
raise ValueError(
"It appears that the model is using a PEFT method, please wrap your model with `PeftModel` "
"to make it work with `optimum-neuron`"
)
return model

@requires_neuronx_distributed
Expand Down Expand Up @@ -466,6 +483,8 @@ def prepare_model(
module._use_flash_attention_2 = False

if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
if isinstance(model, NeuronPeftModel):
raise NotImplementedError("PEFT is not supported with model parallelism for now.")
model = self._prepare_model_for_mp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
)
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/accelerate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def wrapper(*args, **kwargs):
with patcher:
output = orig_func(*args, **kwargs)
self.load_state_dict(orig_state_dict, assign=True)
xm.mark_step()
del cpu_state_dict
gc.collect()
return output
Expand Down
41 changes: 27 additions & 14 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@
has_length,
speed_metrics,
)
from transformers.utils import WEIGHTS_NAME, is_accelerate_available, is_apex_available, is_sagemaker_mp_enabled
from transformers.utils import (
WEIGHTS_NAME,
is_accelerate_available,
is_apex_available,
is_peft_available,
is_sagemaker_mp_enabled,
)

from ..utils import logging
from .accelerate import NeuronAccelerator, NeuronDistributedType
Expand Down Expand Up @@ -436,18 +442,21 @@ def _reduce_loss(self, tr_loss: torch.Tensor) -> torch.Tensor:
else:
reduced_tr_loss = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div)

# reset tr_loss to zero
tr_loss.zero_()

return reduced_tr_loss

def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
# We always reduce the loss, even when we do not use it to avoid a new graph.
# This communication is not costly.
reduced_tr_loss = self._reduce_loss(tr_loss)

if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
# reset tr_loss to zero
tr_loss.zero_()

def log_closure(self, tr_loss, grad_norm):
def log_closure(self, reduced_tr_loss, grad_norm):
if is_main_worker_for_metrics():
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.to("cpu").item()
tr_loss_scalar = reduced_tr_loss.to("cpu").item()

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()
Expand All @@ -462,7 +471,7 @@ def log_closure(self, tr_loss, grad_norm):
self.store_flos()
self.log(logs)

xm.add_step_closure(log_closure, (self, tr_loss, grad_norm))
xm.add_step_closure(log_closure, (self, reduced_tr_loss, grad_norm))

metrics = None
if self.control.should_evaluate:
Expand Down Expand Up @@ -518,8 +527,15 @@ def _save_xla(self, output_dir: Optional[str] = None):
num_local_ranks_per_step=self.accelerator.state.mp_plugin.num_local_ranks_per_step,
)
else:
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
if is_peft_available():
from peft import PeftModel

supported_classes = (PreTrainedModel, PeftModel)
else:
supported_classes = (PreTrainedModel,)

if not isinstance(self.model, supported_classes):
if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
Expand Down Expand Up @@ -981,6 +997,7 @@ def _inner_training_loop(
f"{tr_loss_step.device}"
)
tr_loss += tr_loss_step
print("tr loss", tr_loss)

self.current_flos += float(self.floating_point_ops(inputs))

Expand Down Expand Up @@ -1032,11 +1049,7 @@ def _inner_training_loop(
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

reduced_tr_loss = self._reduce_loss(tr_loss)
self._maybe_log_save_evaluate(
reduced_tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval
)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

Expand Down
12 changes: 11 additions & 1 deletion optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
"Patcher",
"patch_everywhere",
"patch_within_function",
"replace_class_in_inheritance_hierarchy",
],
"peft_utils": ["NeuronPeftModel", "get_peft_model"],
"training_utils": [
"is_model_officially_supported",
"patch_transformers_for_neuron_sdk",
Expand Down Expand Up @@ -103,7 +105,15 @@
get_attention_scores_sd15,
get_attention_scores_sdxl,
)
from .patching import DynamicPatch, ModelPatcher, Patcher, patch_everywhere, patch_within_function
from .patching import (
DynamicPatch,
ModelPatcher,
Patcher,
patch_everywhere,
patch_within_function,
replace_class_in_inheritance_hierarchy,
)
from .peft_utils import NeuronPeftModel, get_peft_model
from .training_utils import (
is_model_officially_supported,
patch_transformers_for_neuron_sdk,
Expand Down
27 changes: 26 additions & 1 deletion optimum/neuron/utils/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import inspect
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, Union


if TYPE_CHECKING:
Expand Down Expand Up @@ -221,3 +221,28 @@ def patch_everywhere(attribute_name: str, patch: Any, module_name_prefix: Option
continue
if hasattr(module, attribute_name):
setattr(module, attribute_name, patch)


def replace_class_in_inheritance_hierarchy(obj: Any, orig_cls: Type, replacement_cls: Type):
"""
Inspects the inheritance hierarchy of `obj` and replace `orig_cls` by `replacement_cls` if found.
"""
to_visit = [obj.__class__]
should_stop = False
while to_visit and not should_stop:
cls = to_visit.pop(0)
if cls is object:
continue
bases = cls.__bases__
new_bases = []
for base in bases:
to_visit.append(base)
if base == orig_cls:
new_bases.append(replacement_cls)
should_stop = True
elif base == replacement_cls:
should_stop = True
new_bases.append(base)
else:
new_bases.append(base)
cls.__bases__ = tuple(new_bases)
117 changes: 117 additions & 0 deletions optimum/neuron/utils/peft_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to the PEFT library and support."""
import functools
import gc
from typing import TYPE_CHECKING, Any, List, Optional, Union

from transformers.utils import is_peft_available

from .patching import replace_class_in_inheritance_hierarchy
from .require_utils import requires_neuronx_distributed


if is_peft_available():
from peft import PeftModel
from peft import get_peft_model as orig_get_peft_model
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict

else:

class PeftModel:
pass

def orig_get_peft_model(*args, **kwargs):
pass

def get_peft_model_state_dict(*args, **kwargs):
pass

def set_peft_model_state_dict(*args, **kwargs):
pass


if TYPE_CHECKING:
pass


class NeuronPeftModel(PeftModel):
@requires_neuronx_distributed
def save_pretrained(
self,
save_directory: str,
safe_serialization: bool = True,
selected_adapters: Optional[List[str]] = None,
save_embedding_layers: Union[str, bool] = "auto",
is_main_process: bool = True,
convert_pissa_to_lora: Optional[str] = None,
**kwargs: Any,
):
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
model_parallel_is_initialized,
)
from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu

if model_parallel_is_initialized():
should_write_data = get_data_parallel_rank() == 0
else:
should_write_data = xm.is_master_ordinal(local=True)

if selected_adapters is None:
selected_adapters = list(self.peft_config.keys())

orig_state_dicts = {}
cpu_state_dicts = {}
for adapter_name in selected_adapters:
state_dict = get_peft_model_state_dict(
self,
state_dict=kwargs.get("state_dict", None),
adapter_name=adapter_name,
save_embedding_layers=save_embedding_layers,
)
cpu_state_dict = move_all_tensor_to_cpu(state_dict, convert=should_write_data)
orig_state_dicts[adapter_name] = state_dict
cpu_state_dicts[adapter_name] = cpu_state_dict

for adapter_name, state_dict in cpu_state_dicts.items():
set_peft_model_state_dict(self, state_dict, adapter_name=adapter_name)

output = None
if should_write_data:
output = super().save_pretrained(
save_directory,
safe_serialization=safe_serialization,
selected_adapters=selected_adapters,
save_embedding_layers=save_embedding_layers,
is_main_process=is_main_process,
convert_pissa_to_lora=convert_pissa_to_lora,
)

for adapter_name, state_dict in orig_state_dicts.items():
set_peft_model_state_dict(self, state_dict, adapter_name=adapter_name)

xm.mark_step()
del cpu_state_dicts
gc.collect()
return output


@functools.wraps(orig_get_peft_model)
def get_peft_model(*args, **kwargs):
peft_model = orig_get_peft_model(*args, **kwargs)
replace_class_in_inheritance_hierarchy(peft_model, PeftModel, NeuronPeftModel)
return peft_model
6 changes: 4 additions & 2 deletions optimum/neuron/utils/require_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
from typing import Any, Callable, Dict

from transformers.utils import is_safetensors_available
from transformers.utils import is_peft_available, is_safetensors_available

from .import_utils import (
is_neuronx_distributed_available,
Expand All @@ -27,12 +27,13 @@
)


_AVAILABILITIES: Dict[str, Callable[[], bool]] = {
_AVAILABILITIES: Dict[str, Callable] = {
"safetensors": is_safetensors_available,
"torch_xla": is_torch_xla_available,
"neuronx_distributed": is_neuronx_distributed_available,
"torch_neuronx": is_torch_neuronx_available,
"transformers_neuronx": is_transformers_neuronx_available,
"peft": is_peft_available,
}


Expand All @@ -59,3 +60,4 @@ def wrapper(*args, **kwargs):
requires_neuronx_distributed = _create_requires_function("neuronx_distributed")
requires_torch_neuronx = _create_requires_function("torch_neuronx")
requires_transformers_neuronx = _create_requires_function("transformers_neuronx")
requires_peft = _create_requires_function("peft")
Loading

0 comments on commit da9d261

Please sign in to comment.