diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index 8b49b24..7a2bb8b 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -1,7 +1,7 @@ """Implements linear operator inverses.""" from math import sqrt -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from warnings import warn from einops import einsum, rearrange @@ -10,9 +10,11 @@ from torch import Tensor, cat, cholesky_inverse, eye, float64, outer from torch.linalg import cholesky, eigh -from curvlinops.kfac import KFACLinearOperator +from curvlinops.kfac import KFACLinearOperator, ParameterMatrixType -KFAC_INV_TYPE = Union[Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]] +KFACInvType = TypeVar( + "KFACInvType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]] +) class _InverseLinearOperator(LinearOperator): @@ -360,8 +362,8 @@ def __init__( self._use_exact_damping = use_exact_damping self._cache = cache self._retry_double_precision = retry_double_precision - self._inverse_input_covariances: Dict[str, KFAC_INV_TYPE] = {} - self._inverse_gradient_covariances: Dict[str, KFAC_INV_TYPE] = {} + self._inverse_input_covariances: Dict[str, KFACInvType] = {} + self._inverse_gradient_covariances: Dict[str, KFACInvType] = {} def _compute_damping( self, aaT: Optional[Tensor], ggT: Optional[Tensor] @@ -407,7 +409,7 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor: def _compute_inverse_factors( self, aaT: Optional[Tensor], ggT: Optional[Tensor] - ) -> Tuple[KFAC_INV_TYPE, KFAC_INV_TYPE]: + ) -> Tuple[KFACInvType, KFACInvType]: """Compute the inverses of the Kronecker factors for a given layer. Args: @@ -478,7 +480,7 @@ def _compute_inverse_factors( def _compute_or_get_cached_inverse( self, name: str - ) -> Tuple[KFAC_INV_TYPE, KFAC_INV_TYPE]: + ) -> Tuple[KFACInvType, KFACInvType]: """Invert the Kronecker factors of the KFACLinearOperator or retrieve them. Args: @@ -505,7 +507,7 @@ def _compute_or_get_cached_inverse( return aaT_inv, ggT_inv def _left_and_right_multiply( - self, M_joint: Tensor, aaT_inv: KFAC_INV_TYPE, ggT_inv: KFAC_INV_TYPE + self, M_joint: Tensor, aaT_inv: KFACInvType, ggT_inv: KFACInvType ) -> Tensor: """Left and right multiply matrix with inverse Kronecker factors. @@ -541,8 +543,8 @@ def _separate_left_and_right_multiply( self, M_torch: Tensor, param_pos: Dict[str, int], - aaT_inv: KFAC_INV_TYPE, - ggT_inv: KFAC_INV_TYPE, + aaT_inv: KFACInvType, + ggT_inv: KFACInvType, ) -> Tensor: """Multiply matrix with inverse Kronecker factors for separated weight and bias. @@ -598,9 +600,7 @@ def _separate_left_and_right_multiply( return M_torch - def torch_matmat( - self, M_torch: Union[Tensor, List[Tensor]] - ) -> Union[Tensor, List[Tensor]]: + def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: """Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch. This allows for matrix-matrix products with the inverse KFAC approximation in @@ -650,9 +650,7 @@ def torch_matmat( return M_torch - def torch_matvec( - self, v_torch: Union[Tensor, List[Tensor]] - ) -> Union[Tensor, List[Tensor]]: + def torch_matvec(self, v_torch: ParameterMatrixType) -> ParameterMatrixType: """Apply the inverse of KFAC to a vector in PyTorch. This allows for matrix-vector products with the inverse KFAC approximation in diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 19a79a6..052a64c 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -22,7 +22,7 @@ from enum import Enum from functools import partial from math import sqrt -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union from einops import einsum, rearrange, reduce from numpy import ndarray @@ -46,6 +46,11 @@ loss_hessian_matrix_sqrt, ) +# Type for a matrix/vector that can be represented as a list of tensors with the same +# shape as the parameters, or a single matrix/vector of shape `[D, D]`/`[D]` where `D` +# is the number of parameters. +ParameterMatrixType = TypeVar("ParameterMatrixType", Tensor, List[Tensor]) + class FisherType(str, Enum): """Enum for the Fisher type.""" @@ -342,7 +347,7 @@ def _torch_preprocess(self, M: Tensor) -> List[Tensor]: return [res.T.reshape(shape) for res, shape in zip(result, shapes)] def _check_input_type_and_preprocess( - self, M_torch: Union[Tensor, List[Tensor]] + self, M_torch: ParameterMatrixType ) -> Tuple[bool, List[Tensor]]: """Check input type and maybe preprocess to list format. @@ -392,9 +397,7 @@ def _check_input_type_and_preprocess( M_torch = self._torch_preprocess(M_torch) return return_tensor, M_torch - def torch_matmat( - self, M_torch: Union[Tensor, List[Tensor]] - ) -> Union[Tensor, List[Tensor]]: + def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: """Apply KFAC to a matrix (multiple vectors) in PyTorch. This allows for matrix-matrix products with the KFAC approximation in PyTorch @@ -459,9 +462,7 @@ def torch_matmat( return M_torch - def torch_matvec( - self, v_torch: Union[Tensor, List[Tensor]] - ) -> Union[Tensor, List[Tensor]]: + def torch_matvec(self, v_torch: ParameterMatrixType) -> ParameterMatrixType: """Apply KFAC to a vector in PyTorch. This allows for matrix-vector products with the KFAC approximation in PyTorch