diff --git a/benchmark/text-generation-inference/performance/generate_csv.py b/benchmark/text-generation-inference/performance/generate_csv.py index 1e7770f63..366370e19 100644 --- a/benchmark/text-generation-inference/performance/generate_csv.py +++ b/benchmark/text-generation-inference/performance/generate_csv.py @@ -3,7 +3,6 @@ import os import pandas as pd - from guidellm.core import GuidanceReport, TextGenerationBenchmark @@ -16,11 +15,7 @@ def _benchmark_rate_id(benchmark: TextGenerationBenchmark) -> str: :return: A string representing the benchmark rate ID. :rtype: str """ - rate_id = ( - f"{benchmark.mode}@{benchmark.rate:.2f} req/sec" - if benchmark.rate - else f"{benchmark.mode}" - ) + rate_id = f"{benchmark.mode}@{benchmark.rate:.2f} req/sec" if benchmark.rate else f"{benchmark.mode}" return rate_id @@ -38,20 +33,20 @@ def main(): for path in paths: filename = os.path.basename(path) # Extract model_id - model_id, date = filename.replace(suffix, '').split('#') + model_id, date = filename.replace(suffix, "").split("#") with open(path) as f: report = GuidanceReport.from_json(f.read()) for benchmark in report.benchmarks: for b in benchmark.benchmarks_sorted: d = { - "model_id": model_id, - "Date": date, - "Input type": _benchmark_rate_id(b), - "Requests per Second": b.completed_request_rate, - "Request Latency (s)": b.request_latency, - "Time-to-first-token (ms)": b.time_to_first_token, - "Inter Token Latency (ms)": b.inter_token_latency, - "Output Token Throughput (t/s)": b.output_token_throughput, + "model_id": model_id, + "Date": date, + "Input type": _benchmark_rate_id(b), + "Requests per Second": b.completed_request_rate, + "Request Latency (s)": b.request_latency, + "Time-to-first-token (ms)": b.time_to_first_token, + "Inter Token Latency (ms)": b.inter_token_latency, + "Output Token Throughput (t/s)": b.output_token_throughput, } results.append(pd.DataFrame.from_dict(d, orient="index").transpose()) diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index a55e42ef3..67263fd68 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -27,7 +27,7 @@ _import_structure = { "hf_argparser": ["NeuronHfArgumentParser"], - "trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer", "NeuronSFTTrainer"], + "trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer", "NeuronSFTTrainer", "NeuronORPOTrainer"], "training_args": ["NeuronTrainingArguments", "Seq2SeqNeuronTrainingArguments"], "modeling_traced": ["NeuronTracedModel"], "modeling": [ @@ -69,7 +69,7 @@ "ModelParallelismPlugin", ], "pipelines": ["pipeline"], - "utils": ["NeuronSFTConfig", "get_peft_model"], + "utils": ["NeuronSFTConfig", "NeuronORPOConfig", "get_peft_model"], } if TYPE_CHECKING: @@ -109,9 +109,9 @@ from .modeling_seq2seq import NeuronModelForSeq2SeqLM from .modeling_traced import NeuronTracedModel from .pipelines import pipeline - from .trainers import NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer + from .trainers import NeuronORPOTrainer, NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments - from .utils import NeuronSFTConfig, get_peft_model + from .utils import NeuronORPOConfig, NeuronSFTConfig, get_peft_model else: import sys diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 6608d5825..e682dc003 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -24,7 +24,7 @@ import time import warnings from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import datasets import numpy as np @@ -131,7 +131,7 @@ if is_trl_available(): - from trl import SFTConfig, SFTTrainer + from trl import ORPOConfig, ORPOTrainer, SFTConfig, SFTTrainer else: class SFTTrainer: @@ -140,6 +140,12 @@ class SFTTrainer: class SFTConfig: pass + class ORPOConfig: + pass + + class ORPOTrainer: + pass + if is_peft_available(): from peft import PeftConfig @@ -1863,3 +1869,54 @@ def tokenize(element): tokenized_dataset = dataset.map(tokenize, **map_kwargs) return tokenized_dataset + + +# class NeuronORPOTrainer(ORPOTrainer): +class NeuronORPOTrainer(_TrainerForNeuron, ORPOTrainer): + def get_batch_loss_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( + policy_chosen_logps, policy_rejected_logps + ) + # full ORPO loss + loss = policy_nll_loss - losses.mean() + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean() + # metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean() + # metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean() + metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean() + metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio + metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen + if is_torch_xla_available(): + xm.mark_step() # needed because .item() calls + for k, v in metrics.items(): + metrics[k] = v.item() + if self.aux_loss_enabled: + loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss + + return loss, metrics diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index 176373716..05b68cf89 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -127,6 +127,9 @@ class NeuronTrainingArgumentsMixin: ) def __post_init__(self): + if self.neuron_cc_flags_model_type is not None: + os.environ["OPTIMUM_NEURON_COMMON_FLAGS_MODEL_TYPE"] = self.neuron_cc_flags_model_type + # Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available` patch_accelerate_is_torch_xla_available() @@ -221,6 +224,11 @@ def __post_init__(self): def _setup_devices(self) -> "torch.device": return super()._setup_devices + @property + def neuron_cc_flags_model_type(self) -> Optional[str]: + """Controls the value to provide to the Neuron Compiler for the model-type flag.""" + return "transformer" + @property def place_model_on_device(self): return not self.mp_plugin.should_parallelize and super().place_model_on_device diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 19562b16a..965aeaaaa 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -75,7 +75,7 @@ "is_model_officially_supported", "patch_transformers_for_neuron_sdk", ], - "trl_utils": ["NeuronSFTConfig"], + "trl_utils": ["NeuronSFTConfig", "NeuronORPOConfig"], } if TYPE_CHECKING: @@ -135,7 +135,7 @@ is_model_officially_supported, patch_transformers_for_neuron_sdk, ) - from .trl_utils import NeuronSFTConfig + from .trl_utils import NeuronORPOConfig, NeuronSFTConfig else: import sys diff --git a/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py index b0e78e6ec..6c2d79d9c 100644 --- a/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py +++ b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py @@ -47,8 +47,9 @@ def set_common_flags(): """ Sets environment variables for transformer-based models training with AWS Neuron. """ - # Set compiler flag to compile for transformer model type - os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --model-type=transformer" + model_type = os.environ.get("OPTIMUM_NEURON_COMMON_FLAGS_MODEL_TYPE", "") + if model_type != "": + os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + f" --model-type={model_type}" # Setting MALLOC_ARENA_MAX is needed because of a memory issue in XLA/glic, otherwise OOM can happen during # checkpointing. More information here: # https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/torch/torch-neuronx/index.html#memory-leaking-in-glibc diff --git a/optimum/neuron/utils/trl_utils.py b/optimum/neuron/utils/trl_utils.py index c3b4d129c..31041122f 100644 --- a/optimum/neuron/utils/trl_utils.py +++ b/optimum/neuron/utils/trl_utils.py @@ -15,13 +15,14 @@ """Utilities related to the TRL library and support.""" from dataclasses import dataclass +from typing import Optional from ..training_args import NeuronTrainingArguments from .import_utils import is_trl_available if is_trl_available(): - from trl import SFTConfig + from trl import ORPOConfig, SFTConfig else: @dataclass @@ -29,7 +30,20 @@ class SFTConfig: def __init__(self, *args, **kwargs): raise RuntimeError("You need to install the `trl` library to use the `NeuronSFTConfig`.") + @dataclass + class ORPOConfig: + def __init__(self, *args, **kwargs): + raise RuntimeError("You need to install the `trl` library to use the `NeuronSFTConfig`.") + @dataclass class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig): pass + + +@dataclass +class NeuronORPOConfig(NeuronTrainingArguments, ORPOConfig): + + @property + def neuron_cc_flags_model_type(self) -> Optional[str]: + return None