Skip to content

Commit

Permalink
Add state dict functionality to (inverse) KFAC linear operator
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed May 15, 2024
1 parent 07ffeb2 commit f873494
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 3 deletions.
69 changes: 68 additions & 1 deletion curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implements linear operator inverses."""

from math import sqrt
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

from einops import einsum, rearrange
Expand Down Expand Up @@ -695,3 +695,70 @@ def _matmat(self, M: ndarray) -> ndarray:
M_torch = self._A._preprocess(M)
M_torch = self.torch_matmat(M_torch)
return self._A._postprocess(M_torch)

def state_dict(self) -> Dict[str, Any]:
"""Return the state of the inverse KFAC linear operator.
Returns:
State dictionary.
"""
return {
"A": self._A.state_dict(),
# Attributes
"damping": self._damping,
"use_heuristic_damping": self._use_heuristic_damping,
"min_damping": self._min_damping,
"use_exact_damping": self._use_exact_damping,
"cache": self._cache,
"retry_double_precision": self._retry_double_precision,
# Inverse Kronecker factors (if computed and cached)
"inverse_input_covariances": self._inverse_input_covariances,
"inverse_gradient_covariances": self._inverse_gradient_covariances,
}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the state of the inverse KFAC linear operator.
Args:
state_dict: State dictionary.
"""
self._A.load_state_dict(state_dict["A"])

# Set attributes
self._damping = state_dict["damping"]
self._use_heuristic_damping = state_dict["use_heuristic_damping"]
self._min_damping = state_dict["min_damping"]
self._use_exact_damping = state_dict["use_exact_damping"]
self._cache = state_dict["cache"]
self._retry_double_precision = state_dict["retry_double_precision"]

# Set inverse Kronecker factors (if computed and cached)
self._inverse_input_covariances = state_dict["inverse_input_covariances"]
self._inverse_gradient_covariances = state_dict["inverse_gradient_covariances"]

@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, Any],
A: KFACLinearOperator,
) -> "KFACInverseLinearOperator":
"""Load an inverse KFAC linear operator from a state dictionary.
Args:
state_dict: State dictionary.
A: ``KFACLinearOperator`` whose inverse is formed.
Returns:
Linear operator of inverse KFAC approximation.
"""
inv_kfac = cls(
A,
damping=state_dict["damping"],
use_heuristic_damping=state_dict["use_heuristic_damping"],
min_damping=state_dict["min_damping"],
use_exact_damping=state_dict["use_exact_damping"],
cache=state_dict["cache"],
retry_double_precision=state_dict["retry_double_precision"],
)
inv_kfac.load_state_dict(state_dict)
return inv_kfac
154 changes: 152 additions & 2 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from collections.abc import MutableMapping
from functools import partial
from math import sqrt
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from einops import einsum, rearrange, reduce
from numpy import ndarray
Expand Down Expand Up @@ -111,7 +111,7 @@ class KFACLinearOperator(_LinearOperator):
def __init__( # noqa: C901
self,
model_func: Module,
loss_func: MSELoss,
loss_func: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss],
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
progressbar: bool = False,
Expand Down Expand Up @@ -1070,3 +1070,153 @@ def frobenius_norm(self) -> Tensor:
)
self._frobenius_norm.sqrt_()
return self._frobenius_norm

