Skip to content

Commit

Permalink
[WIP] peft support
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed May 28, 2024
1 parent 2561c33 commit 7e64a5d
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 27 deletions.
5 changes: 2 additions & 3 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from accelerate.checkpointing import save_accelerator_state, save_custom_state
from accelerate.utils import AutocastKwargs, DistributedType
from accelerate.utils.operations import gather_object, recursively_apply
from transformers import PreTrainedModel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

Expand Down Expand Up @@ -66,8 +67,6 @@


if TYPE_CHECKING:
from transformers import PreTrainedModel

try:
from torch.optim.lr_scheduler import LRScheduler
except ImportError:
Expand Down Expand Up @@ -341,7 +340,7 @@ def patch_model_for_neuron(
),
)

if hasattr(model, "save_pretrained"):
if isinstance(model, PreTrainedModel):
patching_specs.append(
(
"save_pretrained",
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/accelerate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def wrapper(*args, **kwargs):
with patcher:
output = orig_func(*args, **kwargs)
self.load_state_dict(orig_state_dict, assign=True)
xm.mark_step()
del cpu_state_dict
gc.collect()
return output
Expand Down
19 changes: 16 additions & 3 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@
has_length,
speed_metrics,
)
from transformers.utils import WEIGHTS_NAME, is_accelerate_available, is_apex_available, is_sagemaker_mp_enabled
from transformers.utils import (
WEIGHTS_NAME,
is_accelerate_available,
is_apex_available,
is_peft_available,
is_sagemaker_mp_enabled,
)

from ..utils import logging
from .accelerate import NeuronAccelerator, NeuronDistributedType
Expand Down Expand Up @@ -518,8 +524,15 @@ def _save_xla(self, output_dir: Optional[str] = None):
num_local_ranks_per_step=self.accelerator.state.mp_plugin.num_local_ranks_per_step,
)
else:
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
if is_peft_available():
from peft import PeftModel

supported_classes = (PreTrainedModel, PeftModel)
else:
supported_classes = (PreTrainedModel,)

if not isinstance(self.model, supported_classes):
if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
Expand Down
27 changes: 26 additions & 1 deletion optimum/neuron/utils/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import inspect
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, Type


if TYPE_CHECKING:
Expand Down Expand Up @@ -221,3 +221,28 @@ def patch_everywhere(attribute_name: str, patch: Any, module_name_prefix: Option
continue
if hasattr(module, attribute_name):
setattr(module, attribute_name, patch)


def replace_class_in_inheritance_hierarchy(obj: Any, orig_cls: Type, replacement_cls: Type):
"""
Inspects the inheritance hierarchy of `obj` and replace `orig_cls` by `replacement_cls` if found.
"""
to_visit = [obj.__class__]
should_stop = False
while to_visit and not should_stop:
cls = to_visit.pop(0)
if cls is object:
continue
bases = cls.__bases__
new_bases = []
for base in bases:
to_visit.append(base)
if base == orig_cls:
new_bases.append(replacement_cls)
should_stop = True
elif base == replacement_cls:
should_stop = True
new_bases.append(base)
else:
new_bases.append(base)
cls.__bases__ = tuple(new_bases)
115 changes: 115 additions & 0 deletions optimum/neuron/utils/peft_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to the PEFT library and support."""
import gc
import functools
from typing import Any, List, Optional, Union

from transformers.utils import is_peft_available

from .patching import Patcher, replace_class_in_inheritance_hierarchy
from .require_utils import requires_neuronx_distributed


if is_peft_available():
from peft import PeftModel, get_peft_model as orig_get_peft_model
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict

else:

class PeftModel:
pass

def orig_get_peft_model(*args, **kwargs):
pass

def get_peft_model_state_dict(*args, **kwargs):
pass

def set_peft_model_state_dict(*args, **kwargs):
pass


class NeuronPeftModel(PeftModel):
@requires_neuronx_distributed
def save_pretrained(
self,
save_directory: str,
safe_serialization: bool = True,
selected_adapters: Optional[List[str]] = None,
save_embedding_layers: Union[str, bool] = "auto",
is_main_process: bool = True,
convert_pissa_to_lora: Optional[str] = None,
**kwargs: Any,
):
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
model_parallel_is_initialized,
)
from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu

if model_parallel_is_initialized():
should_write_data = get_data_parallel_rank() == 0
else:
should_write_data = xm.is_master_ordinal(local=True)

if selected_adapters is None:
selected_adapters = list(self.peft_config.keys())

orig_state_dicts = {}
cpu_state_dicts = {}
for adapter_name in selected_adapters:
state_dict = get_peft_model_state_dict(
self,
state_dict=kwargs.get("state_dict", None),
adapter_name=adapter_name,
save_embedding_layers=save_embedding_layers,
)
cpu_state_dict = move_all_tensor_to_cpu(state_dict, convert=should_write_data)
orig_state_dicts[adapter_name] = state_dict
cpu_state_dicts[adapter_name] = cpu_state_dict

for adapter_name, state_dict in cpu_state_dicts.items():
set_peft_model_state_dict(self, state_dict, adapter_name=adapter_name)

output = None
if should_write_data:
output = super().save_pretrained(
save_directory,
safe_serialization=safe_serialization,
selected_adapters=selected_adapters,
save_embedding_layers=save_embedding_layers,
is_main_process=is_main_process,
convert_pissa_to_lora=convert_pissa_to_lora,
)

for adapter_name, state_dict in orig_state_dicts.items():
set_peft_model_state_dict(self, state_dict, adapter_name=adapter_name)

xm.mark_step()
del cpu_state_dicts
gc.collect()
return output



@functools.wraps(orig_get_peft_model)
def get_peft_model(*args, **kwargs):
peft_model = orig_get_peft_model(*args, **kwargs)
replace_class_in_inheritance_hierarchy(peft_model, PeftModel, NeuronPeftModel)
return peft_model


23 changes: 3 additions & 20 deletions optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Training utilities"""

from typing import TYPE_CHECKING, List, Optional, Type, Union
from typing import TYPE_CHECKING, List, Optional, Type, Union, Any

import torch
import transformers
Expand Down Expand Up @@ -46,6 +46,7 @@
from ..generation import GeneralNeuronGenerationMixin, NeuronGenerationMixin
from . import is_neuronx_distributed_available
from .require_utils import requires_neuronx_distributed, requires_torch_xla
from .patching import replace_class_in_inheritance_hierarchy


if is_neuronx_distributed_available():
Expand Down Expand Up @@ -140,25 +141,7 @@ def patch_generation_mixin_to_neuron_generation_mixin(
Changes the vanilla `GenerationMixin` class from Transformers to `neuron_generation_mixin_cls` in the model's
inheritance. This allows to make the model Neuron-compatible for generation without much hassle.
"""
to_visit = [model.__class__]
should_stop = False
while to_visit and not should_stop:
cls = to_visit.pop(0)
if cls is object:
continue
bases = cls.__bases__
new_bases = []
for base in bases:
to_visit.append(base)
if base == GenerationMixin:
new_bases.append(neuron_generation_mixin_cls)
should_stop = True
elif base == neuron_generation_mixin_cls:
should_stop = True
new_bases.append(base)
else:
new_bases.append(base)
cls.__bases__ = tuple(new_bases)
return replace_class_in_inheritance_hierarchy(model, GenerationMixin, neuron_generation_mixin_cls)


def patch_generation_mixin_to_general_neuron_generation_mixin(model: "PreTrainedModel"):
Expand Down

0 comments on commit 7e64a5d

Please sign in to comment.