Skip to content

Commit

Permalink
Follow grad_norm changes in transformers #27326 (#1730)
Browse files Browse the repository at this point in the history
* Follow change in transformers #27326

* fix typo
  • Loading branch information
jingyanwangms authored Mar 1, 2024
1 parent 8748081 commit 27deaf6
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,19 +744,27 @@ def get_dataloader_sampler(dataloader):
# deepspeed does its own clipping

if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm)
_grad_norm = self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(args.max_grad_norm)
_grad_norm = model.clip_grad_norm_(args.max_grad_norm)
else:
self.accelerator.clip_grad_norm_(
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)

if (
is_accelerate_available()
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
):
grad_norm = model.get_global_grad_norm()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None

# Optimizer step
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
Expand All @@ -767,11 +775,12 @@ def get_dataloader_sampler(dataloader):
self.lr_scheduler.step()

model.zero_grad()
grad_norm: Optional[float] = None
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

Expand All @@ -786,7 +795,7 @@ def get_dataloader_sampler(dataloader):
self.control.should_training_stop = True

self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)

if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
logger.warning(
Expand Down

0 comments on commit 27deaf6

Please sign in to comment.