Skip to content

Commit

Permalink
Change the default estimation mode of the curvature estimators to `gg…
Browse files Browse the repository at this point in the history
…n_curvature_prop`

PiperOrigin-RevId: 663706387
  • Loading branch information
joeljennings authored and KfacJaxDev committed Aug 16, 2024
1 parent b51f93c commit 9623ab4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions kfac_jax/_src/curvature_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def __init__(
func: utils.Func,
params_index: int = 0,
batch_index: int = 1,
default_estimation_mode: str = "fisher_gradients",
default_estimation_mode: str = "ggn_curvature_prop",
):
"""Initializes the CurvatureEstimator instance.
Expand Down Expand Up @@ -968,7 +968,7 @@ def __init__(
func: The model function, which should have at least one registered loss.
default_estimation_mode: The estimation mode which to use by default when
calling ``self.update_curvature_matrix_estimate``. If ``None`` this will
be ``'fisher_gradients'``.
be ``'ggn_curvature_prop'``.
layer_tag_to_block_ctor: An optional dict mapping tags to specific classes
of block approximations, which to override the default ones.
index_to_block_ctor: An optional dict mapping a specific block parameter
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def __init__(

super().__init__(
func=func,
default_estimation_mode=default_estimation_mode or "fisher_gradients",
default_estimation_mode=default_estimation_mode or "ggn_curvature_prop",
**kwargs,
)

Expand Down Expand Up @@ -1759,7 +1759,7 @@ def retagged_func(params, *args):

super().__init__(
func=retagged_func,
default_estimation_mode=default_estimation_mode or "fisher_exact",
default_estimation_mode=default_estimation_mode or "ggn_curvature_prop",
layer_tag_to_block_ctor=layer_tag_to_block_ctor,
auto_register_tags=False,
**kwargs,
Expand Down
2 changes: 1 addition & 1 deletion kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def __init__(
matrix. See the documentation for :class:`~CurvatureEstimator` for a
detailed description of the possible options. If ``None`` will use
default estimation_mode mode of the used CurvatureEstimator subclass,
which is typically "fisher_gradients". (Default: ``None``)
which is typically "ggn_curvature_prop". (Default: ``None``)
custom_estimator_ctor: Optional constructor for subclass of
:class:`~BlockDiagonalCurvature`. If specified, the optimizer will use
this conastructor instead of the default
Expand Down

0 comments on commit 9623ab4

Please sign in to comment.