From 9623ab4f56a0921c11cd38576360863ad7cb11b4 Mon Sep 17 00:00:00 2001 From: Joel Jennings Date: Fri, 16 Aug 2024 06:06:40 -0700 Subject: [PATCH] Change the default estimation mode of the curvature estimators to `ggn_curvature_prop` PiperOrigin-RevId: 663706387 --- kfac_jax/_src/curvature_estimator.py | 8 ++++---- kfac_jax/_src/optimizer.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/kfac_jax/_src/curvature_estimator.py b/kfac_jax/_src/curvature_estimator.py index e663de0..fed38d9 100644 --- a/kfac_jax/_src/curvature_estimator.py +++ b/kfac_jax/_src/curvature_estimator.py @@ -624,7 +624,7 @@ def __init__( func: utils.Func, params_index: int = 0, batch_index: int = 1, - default_estimation_mode: str = "fisher_gradients", + default_estimation_mode: str = "ggn_curvature_prop", ): """Initializes the CurvatureEstimator instance. @@ -968,7 +968,7 @@ def __init__( func: The model function, which should have at least one registered loss. default_estimation_mode: The estimation mode which to use by default when calling ``self.update_curvature_matrix_estimate``. If ``None`` this will - be ``'fisher_gradients'``. + be ``'ggn_curvature_prop'``. layer_tag_to_block_ctor: An optional dict mapping tags to specific classes of block approximations, which to override the default ones. index_to_block_ctor: An optional dict mapping a specific block parameter @@ -1000,7 +1000,7 @@ def __init__( super().__init__( func=func, - default_estimation_mode=default_estimation_mode or "fisher_gradients", + default_estimation_mode=default_estimation_mode or "ggn_curvature_prop", **kwargs, ) @@ -1759,7 +1759,7 @@ def retagged_func(params, *args): super().__init__( func=retagged_func, - default_estimation_mode=default_estimation_mode or "fisher_exact", + default_estimation_mode=default_estimation_mode or "ggn_curvature_prop", layer_tag_to_block_ctor=layer_tag_to_block_ctor, auto_register_tags=False, **kwargs, diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index 66dbe97..77e7b98 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -292,7 +292,7 @@ def __init__( matrix. See the documentation for :class:`~CurvatureEstimator` for a detailed description of the possible options. If ``None`` will use default estimation_mode mode of the used CurvatureEstimator subclass, - which is typically "fisher_gradients". (Default: ``None``) + which is typically "ggn_curvature_prop". (Default: ``None``) custom_estimator_ctor: Optional constructor for subclass of :class:`~BlockDiagonalCurvature`. If specified, the optimizer will use this conastructor instead of the default