diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 1a98ed9..e45daa2 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -136,6 +136,12 @@ def step(self): self.zero_grad() return for optimizer, lr_scheduler in zip(self.optimizers, self.lr_schedulers): + try: + check_overflow(optimizer.param_groups) + except OverflowError: + has_overflow = True + print_rank("Gradient overflow, change scale from %lf to %lf" % (self.loss_scale, self.loss_scale / self.loss_scale_factor)) + break if hasattr(optimizer, "_bmtrain_optimizer") and optimizer._bmtrain_optimizer: optimizer.step(scale=self.loss_scale) else: