Skip to content

Commit

Permalink
Add NeuronORPOTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Oct 17, 2024
1 parent 0ea7285 commit b6fb211
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 26 deletions.
25 changes: 10 additions & 15 deletions benchmark/text-generation-inference/performance/generate_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os

import pandas as pd

from guidellm.core import GuidanceReport, TextGenerationBenchmark


Expand All @@ -16,11 +15,7 @@ def _benchmark_rate_id(benchmark: TextGenerationBenchmark) -> str:
:return: A string representing the benchmark rate ID.
:rtype: str
"""
rate_id = (
f"{benchmark.mode}@{benchmark.rate:.2f} req/sec"
if benchmark.rate
else f"{benchmark.mode}"
)
rate_id = f"{benchmark.mode}@{benchmark.rate:.2f} req/sec" if benchmark.rate else f"{benchmark.mode}"
return rate_id


Expand All @@ -38,20 +33,20 @@ def main():
for path in paths:
filename = os.path.basename(path)
# Extract model_id
model_id, date = filename.replace(suffix, '').split('#')
model_id, date = filename.replace(suffix, "").split("#")
with open(path) as f:
report = GuidanceReport.from_json(f.read())
for benchmark in report.benchmarks:
for b in benchmark.benchmarks_sorted:
d = {
"model_id": model_id,
"Date": date,
"Input type": _benchmark_rate_id(b),
"Requests per Second": b.completed_request_rate,
"Request Latency (s)": b.request_latency,
"Time-to-first-token (ms)": b.time_to_first_token,
"Inter Token Latency (ms)": b.inter_token_latency,
"Output Token Throughput (t/s)": b.output_token_throughput,
"model_id": model_id,
"Date": date,
"Input type": _benchmark_rate_id(b),
"Requests per Second": b.completed_request_rate,
"Request Latency (s)": b.request_latency,
"Time-to-first-token (ms)": b.time_to_first_token,
"Inter Token Latency (ms)": b.inter_token_latency,
"Output Token Throughput (t/s)": b.output_token_throughput,
}
results.append(pd.DataFrame.from_dict(d, orient="index").transpose())

Expand Down
8 changes: 4 additions & 4 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

_import_structure = {
"hf_argparser": ["NeuronHfArgumentParser"],
"trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer", "NeuronSFTTrainer"],
"trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer", "NeuronSFTTrainer", "NeuronORPOTrainer"],
"training_args": ["NeuronTrainingArguments", "Seq2SeqNeuronTrainingArguments"],
"modeling_traced": ["NeuronTracedModel"],
"modeling": [
Expand Down Expand Up @@ -69,7 +69,7 @@
"ModelParallelismPlugin",
],
"pipelines": ["pipeline"],
"utils": ["NeuronSFTConfig", "get_peft_model"],
"utils": ["NeuronSFTConfig", "NeuronORPOConfig", "get_peft_model"],
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,9 +109,9 @@
from .modeling_seq2seq import NeuronModelForSeq2SeqLM
from .modeling_traced import NeuronTracedModel
from .pipelines import pipeline
from .trainers import NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer
from .trainers import NeuronORPOTrainer, NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer
from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments
from .utils import NeuronSFTConfig, get_peft_model
from .utils import NeuronORPOConfig, NeuronSFTConfig, get_peft_model

else:
import sys
Expand Down
61 changes: 59 additions & 2 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import time
import warnings
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import datasets
import numpy as np
Expand Down Expand Up @@ -131,7 +131,7 @@


if is_trl_available():
from trl import SFTConfig, SFTTrainer
from trl import ORPOConfig, ORPOTrainer, SFTConfig, SFTTrainer
else:

class SFTTrainer:
Expand All @@ -140,6 +140,12 @@ class SFTTrainer:
class SFTConfig:
pass

class ORPOConfig:
pass

class ORPOTrainer:
pass


if is_peft_available():
from peft import PeftConfig
Expand Down Expand Up @@ -1863,3 +1869,54 @@ def tokenize(element):
tokenized_dataset = dataset.map(tokenize, **map_kwargs)

return tokenized_dataset


# class NeuronORPOTrainer(ORPOTrainer):
class NeuronORPOTrainer(_TrainerForNeuron, ORPOTrainer):
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()

reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
# metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
# metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
if is_torch_xla_available():
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss

return loss, metrics
8 changes: 8 additions & 0 deletions optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class NeuronTrainingArgumentsMixin:
)

def __post_init__(self):
if self.neuron_cc_flags_model_type is not None:
os.environ["OPTIMUM_NEURON_COMMON_FLAGS_MODEL_TYPE"] = self.neuron_cc_flags_model_type

# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
patch_accelerate_is_torch_xla_available()

Expand Down Expand Up @@ -221,6 +224,11 @@ def __post_init__(self):
def _setup_devices(self) -> "torch.device":
return super()._setup_devices

@property
def neuron_cc_flags_model_type(self) -> Optional[str]:
"""Controls the value to provide to the Neuron Compiler for the model-type flag."""
return "transformer"

@property
def place_model_on_device(self):
return not self.mp_plugin.should_parallelize and super().place_model_on_device
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"is_model_officially_supported",
"patch_transformers_for_neuron_sdk",
],
"trl_utils": ["NeuronSFTConfig"],
"trl_utils": ["NeuronSFTConfig", "NeuronORPOConfig"],
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,7 +135,7 @@
is_model_officially_supported,
patch_transformers_for_neuron_sdk,
)
from .trl_utils import NeuronSFTConfig
from .trl_utils import NeuronORPOConfig, NeuronSFTConfig
else:
import sys

Expand Down
5 changes: 3 additions & 2 deletions optimum/neuron/utils/torch_xla_and_neuronx_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def set_common_flags():
"""
Sets environment variables for transformer-based models training with AWS Neuron.
"""
# Set compiler flag to compile for transformer model type
os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --model-type=transformer"
model_type = os.environ.get("OPTIMUM_NEURON_COMMON_FLAGS_MODEL_TYPE", "")
if model_type != "":
os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + f" --model-type={model_type}"
# Setting MALLOC_ARENA_MAX is needed because of a memory issue in XLA/glic, otherwise OOM can happen during
# checkpointing. More information here:
# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/torch/torch-neuronx/index.html#memory-leaking-in-glibc
Expand Down
16 changes: 15 additions & 1 deletion optimum/neuron/utils/trl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,35 @@
"""Utilities related to the TRL library and support."""

from dataclasses import dataclass
from typing import Optional

from ..training_args import NeuronTrainingArguments
from .import_utils import is_trl_available


if is_trl_available():
from trl import SFTConfig
from trl import ORPOConfig, SFTConfig
else:

@dataclass
class SFTConfig:
def __init__(self, *args, **kwargs):
raise RuntimeError("You need to install the `trl` library to use the `NeuronSFTConfig`.")

@dataclass
class ORPOConfig:
def __init__(self, *args, **kwargs):
raise RuntimeError("You need to install the `trl` library to use the `NeuronSFTConfig`.")


@dataclass
class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig):
pass


@dataclass
class NeuronORPOConfig(NeuronTrainingArguments, ORPOConfig):

@property
def neuron_cc_flags_model_type(self) -> Optional[str]:
return None

0 comments on commit b6fb211

Please sign in to comment.