From c8e92043b22a50f0eed6b4613e77b80641897d21 Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen <33333409+runame@users.noreply.github.com> Date: Sat, 11 May 2024 12:20:41 +0100 Subject: [PATCH] Fix scaling of MC Fisher for MSELoss and BCEWithLogitsLoss with mean reduction (#112) --- curvlinops/fisher.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/curvlinops/fisher.py b/curvlinops/fisher.py index 86a7bb8..c056586 100644 --- a/curvlinops/fisher.py +++ b/curvlinops/fisher.py @@ -225,7 +225,7 @@ def _matmat_batch( grad_output = self.sample_grad_output(output, self._mc_samples, y) - # Adjust the scale depending on the loss reduction used + # Adjust the scale depending on the loss function and reduction used num_loss_terms, C = output.shape reduction_factor = { "mean": ( @@ -266,7 +266,12 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten For a single data point, the would-be gradient's outer product equals the Hessian ``∇²_f log p(·|f)`` in expectation. - Currently only supports ``MSELoss`` and ``CrossEntropyLoss``. + Currently only supports ``MSELoss``, ``CrossEntropyLoss``, and + ``BCEWithLogitsLoss``. + + The returned gradient does not account for the scaling of the loss function by + the output dimension ``C`` that ``MSELoss`` and ``BCEWithLogitsLoss`` apply when + ``reduction='mean'``. Args: output: model prediction ``f`` for multiple data with batch axis as @@ -289,10 +294,7 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten C = output.shape[1] if isinstance(self._loss_func, MSELoss): - std = as_tensor( - {"mean": sqrt(0.5 / C), "sum": sqrt(0.5)}[self._loss_func.reduction], - device=output.device, - ) + std = as_tensor(sqrt(0.5), device=output.device) mean = zeros( num_samples, *output.shape, device=output.device, dtype=output.dtype ) @@ -320,12 +322,7 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten # repeat ``num_sample`` times along a new leading axis prob = prob.unsqueeze(0).expand(num_samples, -1, -1) sample = prob.bernoulli(generator=self._generator) - - # With ``reduction="mean"``, BCEWithLogitsLoss averages over all - # dimensions, like ``MSELoss``. We need to incorporate this scaling - # into the backpropagated gradient - scale = {"sum": 1.0, "mean": sqrt(1.0 / C)}[self._loss_func.reduction] - return (prob - sample) * scale + return prob - sample else: raise NotImplementedError(f"Supported losses: {self.supported_losses}")