From 281d9bb40347dad752b97cf71cbee17970858445 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 5 Sep 2024 10:44:17 +0200 Subject: [PATCH] SFTTrainer support (#682) --- optimum/neuron/__init__.py | 8 +- optimum/neuron/trainers.py | 414 ++++++++++++++++++++++++++- optimum/neuron/training_args.py | 2 +- optimum/neuron/utils/__init__.py | 4 + optimum/neuron/utils/import_utils.py | 11 + optimum/neuron/utils/trl_utils.py | 35 +++ setup.py | 1 + tests/test_trainers.py | 70 ++++- 8 files changed, 530 insertions(+), 15 deletions(-) create mode 100644 optimum/neuron/utils/trl_utils.py diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 2b8d7b81b..a55e42ef3 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -27,7 +27,7 @@ _import_structure = { "hf_argparser": ["NeuronHfArgumentParser"], - "trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer"], + "trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer", "NeuronSFTTrainer"], "training_args": ["NeuronTrainingArguments", "Seq2SeqNeuronTrainingArguments"], "modeling_traced": ["NeuronTracedModel"], "modeling": [ @@ -69,7 +69,7 @@ "ModelParallelismPlugin", ], "pipelines": ["pipeline"], - "utils": ["get_peft_model"], + "utils": ["NeuronSFTConfig", "get_peft_model"], } if TYPE_CHECKING: @@ -109,9 +109,9 @@ from .modeling_seq2seq import NeuronModelForSeq2SeqLM from .modeling_traced import NeuronTracedModel from .pipelines import pipeline - from .trainers import NeuronTrainer, Seq2SeqNeuronTrainer + from .trainers import NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments - from .utils import get_peft_model + from .utils import NeuronSFTConfig, get_peft_model else: import sys diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 103c29ba3..6608d5825 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -15,21 +15,36 @@ """Defines Trainer subclasses to perform training on AWS Neuron instances.""" import copy +import dataclasses +import inspect import math import os import shutil import sys import time import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import datasets import numpy as np import torch from accelerate import __version__ as accelerate_version +from accelerate.state import PartialState from accelerate.utils import AutocastKwargs, DataLoaderConfiguration, GradientAccumulationPlugin from packaging import version from torch.utils.data import Dataset -from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollator, + DataCollatorForLanguageModeling, + PreTrainedModel, + PreTrainedTokenizerBase, + Seq2SeqTrainer, + Trainer, + TrainingArguments, +) from transformers.debug_utils import DebugOption, DebugUnderflowOverflow from transformers.integrations import hp_params from transformers.modeling_utils import unwrap_model @@ -39,7 +54,7 @@ TRAINER_STATE_NAME, TRAINING_ARGS_NAME, ) -from transformers.trainer_callback import TrainerState +from transformers.trainer_callback import TrainerCallback, TrainerState from transformers.trainer_pt_utils import ( IterableDatasetShard, find_batch_size, @@ -63,6 +78,7 @@ WEIGHTS_NAME, is_accelerate_available, is_apex_available, + is_peft_available, is_sagemaker_mp_enabled, ) @@ -73,6 +89,7 @@ from .training_args import NeuronTrainingArguments from .utils import ( is_torch_xla_available, + is_trl_available, patch_within_function, ) from .utils.cache_utils import ( @@ -84,7 +101,7 @@ ) from .utils.hub_cache_utils import ModelCacheEntry, hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache from .utils.misc import is_main_worker, is_precompilation -from .utils.peft_utils import NeuronPeftModel +from .utils.peft_utils import NeuronPeftModel, get_peft_model from .utils.require_utils import requires_neuronx_distributed, requires_torch_neuronx from .utils.training_utils import ( get_model_param_count, @@ -93,6 +110,7 @@ patch_generation_mixin_to_neuron_generation_mixin, skip_first_batches, ) +from .utils.trl_utils import NeuronSFTConfig from .utils.version_utils import get_neuronxcc_version @@ -111,6 +129,26 @@ else: IS_SAGEMAKER_MP_POST_1_10 = False + +if is_trl_available(): + from trl import SFTConfig, SFTTrainer +else: + + class SFTTrainer: + pass + + class SFTConfig: + pass + + +if is_peft_available(): + from peft import PeftConfig +else: + + class PeftConfig: + pass + + logger = logging.get_logger("transformers.trainer") KEEP_HF_HUB_PROGRESS_BARS = os.environ.get("KEEP_HF_HUB_PROGRESS_BARS") @@ -120,7 +158,7 @@ transformers_get_optimizer_cls_and_kwargs = Trainer.get_optimizer_cls_and_kwargs -class AugmentTrainerForNeuronMixin: +class _TrainerForNeuron: def __init__(self, *args, **kwargs): if not isinstance(self, Trainer): raise TypeError(f"{self.__class__.__name__} can only be mixed with Trainer subclasses.") @@ -454,7 +492,11 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno tr_loss.zero_() def log_closure(self, reduced_tr_loss, grad_norm): - if is_main_worker_for_metrics(): + # We need to check that self.state.global_step > self._globalstep_last_logged because if two + # closures are added in a row (which can happen at the end of the training), then it will fail the + # second time because at this point we will have: + # self.state.global_step = self._globalstep_last_logged + if is_main_worker_for_metrics() and self.state.global_step > self._globalstep_last_logged: logs: Dict[str, float] = {} tr_loss_scalar = reduced_tr_loss.to("cpu").item() @@ -1455,13 +1497,369 @@ def save_state(self): return super().save_state() -class NeuronTrainer(AugmentTrainerForNeuronMixin, Trainer): +class NeuronTrainer(_TrainerForNeuron, Trainer): """ Trainer that is suited for performing training on AWS Tranium instances. """ -class Seq2SeqNeuronTrainer(AugmentTrainerForNeuronMixin, Seq2SeqTrainer): +class Seq2SeqNeuronTrainer(_TrainerForNeuron, Seq2SeqTrainer): """ Seq2SeqTrainer that is suited for performing training on AWS Tranium instances. """ + + +class _SFTTrainerTrainerInit(SFTTrainer): + def __init__(self, *args, **kwargs): + return Trainer.__init__(self, *args, **kwargs) + + +class NeuronSFTTrainer(_TrainerForNeuron, _SFTTrainerTrainerInit): + """ + `SFTTrainer` adapted for Neuron. + + It differs from the original `SFTTrainer` by: + - Using `_TrainerForNeuron.__init__()` instead of `Trainer.__init__()` + - Using the `_TrainerForNeuron.train()` instead of `Trainer.train()` + - Adapts the `_prepare_non_packed_dataloader` to pad to max length. In the original `SFTTrainer` examples are + not padded, which is an issue here because it triggers compilation every time. + """ + + def __init__( + self, + model: Optional[Union[PreTrainedModel, torch.nn.Module, str]] = None, + args: Optional[SFTConfig] = None, + data_collator: Optional[DataCollator] = None, # type: ignore + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + formatting_func: Optional[Callable] = None, + ): + if not is_trl_available(): + raise RuntimeError("Using NeuronSFTTrainer requires the trl library.") + + from trl.extras.dataset_formatting import get_formatting_func_from_dataset + + # This will be changed to : + from trl.trainer.callbacks import RichProgressCallback + from trl.trainer.utils import ( + DataCollatorForCompletionOnlyLM, + peft_module_casting_to_bf16, + ) + + if is_peft_available(): + from peft import PeftConfig, prepare_model_for_kbit_training + + if args is None: + output_dir = "tmp_trainer" + warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.") + args = NeuronSFTConfig(output_dir=output_dir) + elif args is not None and args.__class__.__name__ == "NeuronTrainingArguments": + args_as_dict = args.to_dict() + # Manually copy token values as TrainingArguments.to_dict() redacts them + args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")}) + args = NeuronSFTConfig(**args_as_dict) + + if getattr(args, "model_init_kwargs", None) is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + torch_dtype = model_init_kwargs.get("torch_dtype") + if torch_dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(torch_dtype, str) and torch_dtype != "auto": + torch_dtype = getattr(torch, torch_dtype) + if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + ) + model_init_kwargs["torch_dtype"] = torch_dtype + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the SFTTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if args.packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): + raise ValueError( + "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." + ) + + if is_peft_available() and peft_config is not None: + if not isinstance(peft_config, PeftConfig): + raise ValueError( + "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." + f" and you passed a {type(peft_config)}." + ) + + if not isinstance(model, NeuronPeftModel): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} + is_sharded_qlora = False + # Below is to support QLoRA + FSDP / DS-Zero3 - one should never call + # peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing + # QLoRA + FSDP / DS-Zero3 + if getattr(model, "is_loaded_in_4bit", False): + for _, param in model.named_parameters(): + if param.__class__.__name__ == "Params4bit": + is_sharded_qlora = param.data.device.type == "cpu" + break + if getattr(model, "is_loaded_in_8bit", False) or ( + getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora + ): + prepare_model_kwargs = { + "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) + } + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + + if args is not None: + args = dataclasses.replace(args, gradient_checkpointing=False) + elif getattr(args, "gradient_checkpointing", False) and ( + "use_reentrant" not in gradient_checkpointing_kwargs + or gradient_checkpointing_kwargs["use_reentrant"] + ): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if ( + "autocast_adapter_dtype" in list(inspect.signature(get_peft_model).parameters) + and getattr(model, "is_loaded_in_4bit", False) + and is_sharded_qlora + ): + model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) + else: + model = get_peft_model(model, peft_config) + if ( + args is not None + and args.bf16 + and getattr(model, "is_loaded_in_4bit", False) + and not is_sharded_qlora + ): + peft_module_casting_to_bf16(model) + + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + if args.max_seq_length is None: + # to overcome some issues with broken tokenizers + args.max_seq_length = min(tokenizer.model_max_length, 1024) + + warnings.warn( + f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {args.max_seq_length}" + ) + + self.dataset_num_proc = args.dataset_num_proc + + self.dataset_batch_size = args.dataset_batch_size + + self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + + if args.dataset_kwargs is None: + args.dataset_kwargs = {} + + if formatting_func is None and args.dataset_text_field is None: + # check if dataset has ChatML format or instruction format and is supported + # if not stays #None + formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) + # if a template is detected, we don't need to add special tokens again + if formatting_func is not None: + args.dataset_kwargs["add_special_tokens"] = False + + if not args.packing: + # If we aren't skipping data preparation, then a dataset_text_field + # or formatting_func must be provided. + if ( + args.dataset_text_field is None + and formatting_func is None + and not args.dataset_kwargs.get("skip_prepare_dataset", False) + ): + raise ValueError( + "You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument." + ) + + if data_collator is None: + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + # Pre-process the datasets only once per node. The remaining processes will use the cache. + with PartialState().local_main_process_first(): + if train_dataset is not None: + train_dataset = self._prepare_dataset( + train_dataset, + tokenizer, + args.packing, + args.dataset_text_field, + args.max_seq_length, + formatting_func, + args.num_of_sequences, + args.chars_per_token, + remove_unused_columns=args.remove_unused_columns if args is not None else True, + **args.dataset_kwargs, + ) + if eval_dataset is not None: + _multiple = isinstance(eval_dataset, dict) + _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} + + eval_packing = args.packing if args.eval_packing is None else args.eval_packing + + for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): + _eval_datasets[_eval_dataset_name] = self._prepare_dataset( + _eval_dataset, + tokenizer, + eval_packing, + args.dataset_text_field, + args.max_seq_length, + formatting_func, + args.num_of_sequences, + args.chars_per_token, + remove_unused_columns=args.remove_unused_columns if args is not None else True, + **args.dataset_kwargs, + ) + if not _multiple: + eval_dataset = _eval_datasets["singleton"] + + if tokenizer.padding_side is not None and tokenizer.padding_side != "right": + warnings.warn( + "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " + "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code." + ) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if self.args.max_steps > 0 and args.packing: + warnings.warn( + "You passed `packing=True` to the SFTTrainer/SFTConfig, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached." + ) + self.train_dataset.infinite = True + elif self.args.max_steps == -1 and args.packing: + self.train_dataset.infinite = False + + if any(isinstance(callback, RichProgressCallback) for callback in self.callback_handler.callbacks): + for callback in self.callback_handler.callbacks: + # Remove the PrinterCallback to avoid duplicated prints in case we passed a `RichProgressCallback` + if callback.__class__.__name__ == "PrinterCallback": + self.callback_handler.pop_callback(callback) + + @wraps(_TrainerForNeuron.train) + def train(self, *args, **kwargs): + # Activate neftune right before training. + if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + self.model = self._trl_activate_neftune(self.model) + + output = super().train(*args, **kwargs) + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + unwrapped_model = unwrap_model(self.model) + if is_peft_available() and isinstance(unwrapped_model, NeuronPeftModel): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + self.neftune_hook_handle.remove() + del embeddings.neftune_noise_alpha + + return output + + def _prepare_non_packed_dataloader( + self, + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + formatting_func=None, + add_special_tokens=True, + remove_unused_columns=True, + ): + use_formatting_func = formatting_func is not None and dataset_text_field is None + self._dataset_sanity_checked = False + + # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field] if not use_formatting_func else formatting_func(element), + add_special_tokens=add_special_tokens, + truncation=True, + # For Neuron we need to pad because otherwise it will trigger compilation for each new sequence length. + padding="max_length", + max_length=max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + + if use_formatting_func and not self._dataset_sanity_checked: + if not isinstance(formatting_func(element), list): + raise ValueError( + "The `formatting_func` should return a list of processed strings since it can lead to silent bugs." + ) + else: + self._dataset_sanity_checked = True + + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + signature_columns = ["input_ids", "labels", "attention_mask"] + + if dataset.column_names is not None: # None for IterableDataset + extra_columns = list(set(dataset.column_names) - set(signature_columns)) + else: + extra_columns = [] + + if not remove_unused_columns and len(extra_columns) > 0: + warnings.warn( + "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to " + f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns." + ) + + map_kwargs = { + "batched": True, + "remove_columns": dataset.column_names if remove_unused_columns else None, + "batch_size": self.dataset_batch_size, + } + if isinstance(dataset, datasets.Dataset): + map_kwargs["num_proc"] = self.dataset_num_proc # this arg is not available for IterableDataset + tokenized_dataset = dataset.map(tokenize, **map_kwargs) + + return tokenized_dataset diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index ce6e34a0b..176373716 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -130,7 +130,7 @@ def __post_init__(self): # Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available` patch_accelerate_is_torch_xla_available() - if self.fsdp != "": + if self.fsdp not in ["", []]: raise RuntimeError("FSDP is not supported.") if self.fp16: diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index ce8283639..0c4e60209 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -40,6 +40,7 @@ "is_torch_neuronx_available", "is_torch_xla_available", "is_transformers_neuronx_available", + "is_trl_available", ], "input_generators": [ "DummyBeamValuesGenerator", @@ -73,6 +74,7 @@ "is_model_officially_supported", "patch_transformers_for_neuron_sdk", ], + "trl_utils": ["NeuronSFTConfig"], } if TYPE_CHECKING: @@ -97,6 +99,7 @@ is_torch_neuronx_available, is_torch_xla_available, is_transformers_neuronx_available, + is_trl_available, ) from .input_generators import ( ASTDummyAudioInputGenerator, @@ -130,6 +133,7 @@ is_model_officially_supported, patch_transformers_for_neuron_sdk, ) + from .trl_utils import NeuronSFTConfig else: import sys diff --git a/optimum/neuron/utils/import_utils.py b/optimum/neuron/utils/import_utils.py index 11340e1d6..ebfc7d81d 100644 --- a/optimum/neuron/utils/import_utils.py +++ b/optimum/neuron/utils/import_utils.py @@ -65,3 +65,14 @@ def is_accelerate_available(min_version: Optional[str] = MIN_ACCELERATE_VERSION) def is_torch_neuronx_available() -> bool: return importlib.util.find_spec("torch_neuronx") is not None + + +def is_trl_available() -> bool: + trl_available = importlib.util.find_spec("trl") is not None + if trl_available: + import trl + + if version.parse(trl.__version__) >= version.parse("0.10.0"): + return True + raise RuntimeError("Only `trl` 0.10.0 and more recent is supported.") + return False diff --git a/optimum/neuron/utils/trl_utils.py b/optimum/neuron/utils/trl_utils.py new file mode 100644 index 000000000..c3b4d129c --- /dev/null +++ b/optimum/neuron/utils/trl_utils.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities related to the TRL library and support.""" + +from dataclasses import dataclass + +from ..training_args import NeuronTrainingArguments +from .import_utils import is_trl_available + + +if is_trl_available(): + from trl import SFTConfig +else: + + @dataclass + class SFTConfig: + def __init__(self, *args, **kwargs): + raise RuntimeError("You need to install the `trl` library to use the `NeuronSFTConfig`.") + + +@dataclass +class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig): + pass diff --git a/setup.py b/setup.py index f2e770e51..a389de8dc 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ "safetensors", "sentence-transformers >= 2.2.0", "peft", + "trl", "compel", "rjieba", "soundfile", diff --git a/tests/test_trainers.py b/tests/test_trainers.py index 58ef2c4c4..514120247 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -28,7 +28,7 @@ AutoModelForSequenceClassification, ) -from optimum.neuron import NeuronTrainer, NeuronTrainingArguments +from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainer, NeuronTrainingArguments from optimum.neuron.distributed.utils import MODEL_PARALLEL_SHARDS_DIR_NAME from optimum.neuron.utils import is_neuronx_distributed_available from optimum.neuron.utils.cache_utils import ( @@ -300,7 +300,7 @@ def create_training_args(output_dir, resume_from_checkpoint=None, max_steps=max_ per_device_train_batch_size=train_batch_size, per_device_eval_batch_size=eval_batch_size, max_steps=max_steps, - logging_steps=1, + logging_steps=2, save_steps=5, do_eval=do_eval, output_dir=output_dir, @@ -396,3 +396,69 @@ def preprocess_function(examples): trainer.train(resume_from_checkpoint=True) trainer.evaluate() + + +@is_trainium_test +class TestNeuronSFTTrainer(DistributedTest): + @pytest.fixture( + scope="class", + params=[[2, 1, 1], [2, 2, 1]], + ids=["dp=2", "tp=2"], + ) + def parallel_sizes(self, request): + return request.param + + def _test_sft_trainer(self, parallel_sizes, tmpdir, packing): + _, tp_size, pp_size = parallel_sizes + + output_dir = Path(tmpdir) + + dataset = load_dataset("databricks/databricks-dolly-15k", split="train") + + def format_dolly(sample): + instruction = f"### Instruction\n{sample['instruction']}" + context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None + response = f"### Answer\n{sample['response']}" + # join all the parts together + prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) + if packing: + return prompt + return [prompt] + + tokenizer, model = get_tokenizer_and_tiny_llama_model() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" # to prevent warnings + + args = NeuronTrainingArguments( + output_dir=output_dir, + do_train=True, + max_steps=20, + per_device_train_batch_size=1, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + logging_steps=1, + ) + args = args.to_dict() + sft_config = NeuronSFTConfig( + max_seq_length=512, + packing=packing, + dataset_num_proc=1, + **args, + ) + + # Create Trainer instance + trainer = NeuronSFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, + formatting_func=format_dolly, + args=sft_config, + ) + + trainer.train() + + def test_without_packing(self, parallel_sizes, tmpdir): + return self._test_sft_trainer(parallel_sizes, tmpdir, False) + + def test_with_packing(self, parallel_sizes, tmpdir): + return self._test_sft_trainer(parallel_sizes, tmpdir, True)