Skip to content

Commit

Permalink
Add required trl version
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Oct 18, 2024
1 parent 70180cc commit b3d5590
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ class PeftConfig:

logger = logging.get_logger("transformers.trainer")


TRL_VERSION = "0.11.4"

KEEP_HF_HUB_PROGRESS_BARS = os.environ.get("KEEP_HF_HUB_PROGRESS_BARS")
if KEEP_HF_HUB_PROGRESS_BARS is None:
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
Expand Down Expand Up @@ -1549,7 +1552,7 @@ def __init__(
peft_config: Optional["PeftConfig"] = None,
formatting_func: Optional[Callable] = None,
):
if not is_trl_available():
if not is_trl_available(required_version=TRL_VERSION):
raise RuntimeError("Using NeuronSFTTrainer requires the trl library.")

from trl.extras.dataset_formatting import get_formatting_func_from_dataset
Expand Down Expand Up @@ -1894,7 +1897,7 @@ def __init__(
peft_config: Optional[Dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
):
if not is_trl_available():
if not is_trl_available(required_version=TRL_VERSION):
raise RuntimeError("Using NeuronORPOTrainer requires the trl library.")

from trl.trainer.utils import DPODataCollatorWithPadding, disable_dropout_in_model, peft_module_casting_to_bf16
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"safetensors",
"sentence-transformers >= 2.2.0",
"peft",
"trl",
"trl==0.11.4",
"compel",
"rjieba",
"soundfile",
Expand Down

0 comments on commit b3d5590

Please sign in to comment.