Skip to content

Commit

Permalink
Apply deprecated evaluation_strategy (#1819)
Browse files Browse the repository at this point in the history
Apply deprecation `evaluation_strategy`
  • Loading branch information
muellerzr authored Sep 6, 2024
1 parent c0d9111 commit 29f23f1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE run_image_classification.py \
--per_device_eval_batch_size 32 \
--logging_strategy steps \
--logging_steps 10 \
--evaluation_strategy epoch \
--eval_strategy epoch \
--seed 1337
```

Expand Down
22 changes: 11 additions & 11 deletions optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,32 @@ def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN

if isinstance(self.evaluation_strategy, EvaluationStrategy):
if isinstance(self.eval_strategy, EvaluationStrategy):
warnings.warn(
"using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5"
"using `EvaluationStrategy` for `eval_strategy` is deprecated and will be removed in version 5"
" of 🤗 Transformers. Use `IntervalStrategy` instead",
FutureWarning,
)
# Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it.
self.evaluation_strategy = self.evaluation_strategy.value
self.eval_strategy = self.eval_strategy.value

self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
self.eval_strategy = IntervalStrategy(self.eval_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy)
self.save_strategy = IntervalStrategy(self.save_strategy)
self.hub_strategy = HubStrategy(self.hub_strategy)

self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO:
self.do_eval = True

# eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero
if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
if self.logging_steps > 0:
logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}")
self.eval_steps = self.logging_steps
else:
raise ValueError(
f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or"
f"evaluation strategy {self.eval_strategy} requires either non-zero --eval_steps or"
" --logging_steps"
)

Expand All @@ -154,7 +154,7 @@ def __post_init__(self):
if self.logging_steps != int(self.logging_steps):
raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}")
self.logging_steps = int(self.logging_steps)
if self.evaluation_strategy == IntervalStrategy.STEPS and self.eval_steps > 1:
if self.eval_strategy == IntervalStrategy.STEPS and self.eval_steps > 1:
if self.eval_steps != int(self.eval_steps):
raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}")
self.eval_steps = int(self.eval_steps)
Expand All @@ -165,13 +165,13 @@ def __post_init__(self):

# Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible.
if self.load_best_model_at_end:
if self.evaluation_strategy != self.save_strategy:
if self.eval_strategy != self.save_strategy:
raise ValueError(
"--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
"steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps "
f"{self.save_steps} and eval_steps {self.eval_steps}."
)
if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
if self.eval_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
if self.eval_steps < 1 or self.save_steps < 1:
if not (self.eval_steps < 1 and self.save_steps < 1):
raise ValueError(
Expand Down Expand Up @@ -244,7 +244,7 @@ def __post_init__(self):
)

if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
if self.evaluation_strategy == IntervalStrategy.NO:
if self.eval_strategy == IntervalStrategy.NO:
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
if not is_torch_available():
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0")
Expand Down

0 comments on commit 29f23f1

Please sign in to comment.