From 4d45c5cfd522e6e1afeb22d11ef00416ae14af1c Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 5 Jul 2024 11:56:03 +0200 Subject: [PATCH] Apply new way of resetting the loss --- optimum/neuron/trainers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 9b12d7b51..8640538e4 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -429,7 +429,7 @@ def _reduce_loss(self, tr_loss: torch.Tensor) -> torch.Tensor: else: dp_size = xm.xrt_world_size() - # tr_loss = tr_loss - self._prev_tr_loss + tr_loss = tr_loss - self._prev_tr_loss tr_loss_div = tr_loss / dp_size if self.args.mp_plugin.should_parallelize: @@ -892,8 +892,12 @@ def _inner_training_loop( # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0).to(args.device) - # `_prev_tr_loss` is used to keep track of the previously saved loss. This way we do not create multiple graphs when resetting it. + + # `_prev_tr_loss` is used to keep track of the previously saved loss. + # By doing that, we do not have to do `tr_loss.zero_()` when logging the loss. + # This way we do not create multiple graphs depending on the fact that we are logging or not. self._prev_tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step