From 7422b0044749857d066d1897c8ca54f1a51187f7 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 21 Sep 2024 15:06:51 -0400 Subject: [PATCH] Fix docstring and error catching in test --- curvlinops/kfac.py | 4 ++++ test/test_kfac.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index d41ee0e..33d0d74 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -243,7 +243,11 @@ def __init__( Raises: RuntimeError: If the check for deterministic behavior fails. ValueError: If the loss function is not supported. + ValueError: If the Fisher type is not supported. + ValueError: If the KFAC approximation type is not supported. ValueError: If ``fisher_type != FisherType.MC`` and ``mc_samples != 1``. + NotImplementedError: If ``correct_eigenvalues`` and ``fisher_type == + FisherType.FORWARD_ONLY``. ValueError: If ``X`` is not a tensor and ``batch_size_fn`` is not specified. """ if not isinstance(loss_func, self._SUPPORTED_LOSSES): diff --git a/test/test_kfac.py b/test/test_kfac.py index 8d72972..33d25dc 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -512,7 +512,7 @@ def test_multi_dim_output( # KFAC for deep linear network with 4d input and output params = list(model.parameters()) context = ( - raises(ValueError, match="eigenvalues") + raises(NotImplementedError, match="eigenvalues") if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY else nullcontext() ) # EKFAC for FOOF is currently not supported