Skip to content

Commit

Permalink
Remove loss_average argument of KFACLinearOperator (#117)
Browse files Browse the repository at this point in the history
* Remove loss_average argument and add FisherType and KFACType enums

* Support static type checkers with torch_matmat/vec

* Remove unused import

* Remove noqa: C901

* Fix error string
  • Loading branch information
runame authored Jun 12, 2024
1 parent d509cd6 commit 13b1082
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 217 deletions.
30 changes: 14 additions & 16 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implements linear operator inverses."""

from math import sqrt
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
from warnings import warn

from einops import einsum, rearrange
Expand All @@ -10,9 +10,11 @@
from torch import Tensor, cat, cholesky_inverse, eye, float64, outer
from torch.linalg import cholesky, eigh

from curvlinops.kfac import KFACLinearOperator
from curvlinops.kfac import KFACLinearOperator, ParameterMatrixType

KFAC_INV_TYPE = Union[Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]]
KFACInvType = TypeVar(
"KFACInvType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]
)


class _InverseLinearOperator(LinearOperator):
Expand Down Expand Up @@ -360,8 +362,8 @@ def __init__(
self._use_exact_damping = use_exact_damping
self._cache = cache
self._retry_double_precision = retry_double_precision
self._inverse_input_covariances: Dict[str, KFAC_INV_TYPE] = {}
self._inverse_gradient_covariances: Dict[str, KFAC_INV_TYPE] = {}
self._inverse_input_covariances: Dict[str, KFACInvType] = {}
self._inverse_gradient_covariances: Dict[str, KFACInvType] = {}

def _compute_damping(
self, aaT: Optional[Tensor], ggT: Optional[Tensor]
Expand Down Expand Up @@ -407,7 +409,7 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor:

def _compute_inverse_factors(
self, aaT: Optional[Tensor], ggT: Optional[Tensor]
) -> Tuple[KFAC_INV_TYPE, KFAC_INV_TYPE]:
) -> Tuple[KFACInvType, KFACInvType]:
"""Compute the inverses of the Kronecker factors for a given layer.
Args:
Expand Down Expand Up @@ -478,7 +480,7 @@ def _compute_inverse_factors(

def _compute_or_get_cached_inverse(
self, name: str
) -> Tuple[KFAC_INV_TYPE, KFAC_INV_TYPE]:
) -> Tuple[KFACInvType, KFACInvType]:
"""Invert the Kronecker factors of the KFACLinearOperator or retrieve them.
Args:
Expand All @@ -505,7 +507,7 @@ def _compute_or_get_cached_inverse(
return aaT_inv, ggT_inv

def _left_and_right_multiply(
self, M_joint: Tensor, aaT_inv: KFAC_INV_TYPE, ggT_inv: KFAC_INV_TYPE
self, M_joint: Tensor, aaT_inv: KFACInvType, ggT_inv: KFACInvType
) -> Tensor:
"""Left and right multiply matrix with inverse Kronecker factors.
Expand Down Expand Up @@ -541,8 +543,8 @@ def _separate_left_and_right_multiply(
self,
M_torch: Tensor,
param_pos: Dict[str, int],
aaT_inv: KFAC_INV_TYPE,
ggT_inv: KFAC_INV_TYPE,
aaT_inv: KFACInvType,
ggT_inv: KFACInvType,
) -> Tensor:
"""Multiply matrix with inverse Kronecker factors for separated weight and bias.
Expand Down Expand Up @@ -598,9 +600,7 @@ def _separate_left_and_right_multiply(

return M_torch

def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
"""Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch.
This allows for matrix-matrix products with the inverse KFAC approximation in
Expand Down Expand Up @@ -650,9 +650,7 @@ def torch_matmat(

return M_torch

def torch_matvec(
self, v_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
def torch_matvec(self, v_torch: ParameterMatrixType) -> ParameterMatrixType:
"""Apply the inverse of KFAC to a vector in PyTorch.
This allows for matrix-vector products with the inverse KFAC approximation in
Expand Down
Loading

0 comments on commit 13b1082

Please sign in to comment.