diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index 9ad4e43..8b49b24 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -1,7 +1,7 @@ """Implements linear operator inverses.""" from math import sqrt -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from warnings import warn from einops import einsum, rearrange @@ -695,3 +695,68 @@ def _matmat(self, M: ndarray) -> ndarray: M_torch = self._A._preprocess(M) M_torch = self.torch_matmat(M_torch) return self._A._postprocess(M_torch) + + def state_dict(self) -> Dict[str, Any]: + """Return the state of the inverse KFAC linear operator. + + Returns: + State dictionary. + """ + return { + "A": self._A.state_dict(), + # Attributes + "damping": self._damping, + "use_heuristic_damping": self._use_heuristic_damping, + "min_damping": self._min_damping, + "use_exact_damping": self._use_exact_damping, + "cache": self._cache, + "retry_double_precision": self._retry_double_precision, + # Inverse Kronecker factors (if computed and cached) + "inverse_input_covariances": self._inverse_input_covariances, + "inverse_gradient_covariances": self._inverse_gradient_covariances, + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Load the state of the inverse KFAC linear operator. + + Args: + state_dict: State dictionary. + """ + self._A.load_state_dict(state_dict["A"]) + + # Set attributes + self._damping = state_dict["damping"] + self._use_heuristic_damping = state_dict["use_heuristic_damping"] + self._min_damping = state_dict["min_damping"] + self._use_exact_damping = state_dict["use_exact_damping"] + self._cache = state_dict["cache"] + self._retry_double_precision = state_dict["retry_double_precision"] + + # Set inverse Kronecker factors (if computed and cached) + self._inverse_input_covariances = state_dict["inverse_input_covariances"] + self._inverse_gradient_covariances = state_dict["inverse_gradient_covariances"] + + @classmethod + def from_state_dict( + cls, state_dict: Dict[str, Any], A: KFACLinearOperator + ) -> "KFACInverseLinearOperator": + """Load an inverse KFAC linear operator from a state dictionary. + + Args: + state_dict: State dictionary. + A: ``KFACLinearOperator`` whose inverse is formed. + + Returns: + Linear operator of inverse KFAC approximation. + """ + inv_kfac = cls( + A, + damping=state_dict["damping"], + use_heuristic_damping=state_dict["use_heuristic_damping"], + min_damping=state_dict["min_damping"], + use_exact_damping=state_dict["use_exact_damping"], + cache=state_dict["cache"], + retry_double_precision=state_dict["retry_double_precision"], + ) + inv_kfac.load_state_dict(state_dict) + return inv_kfac diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 4909d8e..419ba79 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -21,7 +21,7 @@ from collections.abc import MutableMapping from functools import partial from math import sqrt -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from einops import einsum, rearrange, reduce from numpy import ndarray @@ -111,7 +111,7 @@ class KFACLinearOperator(_LinearOperator): def __init__( # noqa: C901 self, model_func: Module, - loss_func: MSELoss, + loss_func: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], params: List[Parameter], data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], progressbar: bool = False, @@ -1070,3 +1070,173 @@ def frobenius_norm(self) -> Tensor: ) self._frobenius_norm.sqrt_() return self._frobenius_norm + + def state_dict(self) -> Dict[str, Any]: + """Return the state of the KFAC linear operator. + + Returns: + State dictionary. + """ + loss_type = { + MSELoss: "MSELoss", + CrossEntropyLoss: "CrossEntropyLoss", + BCEWithLogitsLoss: "BCEWithLogitsLoss", + }[type(self._loss_func)] + return { + # Model and loss function + "model_func_state_dict": self._model_func.state_dict(), + "loss_type": loss_type, + "loss_reduction": self._loss_func.reduction, + # Attributes + "progressbar": self._progressbar, + "shape": self.shape, + "seed": self._seed, + "fisher_type": self._fisher_type, + "mc_samples": self._mc_samples, + "kfac_approx": self._kfac_approx, + "loss_average": self._loss_average, + "num_per_example_loss_terms": self._num_per_example_loss_terms, + "separate_weight_and_bias": self._separate_weight_and_bias, + "num_data": self._N_data, + # Kronecker factors (if computed) + "input_covariances": self._input_covariances, + "gradient_covariances": self._gradient_covariances, + # Properties (not necessarily computed) + "trace": self._trace, + "det": self._det, + "logdet": self._logdet, + "frobenius_norm": self._frobenius_norm, + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Load the state of the KFAC linear operator. + + Args: + state_dict: State dictionary. + + Raises: + ValueError: If the loss function does not match the state dict. + ValueError: If the loss function reduction does not match the state dict. + """ + self._model_func.load_state_dict(state_dict["model_func_state_dict"]) + # Verify that the loss function and its reduction match the state dict + loss_func_type = { + "MSELoss": MSELoss, + "CrossEntropyLoss": CrossEntropyLoss, + "BCEWithLogitsLoss": BCEWithLogitsLoss, + }[state_dict["loss_type"]] + if not isinstance(self._loss_func, loss_func_type): + raise ValueError( + f"Loss function mismatch: {loss_func_type} != {type(self._loss_func)}." + ) + if state_dict["loss_reduction"] != self._loss_func.reduction: + raise ValueError( + "Loss function reduction mismatch: " + f"{state_dict['loss_reduction']} != {self._loss_func.reduction}." + ) + + # Set attributes + self._progressbar = state_dict["progressbar"] + self.shape = state_dict["shape"] + self._seed = state_dict["seed"] + self._fisher_type = state_dict["fisher_type"] + self._mc_samples = state_dict["mc_samples"] + self._kfac_approx = state_dict["kfac_approx"] + self._loss_average = state_dict["loss_average"] + self._num_per_example_loss_terms = state_dict["num_per_example_loss_terms"] + self._separate_weight_and_bias = state_dict["separate_weight_and_bias"] + self._N_data = state_dict["num_data"] + + # Set Kronecker factors (if computed) + if self._input_covariances or self._gradient_covariances: + # If computed, check if the keys match the mapping keys + input_covariances_keys = set(self._input_covariances.keys()) + gradient_covariances_keys = set(self._gradient_covariances.keys()) + mapping_keys = set(self._mapping.keys()) + if ( + input_covariances_keys != mapping_keys + or gradient_covariances_keys != mapping_keys + ): + raise ValueError( + "Input or gradient covariance keys in state dict do not match " + "mapping keys of linear operator. " + "Difference between input covariance and mapping keys: " + f"{input_covariances_keys - mapping_keys}. " + "Difference between gradient covariance and mapping keys: " + f"{gradient_covariances_keys - mapping_keys}." + ) + self._input_covariances = state_dict["input_covariances"] + self._gradient_covariances = state_dict["gradient_covariances"] + + # Set properties (not necessarily computed) + self._trace = state_dict["trace"] + self._det = state_dict["det"] + self._logdet = state_dict["logdet"] + self._frobenius_norm = state_dict["frobenius_norm"] + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, Any], + model_func: Module, + params: List[Parameter], + data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], + check_deterministic: bool = True, + batch_size_fn: Optional[Callable[[MutableMapping], int]] = None, + ) -> KFACLinearOperator: + """Load a KFAC linear operator from a state dictionary. + + Args: + state_dict: State dictionary. + model_func: The model function. + params: The model's parameters that KFAC is computed for. + data: A data loader containing the data of the Fisher/GGN. + check_deterministic: Whether to check that the linear operator is + deterministic. Defaults to ``True``. + batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this + needs to be specified. The intended behavior is to consume the first + entry of the iterates from ``data`` and return their batch size. + + Returns: + Linear operator of KFAC approximation. + + Raises: + RuntimeError: If the check for deterministic behavior fails. + """ + loss_func = { + "MSELoss": MSELoss, + "CrossEntropyLoss": CrossEntropyLoss, + "BCEWithLogitsLoss": BCEWithLogitsLoss, + }[state_dict["loss_type"]](reduction=state_dict["loss_reduction"]) + kfac = cls( + model_func, + loss_func, + params, + data, + batch_size_fn=batch_size_fn, + check_deterministic=False, + progressbar=state_dict["progressbar"], + shape=state_dict["shape"], + seed=state_dict["seed"], + fisher_type=state_dict["fisher_type"], + mc_samples=state_dict["mc_samples"], + kfac_approx=state_dict["kfac_approx"], + loss_average=state_dict["loss_average"], + num_per_example_loss_terms=state_dict["num_per_example_loss_terms"], + separate_weight_and_bias=state_dict["separate_weight_and_bias"], + num_data=state_dict["num_data"], + ) + kfac.load_state_dict(state_dict) + + # Potentially call `check_deterministic` after the state dict is loaded + if check_deterministic: + old_device = kfac._device + kfac.to_device(device("cpu")) + try: + kfac._check_deterministic() + except RuntimeError as e: + raise e + finally: + kfac.to_device(old_device) + + return kfac diff --git a/test/test_inverse.py b/test/test_inverse.py index 136a65d..0a5e68c 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -1,7 +1,8 @@ """Contains tests for ``curvlinops/inverse``.""" +import os from math import sqrt -from test.utils import cast_input +from test.utils import cast_input, compare_state_dicts from typing import Iterable, List, Tuple, Union import torch @@ -654,3 +655,82 @@ def test_KFAC_inverse_damped_torch_matvec( # Test against _matmat report_nonclose(inv_KFAC @ x.cpu().numpy(), inv_KFAC_x.cpu().numpy()) + + +def test_KFAC_inverse_save_and_load_state_dict(): + """Test that KFACInverseLinearOperator can be saved and loaded from state dict.""" + torch.manual_seed(0) + batch_size, D_in, D_out = 4, 3, 2 + X = torch.rand(batch_size, D_in) + y = torch.rand(batch_size, D_out) + model = torch.nn.Linear(D_in, D_out) + + params = list(model.parameters()) + # create and compute KFAC + kfac = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + ) + + # create inverse KFAC + inv_kfac = KFACInverseLinearOperator( + kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False + ) + _ = inv_kfac @ eye(kfac.shape[1]) # to trigger inverse computation + + # save state dict + state_dict = inv_kfac.state_dict() + torch.save(state_dict, "inv_kfac_state_dict.pt") + + # create new inverse KFAC with different linop input and try to load state dict + wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)]) + inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac) + with raises(ValueError, match="mismatch"): + inv_kfac_wrong.load_state_dict(torch.load("inv_kfac_state_dict.pt")) + + # create new inverse KFAC and load state dict + inv_kfac_new = KFACInverseLinearOperator(kfac) + inv_kfac_new.load_state_dict(torch.load("inv_kfac_state_dict.pt")) + # clean up + os.remove("inv_kfac_state_dict.pt") + + # check that the two inverse KFACs are equal + compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) + test_vec = torch.rand(inv_kfac.shape[1]) + report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) + + +def test_KFAC_inverse_from_state_dict(): + """Test that KFACInverseLinearOperator can be created from state dict.""" + torch.manual_seed(0) + batch_size, D_in, D_out = 4, 3, 2 + X = torch.rand(batch_size, D_in) + y = torch.rand(batch_size, D_out) + model = torch.nn.Linear(D_in, D_out) + + params = list(model.parameters()) + # create and compute KFAC + kfac = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + ) + + # create inverse KFAC and save state dict + inv_kfac = KFACInverseLinearOperator( + kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False + ) + state_dict = inv_kfac.state_dict() + + # create new KFAC from state dict + inv_kfac_new = KFACInverseLinearOperator.from_state_dict(state_dict, kfac) + + # check that the two inverse KFACs are equal + compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) + test_vec = torch.rand(kfac.shape[1]) + report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) diff --git a/test/test_kfac.py b/test/test_kfac.py index 06394dd..a2b7a2f 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1,5 +1,6 @@ """Contains tests for ``curvlinops.kfac``.""" +import os from test.cases import DEVICES, DEVICES_IDS from test.utils import ( Conv2dModel, @@ -7,6 +8,7 @@ WeightShareModel, binary_classification_targets, classification_targets, + compare_state_dicts, ggn_block_diagonal, regression_targets, ) @@ -20,7 +22,7 @@ from scipy.linalg import block_diag from torch import Tensor, allclose, cat, cuda, device from torch import eye as torch_eye -from torch import isinf, isnan, manual_seed, rand, rand_like, randperm +from torch import isinf, isnan, load, manual_seed, rand, rand_like, randperm, save from torch.nn import ( BCEWithLogitsLoss, CrossEntropyLoss, @@ -1220,3 +1222,107 @@ def test_kfac_does_affect_grad(): # make sure gradients are unchanged for grad_before, p in zip(grads_before, params): assert allclose(grad_before, p.grad) + + +def test_save_and_load_state_dict(): + """Test that KFACLinearOperator can be saved and loaded from state dict.""" + manual_seed(0) + batch_size, D_in, D_out = 4, 3, 2 + X = rand(batch_size, D_in) + y = rand(batch_size, D_out) + model = Linear(D_in, D_out) + + params = list(model.parameters()) + # create and compute KFAC + kfac = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + ) + + # save state dict + state_dict = kfac.state_dict() + save(state_dict, "kfac_state_dict.pt") + + # create new KFAC with different loss function and try to load state dict + kfac_new = KFACLinearOperator( + model, + CrossEntropyLoss(), + params, + [(X, y)], + ) + with raises(ValueError, match="loss"): + kfac_new.load_state_dict(load("kfac_state_dict.pt")) + + # create new KFAC with different loss reduction and try to load state dict + kfac_new = KFACLinearOperator( + model, + MSELoss(), + params, + [(X, y)], + ) + with raises(ValueError, match="reduction"): + kfac_new.load_state_dict(load("kfac_state_dict.pt")) + + # create new KFAC with different model and try to load state dict + wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out)) + wrong_params = list(wrong_model.parameters()) + kfac_new = KFACLinearOperator( + wrong_model, + MSELoss(reduction="sum"), + wrong_params, + [(X, y)], + loss_average=None, + ) + with raises(RuntimeError, match="loading state_dict"): + kfac_new.load_state_dict(load("kfac_state_dict.pt")) + + # create new KFAC and load state dict + kfac_new = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + check_deterministic=False, # turn off to avoid computing KFAC again + ) + kfac_new.load_state_dict(load("kfac_state_dict.pt")) + # clean up + os.remove("kfac_state_dict.pt") + + # check that the two KFACs are equal + compare_state_dicts(kfac.state_dict(), kfac_new.state_dict()) + test_vec = rand(kfac.shape[1]) + report_nonclose(kfac @ test_vec, kfac_new @ test_vec) + + +def test_from_state_dict(): + """Test that KFACLinearOperator can be created from state dict.""" + manual_seed(0) + batch_size, D_in, D_out = 4, 3, 2 + X = rand(batch_size, D_in) + y = rand(batch_size, D_out) + model = Linear(D_in, D_out) + + params = list(model.parameters()) + # create and compute KFAC + kfac = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + ) + + # save state dict + state_dict = kfac.state_dict() + + # create new KFAC from state dict + kfac_new = KFACLinearOperator.from_state_dict(state_dict, model, params, [(X, y)]) + + # check that the two KFACs are equal + compare_state_dicts(kfac.state_dict(), kfac_new.state_dict()) + test_vec = rand(kfac.shape[1]) + report_nonclose(kfac @ test_vec, kfac_new @ test_vec) diff --git a/test/utils.py b/test/utils.py index 07e6b82..dcc656b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -7,7 +7,18 @@ from einops import rearrange, reduce from einops.layers.torch import Rearrange from numpy import eye, ndarray -from torch import Tensor, cat, cuda, device, dtype, from_numpy, rand, randint +from torch import ( + Tensor, + allclose, + as_tensor, + cat, + cuda, + device, + dtype, + from_numpy, + rand, + randint, +) from torch.nn import ( AdaptiveAvgPool2d, BCEWithLogitsLoss, @@ -367,3 +378,29 @@ def batch_size_fn(X: MutableMapping) -> int: batch_size: The first dimension size of the tensor. """ return X["x"].shape[0] + + +def compare_state_dicts(state_dict: dict, state_dict_new: dict): + """Compare two state dicts recursively. + + Args: + state_dict (dict): The first state dict to compare. + state_dict_new (dict): The second state dict to compare. + + Raises: + AssertionError: If the state dicts are not equal. + """ + assert len(state_dict) == len(state_dict_new) + for value, value_new in zip(state_dict.values(), state_dict_new.values()): + if isinstance(value, Tensor): + assert allclose(value, value_new) + elif isinstance(value, dict): + compare_state_dicts(value, value_new) + elif isinstance(value, tuple): + assert len(value) == len(value_new) + assert all(isinstance(v, type(v2)) for v, v2 in zip(value, value_new)) + assert all( + allclose(as_tensor(v), as_tensor(v2)) for v, v2 in zip(value, value_new) + ) + else: + assert value == value_new