Skip to content

Commit

Permalink
Apply new way of resetting the loss
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jul 5, 2024
1 parent 4d45c5c commit ecbaf31
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ def _reduce_loss(self, tr_loss: torch.Tensor) -> torch.Tensor:
else:
dp_size = xm.xrt_world_size()

tr_loss = tr_loss - self._prev_tr_loss
tr_loss_div = tr_loss / dp_size

if self.args.mp_plugin.should_parallelize:
Expand All @@ -447,10 +446,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
reduced_tr_loss = self._reduce_loss(tr_loss)

if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
# xm.mark_step()
# reset tr_loss to zero
# tr_loss.zero_()
self._prev_tr_loss = tr_loss
if isinstance(getattr(self, "_zero_loss_value"), torch.Tensor):
tr_loss.data = self._zero_loss_value.data
else:
tr_loss.zero_()

def log_closure(self, reduced_tr_loss, grad_norm):
if is_main_worker_for_metrics():
Expand Down Expand Up @@ -892,12 +891,6 @@ def _inner_training_loop(

# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)

# `_prev_tr_loss` is used to keep track of the previously saved loss.
# By doing that, we do not have to do `tr_loss.zero_()` when logging the loss.
# This way we do not create multiple graphs depending on the fact that we are logging or not.
self._prev_tr_loss = torch.tensor(0.0).to(args.device)

# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
Expand Down Expand Up @@ -1065,6 +1058,15 @@ def _inner_training_loop(
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

# `_zero_loss_value` is used to reset the value of `tr_loss`.
# By doing that, we do not have to do `tr_loss.zero_()` when logging the loss.
# This way we do not insert a new op in the XLA graph (for `tr_loss.zero_()`) which woud create
# multiple graphs depending on the fact that we are logging or not.
# Here we always create a scalar whose value is `0.0`, this way the graph stays the same whether or
# not we are logging. The only difference when logging is that we set
# `tr_loss.data = self._zero_loss_value.data`, which should not create new graph ops.
self._zero_loss_value = torch.tensor(0.0, device=args.device)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Expand Down

0 comments on commit ecbaf31

Please sign in to comment.