def state_dict(self) -> Dict[str, Any]:
"""Return the state of the KFAC linear operator.
Returns:
State dictionary.
"""
loss_type = {
MSELoss: "MSELoss",
CrossEntropyLoss: "CrossEntropyLoss",
BCEWithLogitsLoss: "BCEWithLogitsLoss",
}[type(self._loss_func)]
return {
# Model and loss function
"model_func_state_dict": self._model_func.state_dict(),
"loss_type": loss_type,
"loss_reduction": self._loss_func.reduction,
# Attributes
"progressbar": self._progressbar,
"shape": self._shape,
"seed": self._seed,
"fisher_type": self._fisher_type,
"mc_samples": self._mc_samples,
"kfac_approx": self._kfac_approx,
"loss_average": self._loss_average,
"num_per_example_loss_terms": self._num_per_example_loss_terms,
"separate_weight_and_bias": self._separate_weight_and_bias,
"num_data": self._N_data,
# Kronecker factors (if computed)
"input_covariances": self._input_covariances,
"gradient_covariances": self._gradient_covariances,
# Properties (not necessarily computed)
"trace": self._trace,
"det": self._det,
"logdet": self._logdet,
"frobenius_norm": self._frobenius_norm,
}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the state of the KFAC linear operator.
Args:
state_dict: State dictionary.
Raises:
ValueError: If the loss function does not match the state dict.
ValueError: If the loss function reduction does not match the state dict.
"""
self._model_func.load_state_dict(state_dict["model_func_state_dict"])
# Verify that the loss function and its reduction match the state dict
loss_func_type = {
"MSELoss": MSELoss,
"CrossEntropyLoss": CrossEntropyLoss,
"BCEWithLogitsLoss": BCEWithLogitsLoss,
}[state_dict["loss_type"]]
if not isinstance(self._loss_func, loss_func_type):
raise ValueError(
f"Loss function mismatch: {loss_func_type} != {type(self._loss_func)}."
)
if state_dict["loss_reduction"] != self._loss_func.reduction:
raise ValueError(
"Loss function reduction mismatch: "
f"{state_dict['loss_reduction']} != {self._loss_func.reduction}."
)

# Set attributes
self._progressbar = state_dict["progressbar"]
self._shape = state_dict["shape"]
self._seed = state_dict["seed"]
self._fisher_type = state_dict["fisher_type"]
self._mc_samples = state_dict["mc_samples"]
self._kfac_approx = state_dict["kfac_approx"]
self._loss_average = state_dict["loss_average"]
self._num_per_example_loss_terms = state_dict["num_per_example_loss_terms"]
self._separate_weight_and_bias = state_dict["separate_weight_and_bias"]
self._N_data = state_dict["num_data"]

# Set Kronecker factors (if computed)
self._input_covariances = state_dict["input_covariances"]
self._gradient_covariances = state_dict["gradient_covariances"]

# Set properties (not necessarily computed)
self._trace = state_dict["trace"]
self._det = state_dict["det"]
self._logdet = state_dict["logdet"]
self._frobenius_norm = state_dict["frobenius_norm"]

@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, Any],
model_func: Module,
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
check_deterministic: bool = True,
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
) -> KFACLinearOperator:
"""Load a KFAC linear operator from a state dictionary.
Args:
state_dict: State dictionary.
model_func: The model function.
params: The model's parameters that KFAC is computed for.
data: A data loader containing the data of the Fisher/GGN.
check_deterministic: Whether to check that the linear operator is
deterministic. Defaults to ``True``.
batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this
needs to be specified. The intended behavior is to consume the first
entry of the iterates from ``data`` and return their batch size.
Returns:
Linear operator of KFAC approximation.
"""
loss_func = {
"MSELoss": MSELoss,
"CrossEntropyLoss": CrossEntropyLoss,
"BCEWithLogitsLoss": BCEWithLogitsLoss,
}[state_dict["loss_type"]](reduction=state_dict["loss_reduction"])
kfac = cls(
model_func,
loss_func,
params,
data,
batch_size_fn=batch_size_fn,
check_deterministic=False,
progressbar=state_dict["progressbar"],
shape=state_dict["shape"],
seed=state_dict["seed"],
fisher_type=state_dict["fisher_type"],
mc_samples=state_dict["mc_samples"],
kfac_approx=state_dict["kfac_approx"],
loss_average=state_dict["loss_average"],
num_per_example_loss_terms=state_dict["num_per_example_loss_terms"],
separate_weight_and_bias=state_dict["separate_weight_and_bias"],
num_data=state_dict["num_data"],
)
kfac.load_state_dict(state_dict)

# Potentially call `check_deterministic` after the state dict is loaded
if check_deterministic:
old_device = kfac._device
kfac.to_device(device("cpu"))
try:
kfac._check_deterministic()
except RuntimeError as e:
raise e
finally:
kfac.to_device(old_device)

return kfac

0 comments on commit f873494

Please sign in to comment.