From ecbaf31c4b5a47ed092140797c2bafcb788288d3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 5 Jul 2024 14:28:54 +0200 Subject: [PATCH] Apply new way of resetting the loss --- optimum/neuron/trainers.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 8640538e4..e9e1715c1 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -429,7 +429,6 @@ def _reduce_loss(self, tr_loss: torch.Tensor) -> torch.Tensor: else: dp_size = xm.xrt_world_size() - tr_loss = tr_loss - self._prev_tr_loss tr_loss_div = tr_loss / dp_size if self.args.mp_plugin.should_parallelize: @@ -447,10 +446,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno reduced_tr_loss = self._reduce_loss(tr_loss) if self.control.should_log and self.state.global_step > self._globalstep_last_logged: - # xm.mark_step() - # reset tr_loss to zero - # tr_loss.zero_() - self._prev_tr_loss = tr_loss + if isinstance(getattr(self, "_zero_loss_value"), torch.Tensor): + tr_loss.data = self._zero_loss_value.data + else: + tr_loss.zero_() def log_closure(self, reduced_tr_loss, grad_norm): if is_main_worker_for_metrics(): @@ -892,12 +891,6 @@ def _inner_training_loop( # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0).to(args.device) - - # `_prev_tr_loss` is used to keep track of the previously saved loss. - # By doing that, we do not have to do `tr_loss.zero_()` when logging the loss. - # This way we do not create multiple graphs depending on the fact that we are logging or not. - self._prev_tr_loss = torch.tensor(0.0).to(args.device) - # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step @@ -1065,6 +1058,15 @@ def _inner_training_loop( self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + # `_zero_loss_value` is used to reset the value of `tr_loss`. + # By doing that, we do not have to do `tr_loss.zero_()` when logging the loss. + # This way we do not insert a new op in the XLA graph (for `tr_loss.zero_()`) which woud create + # multiple graphs depending on the fact that we are logging or not. + # Here we always create a scalar whose value is `0.0`, this way the graph stays the same whether or + # not we are logging. The only difference when logging is that we set + # `tr_loss.data = self._zero_loss_value.data`, which should not create new graph ops. + self._zero_loss_value = torch.tensor(0.0, device=args.device) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control)