Skip to content

Commit

Permalink
Support static type checkers with torch_matmat/vec
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed May 29, 2024
1 parent ab72a67 commit fc692e2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
30 changes: 14 additions & 16 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implements linear operator inverses."""

from math import sqrt
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
from warnings import warn

from einops import einsum, rearrange
Expand All @@ -10,9 +10,11 @@
from torch import Tensor, cat, cholesky_inverse, eye, float64, outer
from torch.linalg import cholesky, eigh

from curvlinops.kfac import KFACLinearOperator
from curvlinops.kfac import KFACLinearOperator, ParameterMatrixType

KFAC_INV_TYPE = Union[Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]]
KFACInvType = TypeVar(
"KFACInvType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]
)


class _InverseLinearOperator(LinearOperator):
Expand Down Expand Up @@ -360,8 +362,8 @@ def __init__(
self._use_exact_damping = use_exact_damping
self._cache = cache
self._retry_double_precision = retry_double_precision
self._inverse_input_covariances: Dict[str, KFAC_INV_TYPE] = {}
self._inverse_gradient_covariances: Dict[str, KFAC_INV_TYPE] = {}
self._inverse_input_covariances: Dict[str, KFACInvType] = {}
self._inverse_gradient_covariances: Dict[str, KFACInvType] = {}

def _compute_damping(
self, aaT: Optional[Tensor], ggT: Optional[Tensor]
Expand Down Expand Up @@ -407,7 +409,7 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor:

def _compute_inverse_factors(
self, aaT: Optional[Tensor], ggT: Optional[Tensor]
) -> Tuple[KFAC_INV_TYPE, KFAC_INV_TYPE]:
) -> Tuple[KFACInvType, KFACInvType]:
"""Compute the inverses of the Kronecker factors for a given layer.
Args:
Expand Down Expand Up @@ -478,7 +480,7 @@ def _compute_inverse_factors(

def _compute_or_get_cached_inverse(
self, name: str
) -> Tuple[KFAC_INV_TYPE, KFAC_INV_TYPE]:
) -> Tuple[KFACInvType, KFACInvType]:
"""Invert the Kronecker factors of the KFACLinearOperator or retrieve them.
Args:
Expand All @@ -505,7 +507,7 @@ def _compute_or_get_cached_inverse(
return aaT_inv, ggT_inv

def _left_and_right_multiply(
self, M_joint: Tensor, aaT_inv: KFAC_INV_TYPE, ggT_inv: KFAC_INV_TYPE
self, M_joint: Tensor, aaT_inv: KFACInvType, ggT_inv: KFACInvType
) -> Tensor:
"""Left and right multiply matrix with inverse Kronecker factors.
Expand Down Expand Up @@ -541,8 +543,8 @@ def _separate_left_and_right_multiply(
self,
M_torch: Tensor,
param_pos: Dict[str, int],
aaT_inv: KFAC_INV_TYPE,
ggT_inv: KFAC_INV_TYPE,
aaT_inv: KFACInvType,
ggT_inv: KFACInvType,
) -> Tensor:
"""Multiply matrix with inverse Kronecker factors for separated weight and bias.
Expand Down Expand Up @@ -598,9 +600,7 @@ def _separate_left_and_right_multiply(

return M_torch

def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
"""Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch.
This allows for matrix-matrix products with the inverse KFAC approximation in
Expand Down Expand Up @@ -650,9 +650,7 @@ def torch_matmat(

return M_torch

def torch_matvec(
self, v_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
def torch_matvec(self, v_torch: ParameterMatrixType) -> ParameterMatrixType:
"""Apply the inverse of KFAC to a vector in PyTorch.
This allows for matrix-vector products with the inverse KFAC approximation in
Expand Down
17 changes: 9 additions & 8 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from enum import Enum
from functools import partial
from math import sqrt
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union

from einops import einsum, rearrange, reduce
from numpy import ndarray
Expand All @@ -46,6 +46,11 @@
loss_hessian_matrix_sqrt,
)

# Type for a matrix/vector that can be represented as a list of tensors with the same
# shape as the parameters, or a single matrix/vector of shape `[D, D]`/`[D]` where `D`
# is the number of parameters.
ParameterMatrixType = TypeVar("ParameterMatrixType", Tensor, List[Tensor])


class FisherType(str, Enum):
"""Enum for the Fisher type."""
Expand Down Expand Up @@ -342,7 +347,7 @@ def _torch_preprocess(self, M: Tensor) -> List[Tensor]:
return [res.T.reshape(shape) for res, shape in zip(result, shapes)]

def _check_input_type_and_preprocess(
self, M_torch: Union[Tensor, List[Tensor]]
self, M_torch: ParameterMatrixType
) -> Tuple[bool, List[Tensor]]:
"""Check input type and maybe preprocess to list format.
Expand Down Expand Up @@ -392,9 +397,7 @@ def _check_input_type_and_preprocess(
M_torch = self._torch_preprocess(M_torch)
return return_tensor, M_torch

def torch_matmat(
self, M_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
"""Apply KFAC to a matrix (multiple vectors) in PyTorch.
This allows for matrix-matrix products with the KFAC approximation in PyTorch
Expand Down Expand Up @@ -459,9 +462,7 @@ def torch_matmat(

return M_torch

def torch_matvec(
self, v_torch: Union[Tensor, List[Tensor]]
) -> Union[Tensor, List[Tensor]]:
def torch_matvec(self, v_torch: ParameterMatrixType) -> ParameterMatrixType:
"""Apply KFAC to a vector in PyTorch.
This allows for matrix-vector products with the KFAC approximation in PyTorch
Expand Down

0 comments on commit fc692e2

Please sign in to comment.