Skip to content

Commit

Permalink
[REF] Rename num_classes into output_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 9, 2023
1 parent 89c2dfc commit 84e5ccb
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions curvlinops/kfac_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def loss_hessian_matrix_sqrt(
f"{output_one_datum.shape}"
)
output = output_one_datum.squeeze(0)
num_classes = output.numel()
output_dim = output.numel()

if isinstance(loss_func, MSELoss):
c = {"sum": 1.0, "mean": 1.0 / num_classes}[loss_func.reduction]
return eye(num_classes, device=output.device, dtype=output.dtype).mul_(
c = {"sum": 1.0, "mean": 1.0 / output_dim}[loss_func.reduction]
return eye(output_dim, device=output.device, dtype=output.dtype).mul_(
sqrt(2 * c)
)
elif isinstance(loss_func, CrossEntropyLoss):
Expand Down

0 comments on commit 84e5ccb

Please sign in to comment.