diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 7c6b0a5227..2b7fb65430 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -744,19 +744,27 @@ def get_dataloader_sampler(dataloader): # deepspeed does its own clipping if is_sagemaker_mp_enabled() and args.fp16: - self.optimizer.clip_master_grads(args.max_grad_norm) + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) elif hasattr(self.optimizer, "clip_grad_norm"): # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping - self.optimizer.clip_grad_norm(args.max_grad_norm) + _grad_norm = self.optimizer.clip_grad_norm(args.max_grad_norm) elif hasattr(model, "clip_grad_norm_"): # Some models (like FullyShardedDDP) have a specific way to do gradient clipping - model.clip_grad_norm_(args.max_grad_norm) + _grad_norm = model.clip_grad_norm_(args.max_grad_norm) else: - self.accelerator.clip_grad_norm_( + _grad_norm = self.accelerator.clip_grad_norm_( model.parameters(), args.max_grad_norm, ) + if ( + is_accelerate_available() + and self.accelerator.distributed_type == DistributedType.DEEPSPEED + ): + grad_norm = model.get_global_grad_norm() + else: + grad_norm = _grad_norm.item() if _grad_norm is not None else None + # Optimizer step self.optimizer.step() optimizer_was_run = not self.accelerator.optimizer_step_was_skipped @@ -767,11 +775,12 @@ def get_dataloader_sampler(dataloader): self.lr_scheduler.step() model.zero_grad() + grad_norm: Optional[float] = None self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) @@ -786,7 +795,7 @@ def get_dataloader_sampler(dataloader): self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: logger.warning(