Skip to content

Commit

Permalink
SFTTrainer support (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Sep 5, 2024
1 parent 0edf65f commit 281d9bb
Show file tree
Hide file tree
Showing 8 changed files with 530 additions and 15 deletions.
8 changes: 4 additions & 4 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

_import_structure = {
"hf_argparser": ["NeuronHfArgumentParser"],
"trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer"],
"trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer", "NeuronSFTTrainer"],
"training_args": ["NeuronTrainingArguments", "Seq2SeqNeuronTrainingArguments"],
"modeling_traced": ["NeuronTracedModel"],
"modeling": [
Expand Down Expand Up @@ -69,7 +69,7 @@
"ModelParallelismPlugin",
],
"pipelines": ["pipeline"],
"utils": ["get_peft_model"],
"utils": ["NeuronSFTConfig", "get_peft_model"],
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,9 +109,9 @@
from .modeling_seq2seq import NeuronModelForSeq2SeqLM
from .modeling_traced import NeuronTracedModel
from .pipelines import pipeline
from .trainers import NeuronTrainer, Seq2SeqNeuronTrainer
from .trainers import NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer
from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments
from .utils import get_peft_model
from .utils import NeuronSFTConfig, get_peft_model

else:
import sys
Expand Down
414 changes: 406 additions & 8 deletions optimum/neuron/trainers.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __post_init__(self):
# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
patch_accelerate_is_torch_xla_available()

if self.fsdp != "":
if self.fsdp not in ["", []]:
raise RuntimeError("FSDP is not supported.")

if self.fp16:
Expand Down
4 changes: 4 additions & 0 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"is_torch_neuronx_available",
"is_torch_xla_available",
"is_transformers_neuronx_available",
"is_trl_available",
],
"input_generators": [
"DummyBeamValuesGenerator",
Expand Down Expand Up @@ -73,6 +74,7 @@
"is_model_officially_supported",
"patch_transformers_for_neuron_sdk",
],
"trl_utils": ["NeuronSFTConfig"],
}

if TYPE_CHECKING:
Expand All @@ -97,6 +99,7 @@
is_torch_neuronx_available,
is_torch_xla_available,
is_transformers_neuronx_available,
is_trl_available,
)
from .input_generators import (
ASTDummyAudioInputGenerator,
Expand Down Expand Up @@ -130,6 +133,7 @@
is_model_officially_supported,
patch_transformers_for_neuron_sdk,
)
from .trl_utils import NeuronSFTConfig
else:
import sys

Expand Down
11 changes: 11 additions & 0 deletions optimum/neuron/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,14 @@ def is_accelerate_available(min_version: Optional[str] = MIN_ACCELERATE_VERSION)

def is_torch_neuronx_available() -> bool:
return importlib.util.find_spec("torch_neuronx") is not None


def is_trl_available() -> bool:
trl_available = importlib.util.find_spec("trl") is not None
if trl_available:
import trl

if version.parse(trl.__version__) >= version.parse("0.10.0"):
return True
raise RuntimeError("Only `trl` 0.10.0 and more recent is supported.")
return False
35 changes: 35 additions & 0 deletions optimum/neuron/utils/trl_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to the TRL library and support."""

from dataclasses import dataclass

from ..training_args import NeuronTrainingArguments
from .import_utils import is_trl_available


if is_trl_available():
from trl import SFTConfig
else:

@dataclass
class SFTConfig:
def __init__(self, *args, **kwargs):
raise RuntimeError("You need to install the `trl` library to use the `NeuronSFTConfig`.")


@dataclass
class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig):
pass
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"safetensors",
"sentence-transformers >= 2.2.0",
"peft",
"trl",
"compel",
"rjieba",
"soundfile",
Expand Down
70 changes: 68 additions & 2 deletions tests/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
AutoModelForSequenceClassification,
)

from optimum.neuron import NeuronTrainer, NeuronTrainingArguments
from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainer, NeuronTrainingArguments
from optimum.neuron.distributed.utils import MODEL_PARALLEL_SHARDS_DIR_NAME
from optimum.neuron.utils import is_neuronx_distributed_available
from optimum.neuron.utils.cache_utils import (
Expand Down Expand Up @@ -300,7 +300,7 @@ def create_training_args(output_dir, resume_from_checkpoint=None, max_steps=max_
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=eval_batch_size,
max_steps=max_steps,
logging_steps=1,
logging_steps=2,
save_steps=5,
do_eval=do_eval,
output_dir=output_dir,
Expand Down Expand Up @@ -396,3 +396,69 @@ def preprocess_function(examples):

trainer.train(resume_from_checkpoint=True)
trainer.evaluate()


@is_trainium_test
class TestNeuronSFTTrainer(DistributedTest):
@pytest.fixture(
scope="class",
params=[[2, 1, 1], [2, 2, 1]],
ids=["dp=2", "tp=2"],
)
def parallel_sizes(self, request):
return request.param

def _test_sft_trainer(self, parallel_sizes, tmpdir, packing):
_, tp_size, pp_size = parallel_sizes

output_dir = Path(tmpdir)

dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

def format_dolly(sample):
instruction = f"### Instruction\n{sample['instruction']}"
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
response = f"### Answer\n{sample['response']}"
# join all the parts together
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
if packing:
return prompt
return [prompt]

tokenizer, model = get_tokenizer_and_tiny_llama_model()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # to prevent warnings

args = NeuronTrainingArguments(
output_dir=output_dir,
do_train=True,
max_steps=20,
per_device_train_batch_size=1,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
logging_steps=1,
)
args = args.to_dict()
sft_config = NeuronSFTConfig(
max_seq_length=512,
packing=packing,
dataset_num_proc=1,
**args,
)

# Create Trainer instance
trainer = NeuronSFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=format_dolly,
args=sft_config,
)

trainer.train()

def test_without_packing(self, parallel_sizes, tmpdir):
return self._test_sft_trainer(parallel_sizes, tmpdir, False)

def test_with_packing(self, parallel_sizes, tmpdir):
return self._test_sft_trainer(parallel_sizes, tmpdir, True)

0 comments on commit 281d9bb

Please sign in to comment.