diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index fd049089f..42f8aab4c 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -159,6 +159,9 @@ class PeftConfig: logger = logging.get_logger("transformers.trainer") + +TRL_VERSION = "0.11.4" + KEEP_HF_HUB_PROGRESS_BARS = os.environ.get("KEEP_HF_HUB_PROGRESS_BARS") if KEEP_HF_HUB_PROGRESS_BARS is None: os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" @@ -1549,7 +1552,7 @@ def __init__( peft_config: Optional["PeftConfig"] = None, formatting_func: Optional[Callable] = None, ): - if not is_trl_available(): + if not is_trl_available(required_version=TRL_VERSION): raise RuntimeError("Using NeuronSFTTrainer requires the trl library.") from trl.extras.dataset_formatting import get_formatting_func_from_dataset @@ -1894,7 +1897,7 @@ def __init__( peft_config: Optional[Dict] = None, compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, ): - if not is_trl_available(): + if not is_trl_available(required_version=TRL_VERSION): raise RuntimeError("Using NeuronORPOTrainer requires the trl library.") from trl.trainer.utils import DPODataCollatorWithPadding, disable_dropout_in_model, peft_module_casting_to_bf16 diff --git a/setup.py b/setup.py index 3126af110..1dfa521ed 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ "safetensors", "sentence-transformers >= 2.2.0", "peft", - "trl", + "trl==0.11.4", "compel", "rjieba", "soundfile",