diff --git a/curvlinops/__init__.py b/curvlinops/__init__.py index 0f8cbb3..9f4fafa 100644 --- a/curvlinops/__init__.py +++ b/curvlinops/__init__.py @@ -12,7 +12,7 @@ NeumannInverseLinearOperator, ) from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator -from curvlinops.kfac import KFACLinearOperator +from curvlinops.kfac import FisherType, KFACLinearOperator, KFACType from curvlinops.norm.hutchinson import HutchinsonSquaredFrobeniusNormEstimator from curvlinops.papyan2020traces.spectrum import ( LanczosApproximateLogSpectrumCached, @@ -33,6 +33,9 @@ "KFACLinearOperator", "JacobianLinearOperator", "TransposedJacobianLinearOperator", + # Enums + "FisherType", + "KFACType", # inversion "CGInverseLinearOperator", "LSMRInverseLinearOperator",