Skip to content

Commit

Permalink
Remove loss_average argument and add FisherType and KFACType enums
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed May 29, 2024
1 parent d509cd6 commit ab72a67
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 190 deletions.
137 changes: 54 additions & 83 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

from collections.abc import MutableMapping
from enum import Enum
from functools import partial
from math import sqrt
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -46,6 +47,22 @@
)


class FisherType(str, Enum):
"""Enum for the Fisher type."""

TYPE2 = "type-2"
MC = "mc"
EMPIRICAL = "empirical"
FORWARD_ONLY = "forward-only"


class KFACType(str, Enum):
"""Enum for the KFAC approximation type."""

EXPAND = "expand"
REDUCE = "reduce"


class KFACLinearOperator(_LinearOperator):
r"""Linear operator to multiply with the Fisher/GGN's KFAC approximation.
Expand Down Expand Up @@ -95,18 +112,8 @@ class KFACLinearOperator(_LinearOperator):

_SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss)
_SUPPORTED_MODULES = (Linear, Conv2d)
_SUPPORTED_LOSS_AVERAGE: Tuple[Union[None, str], ...] = (
None,
"batch",
"batch+sequence",
)
_SUPPORTED_FISHER_TYPE: Tuple[str, ...] = (
"type-2",
"mc",
"empirical",
"forward-only",
)
_SUPPORTED_KFAC_APPROX: Tuple[str, ...] = ("expand", "reduce")
_SUPPORTED_FISHER_TYPE: FisherType = FisherType
_SUPPORTED_KFAC_APPROX: KFACType = KFACType

def __init__( # noqa: C901
self,
Expand All @@ -118,10 +125,9 @@ def __init__( # noqa: C901
check_deterministic: bool = True,
shape: Union[Tuple[int, int], None] = None,
seed: int = 2147483647,
fisher_type: str = "mc",
fisher_type: str = FisherType.MC,
mc_samples: int = 1,
kfac_approx: str = "expand",
loss_average: Union[None, str] = "batch",
kfac_approx: str = KFACType.EXPAND,
num_per_example_loss_terms: Optional[int] = None,
separate_weight_and_bias: bool = True,
num_data: Optional[int] = None,
Expand Down Expand Up @@ -157,38 +163,31 @@ def __init__( # noqa: C901
from the parameters. Defaults to ``None``.
seed: The seed for the random number generator used to draw labels
from the model's predictive distribution. Defaults to ``2147483647``.
fisher_type: The type of Fisher/GGN to approximate. If 'type-2', the
exact Hessian of the loss w.r.t. the model outputs is used. This
requires as many backward passes as the output dimension, i.e.
the number of classes for classification. This is sometimes also
called type-2 Fisher. If ``'mc'``, the expectation is approximated
by sampling ``mc_samples`` labels from the model's predictive
distribution. If ``'empirical'``, the empirical gradients are
used which corresponds to the uncentered gradient covariance, or
the empirical Fisher. If ``'forward-only'``, the gradient covariances
will be identity matrices, see the FOOF method in
fisher_type: The type of Fisher/GGN to approximate.
If ``FisherType.TYPE2``, the exact Hessian of the loss w.r.t. the model
outputs is used. This requires as many backward passes as the output
dimension, i.e. the number of classes for classification. This is
sometimes also called type-2 Fisher. If ``FisherType.MC``, the
expectation is approximated by sampling ``mc_samples`` labels from the
model's predictive distribution. If ``FisherType.EMPIRICAL``, the
empirical gradients are used which corresponds to the uncentered
gradient covariance, or the empirical Fisher.
If ``FisherType.FORWARD_ONLY``, the gradient covariances will be
identity matrices, see the FOOF method in
`Benzing, 2022 <https://arxiv.org/abs/2201.12250>`_ or ISAAC in
`Petersen et al., 2023 <https://arxiv.org/abs/2305.00604>`_.
Defaults to ``'mc'``.
Defaults to ``FisherType.MC``.
mc_samples: The number of Monte-Carlo samples to use per data point.
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Has to be set to ``1`` when ``fisher_type != FisherType.MC``.
Defaults to ``1``.
kfac_approx: A string specifying the KFAC approximation that should
be used for linear weight-sharing layers, e.g. ``Conv2d`` modules
or ``Linear`` modules that process matrix- or higher-dimensional
features.
Possible values are ``'expand'`` and ``'reduce'``.
Possible values are ``KFACType.EXPAND`` and ``KFACType.REDUCE``.
See `Eschenhagen et al., 2023 <https://arxiv.org/abs/2311.00636>`_
for an explanation of the two approximations.
loss_average: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If ``"batch"``, the loss function is a mean over as many terms as
the size of the mini-batch. If ``"batch+sequence"``, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If ``None``, the loss function is a sum. This
argument is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: ``"batch"``.
Defaults to ``KFACType.EXPAND``.
num_per_example_loss_terms: Number of per-example loss terms, e.g., the
number of tokens in a sequence. The model outputs will have
``num_data * num_per_example_loss_terms * C`` entries, where ``C`` is
Expand All @@ -210,39 +209,19 @@ def __init__( # noqa: C901
Raises:
RuntimeError: If the check for deterministic behavior fails.
ValueError: If the loss function is not supported.
ValueError: If the loss average is not supported.
ValueError: If the loss average is ``None`` and the loss function's
reduction is not ``'sum'``.
ValueError: If the loss average is not ``None`` and the loss function's
reduction is ``'sum'``.
ValueError: If ``fisher_type != 'mc'`` and ``mc_samples != 1``.
ValueError: If ``fisher_type != FisherType.MC`` and ``mc_samples != 1``.
ValueError: If ``X`` is not a tensor and ``batch_size_fn`` is not specified.
"""
if not isinstance(loss_func, self._SUPPORTED_LOSSES):
raise ValueError(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)
if loss_average not in self._SUPPORTED_LOSS_AVERAGE:
raise ValueError(
f"Invalid loss_average: {loss_average}. "
f"Supported: {self._SUPPORTED_LOSS_AVERAGE}."
)
if loss_average is None and loss_func.reduction != "sum":
raise ValueError(
f"Invalid loss_average: {loss_average}. "
f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'."
)
if loss_func.reduction == "sum" and loss_average is not None:
raise ValueError(
f"Loss function uses reduction='sum', but loss_average={loss_average}."
" Set loss_average to None if you want to use reduction='sum'."
)
if fisher_type not in self._SUPPORTED_FISHER_TYPE:
raise ValueError(
f"Invalid fisher_type: {fisher_type}. "
f"Supported: {self._SUPPORTED_FISHER_TYPE}."
)
if fisher_type != "mc" and mc_samples != 1:
if fisher_type != FisherType.MC and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
"Only mc_samples=1 is supported for fisher_type != 'mc'."
Expand All @@ -259,7 +238,6 @@ def __init__( # noqa: C901
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._kfac_approx = kfac_approx
self._loss_average = loss_average
self._input_covariances: Dict[str, Tensor] = {}
self._gradient_covariances: Dict[str, Tensor] = {}
self._mapping = self.compute_parameter_mapping(params, model_func)
Expand Down Expand Up @@ -613,8 +591,9 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
y: The labels :math:`\{\mathbf{y}_n\}_{n=1}^N`.
Raises:
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
ValueError: If ``fisher_type`` is not ``FisherType.TYPE2``,
``FisherType.MC``, ``FisherType.EMPIRICAL``, or
``FisherType.FORWARD_ONLY``.
"""
# if >2d output we convert to an equivalent 2d output
if isinstance(self._loss_func, CrossEntropyLoss):
Expand All @@ -624,7 +603,7 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
output = rearrange(output, "batch ... c -> (batch ...) c")
y = rearrange(y, "batch ... c -> (batch ...) c")

if self._fisher_type == "type-2":
if self._fisher_type == FisherType.TYPE2:
# Compute per-sample Hessian square root, then concatenate over samples.
# Result has shape `(batch_size, num_classes, num_classes)`
hessian_sqrts = stack(
Expand All @@ -651,19 +630,19 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
retain_graph=c < num_cols - 1,
)

elif self._fisher_type == "mc":
elif self._fisher_type == FisherType.MC:
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss = self._maybe_adjust_loss_scale(loss, output)
grad(loss, self._params, retain_graph=mc != self._mc_samples - 1)

elif self._fisher_type == "empirical":
elif self._fisher_type == FisherType.EMPIRICAL:
loss = self._loss_func(output, y)
loss = self._maybe_adjust_loss_scale(loss, output)
grad(loss, self._params)

elif self._fisher_type == "forward-only":
elif self._fisher_type == FisherType.FORWARD_ONLY:
# Since FOOF sets the gradient covariance Kronecker factors to the identity,
# we don't need to do a backward pass. See https://arxiv.org/abs/2201.12250.
# We choose to set the gradient covariance to the identity explicitly for
Expand Down Expand Up @@ -781,26 +760,21 @@ def _accumulate_gradient_covariance(
if isinstance(module, Conv2d):
g = rearrange(g, "batch c o1 o2 -> batch o1 o2 c")

if self._kfac_approx == "expand":
if self._kfac_approx == KFACType.EXPAND:
# KFAC-expand approximation
g = rearrange(g, "batch ... d_out -> (batch ...) d_out")
else:
# KFAC-reduce approximation
g = reduce(g, "batch ... d_out -> batch d_out", "sum")

# Compute correction for the loss scaling depending on the loss reduction used
num_loss_terms = {
None: batch_size,
"batch": batch_size,
"batch+sequence": batch_size * self._num_per_example_loss_terms,
}[self._loss_average]
# self._mc_samples will be 1 if fisher_type != "mc"
num_loss_terms = batch_size * self._num_per_example_loss_terms
# self._mc_samples will be 1 if fisher_type != FisherType.MC
correction = {
None: 1.0 / self._mc_samples,
"batch": num_loss_terms**2 / (self._N_data * self._mc_samples),
"batch+sequence": num_loss_terms**2
"sum": 1.0 / self._mc_samples,
"mean": num_loss_terms**2
/ (self._N_data * self._mc_samples * self._num_per_example_loss_terms),
}[self._loss_average]
}[self._loss_func.reduction]

covariance = einsum(g, g, "b i,b j->i j").mul_(correction)

Expand Down Expand Up @@ -830,8 +804,8 @@ def _hook_accumulate_input_covariance(

if isinstance(module, Conv2d):
patch_extractor_fn = {
"expand": extract_patches,
"reduce": extract_averaged_patches,
KFACType.EXPAND: extract_patches,
KFACType.REDUCE: extract_averaged_patches,
}[self._kfac_approx]
x = patch_extractor_fn(
x,
Expand All @@ -842,7 +816,7 @@ def _hook_accumulate_input_covariance(
module.groups,
)

if self._kfac_approx == "expand":
if self._kfac_approx == KFACType.EXPAND:
# KFAC-expand approximation
scale = x.shape[1:-1].numel() # sequence length
x = rearrange(x, "batch ... d_in -> (batch ...) d_in")
Expand Down Expand Up @@ -1094,7 +1068,6 @@ def state_dict(self) -> Dict[str, Any]:
"fisher_type": self._fisher_type,
"mc_samples": self._mc_samples,
"kfac_approx": self._kfac_approx,
"loss_average": self._loss_average,
"num_per_example_loss_terms": self._num_per_example_loss_terms,
"separate_weight_and_bias": self._separate_weight_and_bias,
"num_data": self._N_data,
Expand Down Expand Up @@ -1142,7 +1115,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
self._fisher_type = state_dict["fisher_type"]
self._mc_samples = state_dict["mc_samples"]
self._kfac_approx = state_dict["kfac_approx"]
self._loss_average = state_dict["loss_average"]
self._num_per_example_loss_terms = state_dict["num_per_example_loss_terms"]
self._separate_weight_and_bias = state_dict["separate_weight_and_bias"]
self._N_data = state_dict["num_data"]
Expand Down Expand Up @@ -1221,7 +1193,6 @@ def from_state_dict(
fisher_type=state_dict["fisher_type"],
mc_samples=state_dict["mc_samples"],
kfac_approx=state_dict["kfac_approx"],
loss_average=state_dict["loss_average"],
num_per_example_loss_terms=state_dict["num_per_example_loss_terms"],
separate_weight_and_bias=state_dict["separate_weight_and_bias"],
num_data=state_dict["num_data"],
Expand Down
10 changes: 8 additions & 2 deletions test/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
)
from torch.utils.data import DataLoader, TensorDataset

from curvlinops.kfac import KFACType

DEVICES = get_available_devices()
DEVICES_IDS = [f"dev={d}" for d in DEVICES]

Expand Down Expand Up @@ -163,7 +165,9 @@ def forward(self, data: MutableMapping):
# softmax cross-entropy loss with additional input/output dimension
{
"model_func": lambda: WeightShareModel(
Sequential(Linear(10, 5), ReLU(), Linear(5, 3)), setting="expand", loss="CE"
Sequential(Linear(10, 5), ReLU(), Linear(5, 3)),
setting=KFACType.EXPAND,
loss="CE",
),
"loss_func": lambda: CrossEntropyLoss(reduction="mean"),
"data": lambda: [
Expand All @@ -175,7 +179,9 @@ def forward(self, data: MutableMapping):
# same as above, but uses reduction='sum'
{
"model_func": lambda: WeightShareModel(
Sequential(Linear(10, 5), ReLU(), Linear(5, 3)), setting="expand", loss="CE"
Sequential(Linear(10, 5), ReLU(), Linear(5, 3)),
setting=KFACType.EXPAND,
loss="CE",
),
"loss_func": lambda: CrossEntropyLoss(reduction="sum"),
"data": lambda: [
Expand Down
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def initialize_case(
params = [p for p in model_func.parameters() if p.requires_grad]
data = case["data"]()

# In some KFAC cases, ``data = {"expand": [(X, y), ...], "reduce": [(X, y), ...]}``
# In some KFAC cases,
# ``data = {KFACType.EXPAND: [(X, y), ...], KFACType.REDUCE: [(X, y), ...]}``
# unlike the standard ``data = [(X: Tensor | MutableMapping, y), ...]``.
# We ignore the former since the latter is included in KFAC cases, and thus the
# feature of ``MutableMapping`` inputs is sufficiently covered already.
Expand Down
18 changes: 10 additions & 8 deletions test/kfac_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
Sequential,
)

from curvlinops.kfac import KFACType

# Add test cases here, devices and loss function with different reductions will be
# added automatically below
KFAC_EXACT_CASES_NO_DEVICE_NO_LOSS_FUNC = [
Expand Down Expand Up @@ -79,11 +81,11 @@
{
"model_func": lambda: WeightShareModel(Linear(5, 4), Linear(4, 3)),
"data": lambda: {
"expand": [
KFACType.EXPAND: [
(rand(2, 4, 8, 5), regression_targets((2, 4, 8, 3))),
(rand(7, 4, 8, 5), regression_targets((7, 4, 8, 3))),
],
"reduce": [
KFACType.REDUCE: [
(rand(1, 4, 8, 5), regression_targets((1, 3))),
(rand(7, 4, 8, 5), regression_targets((7, 3))),
],
Expand All @@ -94,11 +96,11 @@
{
"model_func": lambda: Conv2dModel(),
"data": lambda: {
"expand": [
KFACType.EXPAND: [
(rand(2, 3, 32, 32), regression_targets((2, 33, 33, 2))),
(rand(7, 3, 32, 32), regression_targets((7, 33, 33, 2))),
],
"reduce": [
KFACType.REDUCE: [
(rand(1, 3, 32, 32), regression_targets((1, 2))),
(rand(8, 3, 32, 32), regression_targets((8, 2))),
],
Expand Down Expand Up @@ -210,11 +212,11 @@
{
"model_func": lambda: WeightShareModel(Linear(5, 3)),
"data": lambda: {
"expand": [
KFACType.EXPAND: [
(rand(7, 4, 8, 5), regression_targets((7, 4, 8, 3))),
(rand(7, 4, 8, 5), regression_targets((7, 4, 8, 3))),
],
"reduce": [
KFACType.REDUCE: [
(rand(8, 4, 8, 5), regression_targets((8, 3))),
(rand(8, 4, 8, 5), regression_targets((8, 3))),
],
Expand All @@ -225,11 +227,11 @@
{
"model_func": lambda: Conv2dModel(),
"data": lambda: {
"expand": [
KFACType.EXPAND: [
(rand(7, 3, 32, 32), regression_targets((7, 33, 33, 2))),
(rand(7, 3, 32, 32), regression_targets((7, 33, 33, 2))),
],
"reduce": [
KFACType.REDUCE: [
(rand(8, 3, 32, 32), regression_targets((8, 2))),
(rand(8, 3, 32, 32), regression_targets((8, 2))),
],
Expand Down
Loading

0 comments on commit ab72a67

Please sign in to comment.