Skip to content

Commit

Permalink
[FIX] Checks if Enum contains string (#119)
Browse files Browse the repository at this point in the history
* Add test for issue #118

* Add MetaEnum to get desired behavior of in-operator

* Add FisherType and KFACType to init to make them easy to import
  • Loading branch information
runame committed Jun 12, 2024
1 parent 13b1082 commit 92ccd9f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
5 changes: 4 additions & 1 deletion curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
NeumannInverseLinearOperator,
)
from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator
from curvlinops.kfac import KFACLinearOperator
from curvlinops.kfac import FisherType, KFACLinearOperator, KFACType
from curvlinops.norm.hutchinson import HutchinsonSquaredFrobeniusNormEstimator
from curvlinops.papyan2020traces.spectrum import (
LanczosApproximateLogSpectrumCached,
Expand All @@ -33,6 +33,9 @@
"KFACLinearOperator",
"JacobianLinearOperator",
"TransposedJacobianLinearOperator",
# Enums
"FisherType",
"KFACType",
# inversion
"CGInverseLinearOperator",
"LSMRInverseLinearOperator",
Expand Down
17 changes: 14 additions & 3 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import annotations

from collections.abc import MutableMapping
from enum import Enum
from enum import Enum, EnumMeta
from functools import partial
from math import sqrt
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -52,7 +52,18 @@
ParameterMatrixType = TypeVar("ParameterMatrixType", Tensor, List[Tensor])


class FisherType(str, Enum):
class MetaEnum(EnumMeta):
"""Metaclass for the Enum class for desired behavior of the `in` operator."""

def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True


class FisherType(str, Enum, metaclass=MetaEnum):
"""Enum for the Fisher type."""

TYPE2 = "type-2"
Expand All @@ -61,7 +72,7 @@ class FisherType(str, Enum):
FORWARD_ONLY = "forward-only"


class KFACType(str, Enum):
class KFACType(str, Enum, metaclass=MetaEnum):
"""Enum for the KFAC approximation type."""

EXPAND = "expand"
Expand Down
18 changes: 18 additions & 0 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,3 +1270,21 @@ def test_from_state_dict():
compare_state_dicts(kfac.state_dict(), kfac_new.state_dict())
test_vec = rand(kfac.shape[1])
report_nonclose(kfac @ test_vec, kfac_new @ test_vec)


@mark.parametrize("fisher_type", ["type-2", "mc", "empirical", "forward-only"])
@mark.parametrize("kfac_approx", ["expand", "reduce"])
def test_string_in_enum(fisher_type: str, kfac_approx: str):
"""Test whether checking if a string is contained in enum works.
To reproduce issue #118.
"""
model = Linear(2, 2)
KFACLinearOperator(
model,
MSELoss(),
list(model.parameters()),
[(rand(2, 2), rand(2, 2))],
fisher_type=fisher_type,
kfac_approx=kfac_approx,
)

0 comments on commit 92ccd9f

Please sign in to comment.