diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 1b3d935b8..cfcb76d51 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -61,6 +61,7 @@ "ModelParallelismPlugin", ], "pipelines": ["pipeline"], + "utils": ["get_peft_model"], } if TYPE_CHECKING: @@ -94,6 +95,7 @@ from .pipelines import pipeline from .trainers import NeuronTrainer, Seq2SeqNeuronTrainer from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments + from .utils import get_peft_model else: import sys diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 25d4499c8..94d8118c2 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -32,16 +32,20 @@ from accelerate.utils.operations import gather_object, recursively_apply from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from transformers import PreTrainedModel +from transformers.utils import is_peft_available from ...utils import logging from ..distributed import Parallelizer, ParallelizersManager from ..utils import ( DynamicPatch, ModelPatcher, + NeuronPeftModel, Patcher, is_neuronx_distributed_available, is_torch_xla_available, patch_within_function, + replace_class_in_inheritance_hierarchy, ) from ..utils.misc import args_and_kwargs_to_kwargs_only, is_main_worker from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla @@ -66,8 +70,6 @@ if TYPE_CHECKING: - from transformers import PreTrainedModel - try: from torch.optim.lr_scheduler import LRScheduler except ImportError: @@ -341,7 +343,7 @@ def patch_model_for_neuron( ), ) - if hasattr(model, "save_pretrained"): + if isinstance(model, PreTrainedModel): patching_specs.append( ( "save_pretrained", @@ -367,6 +369,21 @@ def patch_model_for_neuron( model_patcher = ModelPatcher(prepared_patching_specs, ignore_missing_attributes=True) model_patcher.patch() + + if is_peft_available(): + from peft import PeftModel + from peft.tuners.tuners_utils import BaseTunerLayer + from peft.utils import ModulesToSaveWrapper + + if isinstance(model, PeftModel): + replace_class_in_inheritance_hierarchy(model, PeftModel, NeuronPeftModel) + else: + for _, module in model.named_modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + raise ValueError( + "It appears that the model is using a PEFT method, please wrap your model with `PeftModel` " + "to make it work with `optimum-neuron`" + ) return model @requires_neuronx_distributed @@ -466,6 +483,8 @@ def prepare_model( module._use_flash_attention_2 = False if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: + if isinstance(model, NeuronPeftModel): + raise NotImplementedError("PEFT is not supported with model parallelism for now.") model = self._prepare_model_for_mp( model, device_placement=device_placement, evaluation_mode=evaluation_mode ) diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index 2a564f7dc..6bda3027a 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -143,6 +143,7 @@ def wrapper(*args, **kwargs): with patcher: output = orig_func(*args, **kwargs) self.load_state_dict(orig_state_dict, assign=True) + xm.mark_step() del cpu_state_dict gc.collect() return output diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 4e5c03478..998c00ba0 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -59,7 +59,13 @@ has_length, speed_metrics, ) -from transformers.utils import WEIGHTS_NAME, is_accelerate_available, is_apex_available, is_sagemaker_mp_enabled +from transformers.utils import ( + WEIGHTS_NAME, + is_accelerate_available, + is_apex_available, + is_peft_available, + is_sagemaker_mp_enabled, +) from ..utils import logging from .accelerate import NeuronAccelerator, NeuronDistributedType @@ -436,18 +442,21 @@ def _reduce_loss(self, tr_loss: torch.Tensor) -> torch.Tensor: else: reduced_tr_loss = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div) - # reset tr_loss to zero - tr_loss.zero_() - return reduced_tr_loss def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): + # We always reduce the loss, even when we do not use it to avoid a new graph. + # This communication is not costly. + reduced_tr_loss = self._reduce_loss(tr_loss) + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + # reset tr_loss to zero + tr_loss.zero_() - def log_closure(self, tr_loss, grad_norm): + def log_closure(self, reduced_tr_loss, grad_norm): if is_main_worker_for_metrics(): logs: Dict[str, float] = {} - tr_loss_scalar = tr_loss.to("cpu").item() + tr_loss_scalar = reduced_tr_loss.to("cpu").item() logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["learning_rate"] = self._get_learning_rate() @@ -462,7 +471,7 @@ def log_closure(self, tr_loss, grad_norm): self.store_flos() self.log(logs) - xm.add_step_closure(log_closure, (self, tr_loss, grad_norm)) + xm.add_step_closure(log_closure, (self, reduced_tr_loss, grad_norm)) metrics = None if self.control.should_evaluate: @@ -518,8 +527,15 @@ def _save_xla(self, output_dir: Optional[str] = None): num_local_ranks_per_step=self.accelerator.state.mp_plugin.num_local_ranks_per_step, ) else: - if not isinstance(self.model, PreTrainedModel): - if isinstance(unwrap_model(self.model), PreTrainedModel): + if is_peft_available(): + from peft import PeftModel + + supported_classes = (PreTrainedModel, PeftModel) + else: + supported_classes = (PreTrainedModel,) + + if not isinstance(self.model, supported_classes): + if isinstance(unwrap_model(self.model), supported_classes): unwrap_model(self.model).save_pretrained( output_dir, is_main_process=self.args.should_save, @@ -981,6 +997,7 @@ def _inner_training_loop( f"{tr_loss_step.device}" ) tr_loss += tr_loss_step + print("tr loss", tr_loss) self.current_flos += float(self.floating_point_ops(inputs)) @@ -1032,11 +1049,7 @@ def _inner_training_loop( self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) - - reduced_tr_loss = self._reduce_loss(tr_loss) - self._maybe_log_save_evaluate( - reduced_tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval - ) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 45c7c3e8d..d89159eec 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -60,7 +60,9 @@ "Patcher", "patch_everywhere", "patch_within_function", + "replace_class_in_inheritance_hierarchy", ], + "peft_utils": ["NeuronPeftModel", "get_peft_model"], "training_utils": [ "is_model_officially_supported", "patch_transformers_for_neuron_sdk", @@ -103,7 +105,15 @@ get_attention_scores_sd15, get_attention_scores_sdxl, ) - from .patching import DynamicPatch, ModelPatcher, Patcher, patch_everywhere, patch_within_function + from .patching import ( + DynamicPatch, + ModelPatcher, + Patcher, + patch_everywhere, + patch_within_function, + replace_class_in_inheritance_hierarchy, + ) + from .peft_utils import NeuronPeftModel, get_peft_model from .training_utils import ( is_model_officially_supported, patch_transformers_for_neuron_sdk, diff --git a/optimum/neuron/utils/patching.py b/optimum/neuron/utils/patching.py index 5a455c5d9..adcd1a8c2 100644 --- a/optimum/neuron/utils/patching.py +++ b/optimum/neuron/utils/patching.py @@ -19,7 +19,7 @@ import inspect import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, Union if TYPE_CHECKING: @@ -221,3 +221,28 @@ def patch_everywhere(attribute_name: str, patch: Any, module_name_prefix: Option continue if hasattr(module, attribute_name): setattr(module, attribute_name, patch) + + +def replace_class_in_inheritance_hierarchy(obj: Any, orig_cls: Type, replacement_cls: Type): + """ + Inspects the inheritance hierarchy of `obj` and replace `orig_cls` by `replacement_cls` if found. + """ + to_visit = [obj.__class__] + should_stop = False + while to_visit and not should_stop: + cls = to_visit.pop(0) + if cls is object: + continue + bases = cls.__bases__ + new_bases = [] + for base in bases: + to_visit.append(base) + if base == orig_cls: + new_bases.append(replacement_cls) + should_stop = True + elif base == replacement_cls: + should_stop = True + new_bases.append(base) + else: + new_bases.append(base) + cls.__bases__ = tuple(new_bases) diff --git a/optimum/neuron/utils/peft_utils.py b/optimum/neuron/utils/peft_utils.py new file mode 100644 index 000000000..6302711ca --- /dev/null +++ b/optimum/neuron/utils/peft_utils.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities related to the PEFT library and support.""" +import functools +import gc +from typing import TYPE_CHECKING, Any, List, Optional, Union + +from transformers.utils import is_peft_available + +from .patching import replace_class_in_inheritance_hierarchy +from .require_utils import requires_neuronx_distributed + + +if is_peft_available(): + from peft import PeftModel + from peft import get_peft_model as orig_get_peft_model + from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict + +else: + + class PeftModel: + pass + + def orig_get_peft_model(*args, **kwargs): + pass + + def get_peft_model_state_dict(*args, **kwargs): + pass + + def set_peft_model_state_dict(*args, **kwargs): + pass + + +if TYPE_CHECKING: + pass + + +class NeuronPeftModel(PeftModel): + @requires_neuronx_distributed + def save_pretrained( + self, + save_directory: str, + safe_serialization: bool = True, + selected_adapters: Optional[List[str]] = None, + save_embedding_layers: Union[str, bool] = "auto", + is_main_process: bool = True, + convert_pissa_to_lora: Optional[str] = None, + **kwargs: Any, + ): + import torch_xla.core.xla_model as xm + from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_rank, + model_parallel_is_initialized, + ) + from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu + + if model_parallel_is_initialized(): + should_write_data = get_data_parallel_rank() == 0 + else: + should_write_data = xm.is_master_ordinal(local=True) + + if selected_adapters is None: + selected_adapters = list(self.peft_config.keys()) + + orig_state_dicts = {} + cpu_state_dicts = {} + for adapter_name in selected_adapters: + state_dict = get_peft_model_state_dict( + self, + state_dict=kwargs.get("state_dict", None), + adapter_name=adapter_name, + save_embedding_layers=save_embedding_layers, + ) + cpu_state_dict = move_all_tensor_to_cpu(state_dict, convert=should_write_data) + orig_state_dicts[adapter_name] = state_dict + cpu_state_dicts[adapter_name] = cpu_state_dict + + for adapter_name, state_dict in cpu_state_dicts.items(): + set_peft_model_state_dict(self, state_dict, adapter_name=adapter_name) + + output = None + if should_write_data: + output = super().save_pretrained( + save_directory, + safe_serialization=safe_serialization, + selected_adapters=selected_adapters, + save_embedding_layers=save_embedding_layers, + is_main_process=is_main_process, + convert_pissa_to_lora=convert_pissa_to_lora, + ) + + for adapter_name, state_dict in orig_state_dicts.items(): + set_peft_model_state_dict(self, state_dict, adapter_name=adapter_name) + + xm.mark_step() + del cpu_state_dicts + gc.collect() + return output + + +@functools.wraps(orig_get_peft_model) +def get_peft_model(*args, **kwargs): + peft_model = orig_get_peft_model(*args, **kwargs) + replace_class_in_inheritance_hierarchy(peft_model, PeftModel, NeuronPeftModel) + return peft_model diff --git a/optimum/neuron/utils/require_utils.py b/optimum/neuron/utils/require_utils.py index d828ebe89..df9f68313 100644 --- a/optimum/neuron/utils/require_utils.py +++ b/optimum/neuron/utils/require_utils.py @@ -17,7 +17,7 @@ import functools from typing import Any, Callable, Dict -from transformers.utils import is_safetensors_available +from transformers.utils import is_peft_available, is_safetensors_available from .import_utils import ( is_neuronx_distributed_available, @@ -27,12 +27,13 @@ ) -_AVAILABILITIES: Dict[str, Callable[[], bool]] = { +_AVAILABILITIES: Dict[str, Callable] = { "safetensors": is_safetensors_available, "torch_xla": is_torch_xla_available, "neuronx_distributed": is_neuronx_distributed_available, "torch_neuronx": is_torch_neuronx_available, "transformers_neuronx": is_transformers_neuronx_available, + "peft": is_peft_available, } @@ -59,3 +60,4 @@ def wrapper(*args, **kwargs): requires_neuronx_distributed = _create_requires_function("neuronx_distributed") requires_torch_neuronx = _create_requires_function("torch_neuronx") requires_transformers_neuronx = _create_requires_function("transformers_neuronx") +requires_peft = _create_requires_function("peft") diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index 1988dfe92..40a567b37 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -45,6 +45,7 @@ from ...utils.logging import set_verbosity as set_verbosity_optimum from ..generation import GeneralNeuronGenerationMixin, NeuronGenerationMixin from . import is_neuronx_distributed_available +from .patching import replace_class_in_inheritance_hierarchy from .require_utils import requires_neuronx_distributed, requires_torch_xla @@ -140,25 +141,7 @@ def patch_generation_mixin_to_neuron_generation_mixin( Changes the vanilla `GenerationMixin` class from Transformers to `neuron_generation_mixin_cls` in the model's inheritance. This allows to make the model Neuron-compatible for generation without much hassle. """ - to_visit = [model.__class__] - should_stop = False - while to_visit and not should_stop: - cls = to_visit.pop(0) - if cls is object: - continue - bases = cls.__bases__ - new_bases = [] - for base in bases: - to_visit.append(base) - if base == GenerationMixin: - new_bases.append(neuron_generation_mixin_cls) - should_stop = True - elif base == neuron_generation_mixin_cls: - should_stop = True - new_bases.append(base) - else: - new_bases.append(base) - cls.__bases__ = tuple(new_bases) + return replace_class_in_inheritance_hierarchy(model, GenerationMixin, neuron_generation_mixin_cls) def patch_generation_mixin_to_general_neuron_generation_mixin(model: "PreTrainedModel"): diff --git a/tests/distributed/test_common.py b/tests/distributed/test_common.py index 5bc70ffcd..e2e23236b 100644 --- a/tests/distributed/test_common.py +++ b/tests/distributed/test_common.py @@ -36,8 +36,8 @@ from optimum.neuron.utils.testing_utils import is_trainium_test from .. import DistributedTest -from ..utils import create_static_seed_patcher, get_model -from .utils import create_accelerator_for_mp, get_model_inputs +from ..utils import create_accelerator, create_static_seed_patcher, get_model +from .utils import get_model_inputs if is_torch_xla_available(): @@ -159,7 +159,7 @@ def test_optimizer_parameters_match_model_parameters( model = get_tiny_llama_model(tp_size=tp_size, pp_size=pp_size, lazy_load=lazy_load) optimizer = get_optimizer(model, lazy_optimizer, with_groups) - accelerator = create_accelerator_for_mp(tp_size, pp_size, zero_1=zero_1) + accelerator = create_accelerator(tp_size, pp_size, zero_1=zero_1) if tp_size > 1 or pp_size > 1: assert accelerator.state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM @@ -198,7 +198,7 @@ def test_optimizer_step(self, zero_1, gradient_accumulation_steps, max_grad_norm optimizer = get_optimizer(model, with_groups=False) - accelerator = create_accelerator_for_mp( + accelerator = create_accelerator( tp_size, pp_size, zero_1=zero_1, gradient_accumulation_steps=gradient_accumulation_steps ) @@ -302,7 +302,7 @@ def test_lazy_load(self, from_config, parallel_sizes): orig_parameters: Dict[str, torch.nn.Parameter] = dict(model.named_parameters()) - accelerator = create_accelerator_for_mp(tp_size, pp_size) + accelerator = create_accelerator(tp_size, pp_size) lazy_model = get_tiny_llama_model( tp_size=tp_size, pp_size=pp_size, lazy_load=True, from_config=from_config, use_static_seed_patcher=True ) @@ -349,7 +349,7 @@ def test_save_model_and_load_model(self, parallel_sizes, tmpdir, monkeypatch): model = get_tiny_llama_model(tp_size=tp_size, pp_size=pp_size, lazy_load=False, add_random_noise=True) - accelerator = create_accelerator_for_mp(tp_size, pp_size) + accelerator = create_accelerator(tp_size, pp_size) model = accelerator.prepare(model) accelerator.save_state(tmpdir.as_posix()) accelerator.state._reset_state(reset_partial_state=True) @@ -382,7 +382,7 @@ def test_save_model_and_load_model(self, parallel_sizes, tmpdir, monkeypatch): # Making sure that we end-up with a different model when starting over. new_model = get_tiny_llama_model(tp_size=tp_size, pp_size=pp_size, lazy_load=False, add_random_noise=True) - new_accelerator = create_accelerator_for_mp(tp_size, pp_size) + new_accelerator = create_accelerator(tp_size, pp_size) new_model = new_accelerator.prepare(new_model) new_accelerator.state._reset_state(reset_partial_state=True) del new_accelerator @@ -401,7 +401,7 @@ def test_save_model_and_load_model(self, parallel_sizes, tmpdir, monkeypatch): # Checking that when providing a checkpoint, we end-up with the same model as the original. new_model = get_tiny_llama_model(tp_size=tp_size, pp_size=pp_size, lazy_load=False, add_random_noise=True) - new_accelerator = create_accelerator_for_mp(tp_size, pp_size, checkpoint_dir=tmpdir) + new_accelerator = create_accelerator(tp_size, pp_size, checkpoint_dir=tmpdir) new_model = new_accelerator.prepare(new_model) # If there is no model parallelism, the checkpoint weights will not be loaded automatically since we do not @@ -463,9 +463,7 @@ def test_consolidate_model_parallel_checkpoints( # Saving to pytorch instead of safetensors because it fails otherwise for pickling issues with distributed tests. orig_model.save_pretrained(orig_model_path, safe_serialization=False) - accelerator = create_accelerator_for_mp( - tp_size, pp_size, kv_size_multiplier=kv_size_multiplier, use_xser=use_xser - ) + accelerator = create_accelerator(tp_size, pp_size, kv_size_multiplier=kv_size_multiplier, use_xser=use_xser) _ = accelerator.prepare(orig_model) output_dir = Path(tmpdir) / "parallel_model" diff --git a/tests/distributed/test_model_parallelization.py b/tests/distributed/test_model_parallelization.py index 9961d10b9..582833cc3 100644 --- a/tests/distributed/test_model_parallelization.py +++ b/tests/distributed/test_model_parallelization.py @@ -56,8 +56,8 @@ from optimum.neuron.utils.testing_utils import is_trainium_test from .. import DistributedTest -from ..utils import SEED, create_static_seed_patcher, get_model -from .utils import create_accelerator_for_mp, get_model_inputs +from ..utils import SEED, create_accelerator, create_static_seed_patcher, get_model +from .utils import get_model_inputs if is_torch_xla_available(): @@ -298,7 +298,7 @@ def _parallel_model_matches_original_model( use_static_seed_patcher=True, ) - accelerator = create_accelerator_for_mp( + accelerator = create_accelerator( tp_size, pp_size, parallelize_embeddings=parallelize_embeddings, diff --git a/tests/distributed/utils.py b/tests/distributed/utils.py index f9790adbd..6d5c39822 100644 --- a/tests/distributed/utils.py +++ b/tests/distributed/utils.py @@ -15,7 +15,6 @@ """Utilities for tests distributed.""" import inspect -from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Union import torch @@ -39,7 +38,6 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, ) -from optimum.neuron import ModelParallelismPlugin, NeuronAccelerator from optimum.neuron.utils.require_utils import requires_neuronx_distributed, requires_torch_xla @@ -258,28 +256,3 @@ def get_model_inputs( ) inputs[name] = tensor return inputs - - -def create_accelerator_for_mp( - tp_size: int, - pp_size: int, - zero_1: bool = False, - gradient_accumulation_steps: int = 1, - parallelize_embeddings: bool = True, - sequence_parallel_enabled: bool = True, - kv_size_multiplier: Optional[int] = None, - checkpoint_dir: Optional[Union[Path, str]] = None, - use_xser: bool = True, -) -> NeuronAccelerator: - mp_plugin = ModelParallelismPlugin( - tensor_parallel_size=tp_size, - kv_size_multiplier=kv_size_multiplier, - parallelize_embeddings=parallelize_embeddings, - sequence_parallel_enabled=sequence_parallel_enabled, - pipeline_parallel_size=pp_size, - checkpoint_dir=checkpoint_dir, - use_xser=use_xser, - ) - return NeuronAccelerator( - mp_plugin=mp_plugin, zero_1=zero_1, gradient_accumulation_steps=gradient_accumulation_steps - ) diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index 5e84ebf33..f3a0bfda1 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -75,7 +75,7 @@ class NeuronModelIntegrationTestMixin(unittest.TestCase): - USER = "optimum" + USER = "optimum-internal-testing" MODEL_ID = None NEURON_MODEL_REPO = None NEURON_MODEL_CLASS = None diff --git a/tests/peft/__init__.py b/tests/peft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/peft/test_peft_training.py b/tests/peft/test_peft_training.py new file mode 100644 index 000000000..30abe3120 --- /dev/null +++ b/tests/peft/test_peft_training.py @@ -0,0 +1,157 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests related to PEFT integration.""" + +import json +from pathlib import Path + +import pytest +import torch +from peft import AutoPeftModelForCausalLM, LoraConfig, PeftModel +from peft import get_peft_model as orig_get_peft_model +from safetensors.torch import load_file +from transformers import LlamaForCausalLM + +from optimum.neuron import NeuronTrainer, NeuronTrainingArguments, get_peft_model +from optimum.neuron.utils.peft_utils import NeuronPeftModel +from optimum.neuron.utils.testing_utils import is_trainium_test + +from .. import DistributedTest +from ..utils import ( + create_accelerator, + create_dummy_causal_lm_dataset, + create_static_seed_patcher, + default_data_collator_for_causal_lm, + get_tokenizer_and_tiny_llama_model, +) + + +def get_peft_config(): + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + return LoraConfig( + r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" + ) + + +def test_get_peft_model(): + peft_config = get_peft_config() + _, model = get_tokenizer_and_tiny_llama_model() + orig_peft_model = orig_get_peft_model(model, peft_config) + + assert isinstance(orig_peft_model, PeftModel) + assert not isinstance(orig_peft_model, NeuronPeftModel) + + _, model = get_tokenizer_and_tiny_llama_model() + peft_model = get_peft_model(model, peft_config) + + assert isinstance(peft_model, NeuronPeftModel) + + +@is_trainium_test +class TestPeft(DistributedTest): + @pytest.fixture( + scope="class", + params=[[2, 1, 1]], + ids=["dp=2"], + ) + def parallel_sizes(self, request): + return request.param + + @pytest.mark.world_size(2) + def test_peft_model_is_converted_to_neuron_peft_model(self): + model = AutoPeftModelForCausalLM.from_pretrained("peft-internal-testing/tiny-random-BertModel-lora") + assert isinstance(model, PeftModel) + accelerator = create_accelerator(1, 1) + model = accelerator.prepare(model) + assert isinstance(model, NeuronPeftModel) + + def test_save_pretrained(self, parallel_sizes, tmpdir): + _, tp_size, pp_size = parallel_sizes + + output_dir = Path(tmpdir) + + peft_config = get_peft_config() + + # PEFT model saved using `PeftModel`. + seed_patcher = create_static_seed_patcher(LlamaForCausalLM, 42) + with seed_patcher: + _, model = get_tokenizer_and_tiny_llama_model() + orig_model_path = output_dir / "orig_peft" + orig_peft_model = orig_get_peft_model(model, peft_config) + + orig_peft_model.save_pretrained(orig_model_path.as_posix()) + + # PEFT model saved using `NeuronPeftModel`. + seed_patcher = create_static_seed_patcher(LlamaForCausalLM, 42) + with seed_patcher: + _, model = get_tokenizer_and_tiny_llama_model() + model_path = output_dir / "peft" + peft_model = get_peft_model(model, peft_config) + + accelerator = create_accelerator(tp_size, pp_size) + peft_model = accelerator.prepare_model(peft_model) + peft_model.save_pretrained(model_path.as_posix()) + + with open(orig_model_path / "adapter_config.json") as fp: + orig_adapter_config_content = json.dumps(json.load(fp), sort_keys=True) + + with open(model_path / "adapter_config.json") as fp: + adapter_config_content = json.dumps(json.load(fp), sort_keys=True) + + assert orig_adapter_config_content == adapter_config_content, "adapter_config.json files do not match" + + orig_state_dict = load_file(orig_model_path / "adapter_model.safetensors") + state_dict = load_file(model_path / "adapter_model.safetensors") + + assert orig_state_dict.keys() == state_dict.keys() + for name, tensor in orig_state_dict.items(): + print(f"Checking that the parameter {name} matches") + torch.testing.assert_close(tensor, state_dict[name]) + + def test_peft_training(self, parallel_sizes, tmpdir): + _, tp_size, pp_size = parallel_sizes + + per_device_train_batch_size = 1 + output_dir = Path(tmpdir) + args = NeuronTrainingArguments( + output_dir=output_dir.as_posix(), + do_train=True, + do_eval=False, + bf16=True, + per_device_train_batch_size=per_device_train_batch_size, + save_strategy="epoch", + logging_steps=10, + num_train_epochs=2, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + ) + + tokenizer, model = get_tokenizer_and_tiny_llama_model() + + num_train_samples = num_eval_samples = 50 + datasets = create_dummy_causal_lm_dataset( + model.config.vocab_size, num_train_samples, num_eval_samples, max_number_of_unique_examples=3 + ) + + trainer = NeuronTrainer( + model, + args, + tokenizer=tokenizer, + train_dataset=datasets["train"], + eval_dataset=datasets["eval"], + data_collator=default_data_collator_for_causal_lm, + ) + + trainer.train() diff --git a/tests/test_trainers.py b/tests/test_trainers.py index ea6b78148..58ef2c4c4 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -25,9 +25,7 @@ from huggingface_hub import HfApi from transformers import ( AutoConfig, - AutoModelForCausalLM, AutoModelForSequenceClassification, - AutoTokenizer, ) from optimum.neuron import NeuronTrainer, NeuronTrainingArguments @@ -41,8 +39,10 @@ from . import DistributedTest from .utils import ( + MODEL_NAME, create_dummy_causal_lm_dataset, default_data_collator_for_causal_lm, + get_tokenizer_and_tiny_llama_model, ) @@ -54,17 +54,6 @@ ) -MODEL_NAME = "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random" - - -def get_tokenizer_and_tiny_llama_model(parallel_sizes): - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - _, tp_size, pp_size = parallel_sizes - config = AutoConfig.from_pretrained(MODEL_NAME) - model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, ignore_mismatched_sizes=True) - return tokenizer, model - - @is_trainium_test class TestNeuronTrainingUtils(DistributedTest): @pytest.fixture( @@ -80,7 +69,7 @@ def test_get_model_param_count(self, parallel_sizes, tmpdir): _, tp_size, pp_size = parallel_sizes output_dir = Path(tmpdir) - _, model = get_tokenizer_and_tiny_llama_model(parallel_sizes) + _, model = get_tokenizer_and_tiny_llama_model() target_num_parameters = sum(p.numel() for p in model.parameters()) @@ -130,7 +119,7 @@ def test_save_checkpoint(self, hub_test, tmpdir, parallel_sizes): output_dir=output_dir.as_posix(), ) - tokenizer, model = get_tokenizer_and_tiny_llama_model(parallel_sizes) + tokenizer, model = get_tokenizer_and_tiny_llama_model() datasets = create_dummy_causal_lm_dataset(model.config.vocab_size, 120, 1, sequence_length=128) trainer = NeuronTrainer( @@ -197,7 +186,7 @@ def test_train_and_eval_use_remote_cache(self, hub_test_with_local_cache, tmpdir num_eval_samples = 100 per_device_eval_batch_size = 16 - tokenizer, model = get_tokenizer_and_tiny_llama_model(parallel_sizes) + tokenizer, model = get_tokenizer_and_tiny_llama_model() clone = copy.deepcopy(model) datasets = create_dummy_causal_lm_dataset(model.config.vocab_size, num_train_samples, num_eval_samples) @@ -296,7 +285,7 @@ def test_save_and_resume_from_checkpoint(self, parallel_sizes, tmpdir): max_train_samples = 100 max_eval_samples = 16 - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + tokenizer, _ = get_tokenizer_and_tiny_llama_model() tokenizer.pad_token = tokenizer.eos_token def create_training_args(output_dir, resume_from_checkpoint=None, max_steps=max_steps): diff --git a/tests/utils.py b/tests/utils.py index bc9aadb37..bb54b0992 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,15 +20,17 @@ import os import random import string -from typing import Callable, Dict, List, Optional, Tuple, Type +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import torch from datasets import Dataset, DatasetDict from huggingface_hub import CommitOperationDelete, HfApi, create_repo, delete_repo, get_token, login, logout from huggingface_hub.utils import RepositoryNotFoundError -from transformers import AutoConfig, PreTrainedModel +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from transformers.testing_utils import ENDPOINT_STAGING +from optimum.neuron import ModelParallelismPlugin, NeuronAccelerator from optimum.neuron.distributed import lazy_load_for_parallelism from optimum.neuron.utils.cache_utils import ( delete_custom_cache_repo_name_from_hf_home, @@ -45,6 +47,8 @@ SEED = 42 OPTIMUM_INTERNAL_TESTING_CACHE_REPO = "optimum-internal-testing/optimum-neuron-cache-for-testing" +MODEL_NAME = "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random" + def get_random_string(length) -> str: letters = string.ascii_lowercase @@ -88,8 +92,10 @@ def generate_input_ids(vocab_size: int, batch_size: int, sequence_length: int) - return torch.randint(0, vocab_size, (batch_size, sequence_length)) -def generate_attention_mask(batch_size: int, sequence_length: int) -> torch.Tensor: - return torch.randint(0, 2, (batch_size, sequence_length)) +def generate_attention_mask(batch_size: int, sequence_length: int, random: bool = False) -> torch.Tensor: + if random: + return torch.randint(0, 2, (batch_size, sequence_length)) + return torch.ones((batch_size, sequence_length)) def create_dummy_causal_lm_dataset( @@ -97,21 +103,31 @@ def create_dummy_causal_lm_dataset( num_train_examples: int, num_eval_examples: int, num_test_examples: Optional[int] = None, + max_number_of_unique_examples: Optional[int] = None, sequence_length: int = 32, + random_attention_mask: bool = False, ) -> DatasetDict: if num_test_examples is None: num_test_examples = num_eval_examples + if max_number_of_unique_examples is None: + max_number_of_unique_examples = max(num_train_examples, num_eval_examples, num_test_examples) + def create_gen(num_examples): def gen(): - for _ in range(num_examples): + examples = [] + for _ in range(min(num_examples, max_number_of_unique_examples)): input_ids = generate_input_ids(vocab_size, 1, sequence_length) - attention_mask = generate_attention_mask(1, sequence_length) - yield { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": input_ids, - } + attention_mask = generate_attention_mask(1, sequence_length, random=random_attention_mask) + examples.append( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": input_ids, + } + ) + for i in range(num_examples): + yield examples[i % max_number_of_unique_examples] return gen @@ -213,6 +229,38 @@ def get_model( return model +def get_tokenizer_and_tiny_llama_model(): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + config = AutoConfig.from_pretrained(MODEL_NAME) + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, ignore_mismatched_sizes=True) + return tokenizer, model + + +def create_accelerator( + tp_size: int, + pp_size: int, + zero_1: bool = False, + gradient_accumulation_steps: int = 1, + parallelize_embeddings: bool = True, + sequence_parallel_enabled: bool = True, + kv_size_multiplier: Optional[int] = None, + checkpoint_dir: Optional[Union[Path, str]] = None, + use_xser: bool = True, +) -> NeuronAccelerator: + mp_plugin = ModelParallelismPlugin( + tensor_parallel_size=tp_size, + kv_size_multiplier=kv_size_multiplier, + parallelize_embeddings=parallelize_embeddings, + sequence_parallel_enabled=sequence_parallel_enabled, + pipeline_parallel_size=pp_size, + checkpoint_dir=checkpoint_dir, + use_xser=use_xser, + ) + return NeuronAccelerator( + mp_plugin=mp_plugin, zero_1=zero_1, gradient_accumulation_steps=gradient_accumulation_steps + ) + + class TrainiumTestMixin: @classmethod def setUpClass(cls): diff --git a/text-generation-inference/tests/fixtures/model.py b/text-generation-inference/tests/fixtures/model.py index 661b3d839..5ee46b598 100644 --- a/text-generation-inference/tests/fixtures/model.py +++ b/text-generation-inference/tests/fixtures/model.py @@ -21,7 +21,7 @@ ) logger = logging.getLogger(__file__) -OPTIMUM_CACHE_REPO_ID = "optimum/neuron-testing-cache" +OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache" # All model configurations below will be added to the neuron_model_config fixture MODEL_CONFIGURATIONS = { @@ -41,7 +41,7 @@ def get_hub_neuron_model_id(config_name: str): - return f"optimum/neuron-testing-{version}-{sdk_version}-{config_name}" + return f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}" def export_model(model_id, export_kwargs, neuron_model_path): @@ -63,8 +63,8 @@ def neuron_model_config(request): """Expose a pre-trained neuron model The fixture first makes sure the following model artifacts are present on the hub: - - exported neuron model under optimum/neuron-testing--, - - cached artifacts under optimum/neuron-testing-cache. + - exported neuron model under optimum-internal-testing/neuron-testing--, + - cached artifacts under optimum-internal-testing/neuron-testing-cache. If not, it will export the model and push it to the hub. It then fetches the model locally and return a dictionary containing: diff --git a/text-generation-inference/tests/fixtures/service.py b/text-generation-inference/tests/fixtures/service.py index f5d4e3932..f108e600b 100644 --- a/text-generation-inference/tests/fixtures/service.py +++ b/text-generation-inference/tests/fixtures/service.py @@ -17,7 +17,7 @@ from huggingface_hub import AsyncInferenceClient, TextGenerationOutput -OPTIMUM_CACHE_REPO_ID = "optimum/neuron-testing-cache" +OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache" DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "neuronx-tgi:latest") HF_TOKEN = huggingface_hub.get_token()