diff --git a/examples/optax_wrapper.py b/examples/optax_wrapper.py index 9e8dd37..b8c3655 100644 --- a/examples/optax_wrapper.py +++ b/examples/optax_wrapper.py @@ -16,7 +16,6 @@ import jax import kfac_jax -from examples import optax_preconditioner import optax @@ -29,8 +28,8 @@ ScheduleType = kfac_jax.optimizer.ScheduleType OptaxCtor = Callable[[ScheduleType], optax.GradientTransformation] -PreconditionState = optax_preconditioner.PreconditionState -Preconditioner = optax_preconditioner.Preconditioner +PreconditionState = kfac_jax.OptaxPreconditionState +Preconditioner = kfac_jax.OptaxPreconditioner class OptaxAndPreconditionState(NamedTuple): diff --git a/kfac_jax/__init__.py b/kfac_jax/__init__.py index 64a988d..bc201ce 100644 --- a/kfac_jax/__init__.py +++ b/kfac_jax/__init__.py @@ -112,6 +112,8 @@ curvature_estimator.set_default_tag_to_block_ctor) get_default_tag_to_block_ctor = ( curvature_estimator.get_default_tag_to_block_ctor) +OptaxPreconditioner = curvature_estimator.OptaxPreconditioner +OptaxPreconditionState = curvature_estimator.OptaxPreconditionState # Optimizers Optimizer = optimizer.Optimizer diff --git a/kfac_jax/_src/curvature_estimator/__init__.py b/kfac_jax/_src/curvature_estimator/__init__.py index 6dc2c66..8189108 100644 --- a/kfac_jax/_src/curvature_estimator/__init__.py +++ b/kfac_jax/_src/curvature_estimator/__init__.py @@ -52,6 +52,7 @@ from kfac_jax._src.curvature_estimator import curvature_estimator from kfac_jax._src.curvature_estimator import explicit_exact from kfac_jax._src.curvature_estimator import implicit_exact +from kfac_jax._src.curvature_estimator import optax_interface BlockDiagonalCurvature = block_diagonal.BlockDiagonalCurvature @@ -75,3 +76,6 @@ LossFunctionInputs = implicit_exact.LossFunctionInputs LossFunctionInputsSequence = implicit_exact.LossFunctionInputsSequence LossFunctionInputsTuple = implicit_exact.LossFunctionInputsTuple + +OptaxPreconditioner = optax_interface.OptaxPreconditioner +OptaxPreconditionState = optax_interface.OptaxPreconditionState diff --git a/examples/optax_preconditioner.py b/kfac_jax/_src/curvature_estimator/optax_interface.py similarity index 87% rename from examples/optax_preconditioner.py rename to kfac_jax/_src/curvature_estimator/optax_interface.py index d1d34dd..a0ebc03 100644 --- a/examples/optax_preconditioner.py +++ b/kfac_jax/_src/curvature_estimator/optax_interface.py @@ -19,27 +19,29 @@ import jax from jax import lax import jax.numpy as jnp -import kfac_jax +from kfac_jax._src import utils +from kfac_jax._src.curvature_estimator import block_diagonal +from kfac_jax._src.curvature_estimator import curvature_estimator import optax -Array = kfac_jax.utils.Array -Numeric = kfac_jax.utils.Numeric -PRNGKey = kfac_jax.utils.PRNGKey -Params = kfac_jax.utils.Params -Batch = kfac_jax.utils.Batch -ValueFunc = kfac_jax.optimizer.ValueFunc -FuncArgsVariants = kfac_jax.optimizer.FuncArgsVariants -ScheduleType = kfac_jax.optimizer.ScheduleType -EstimatorState = kfac_jax.curvature_estimator.BlockDiagonalCurvature.State +Array = utils.Array +Numeric = utils.Numeric +PRNGKey = utils.PRNGKey +Params = utils.Params +Batch = utils.Batch +ValueFunc = utils.ValueFunc +FuncArgs = utils.FuncArgs +ScheduleType = utils.ScheduleType +EstimatorState = block_diagonal.BlockDiagonalCurvature.State -class PreconditionState(NamedTuple): +class OptaxPreconditionState(NamedTuple): count: Array estimator_state: EstimatorState -class Preconditioner: +class OptaxPreconditioner: """An Optax-compatible K-FAC preconditioner.""" def __init__( @@ -59,12 +61,12 @@ def __init__( patterns_to_skip: Sequence[str] = (), auto_register_kwargs: dict[str, Any] | None = None, layer_tag_to_block_ctor: ( - dict[str, kfac_jax.curvature_estimator.CurvatureBlockCtor] | None + dict[str, curvature_estimator.CurvatureBlockCtor] | None ) = None, pmap_axis_name: str = "kfac_axis", batch_size_extractor: Callable[ [Batch], Numeric - ] = kfac_jax.utils.default_batch_size_extractor, + ] = utils.default_batch_size_extractor, distributed_inverses: bool = True, distributed_precon_apply: bool = True, num_samples: int = 1, @@ -168,29 +170,32 @@ def __init__( norm_to_scale_identity_weight_per_block ) + auto_register_kwargs = auto_register_kwargs or {} + auto_register_kwargs.update(dict( + register_only_generic=register_only_generic, + patterns_to_skip=patterns_to_skip, + )) # Curvature estimator - self._estimator = kfac_jax.curvature_estimator.BlockDiagonalCurvature( + self._estimator = block_diagonal.BlockDiagonalCurvature( func=value_func, default_estimation_mode=estimation_mode, params_index=0, layer_tag_to_block_ctor=layer_tag_to_block_ctor, - register_only_generic=register_only_generic, - patterns_to_skip=patterns_to_skip, distributed_multiplies=distributed_precon_apply, distributed_cache_updates=distributed_inverses, num_samples=num_samples, should_vmap_samples=should_vmap_samples, - **(auto_register_kwargs or {}), + auto_register_kwargs=auto_register_kwargs, ) def init( self, - func_args: FuncArgsVariants, + func_args: FuncArgs, rng: PRNGKey, - ) -> PreconditionState: + ) -> OptaxPreconditionState: """Initializes the preconditioner and returns the state.""" - return PreconditionState( + return OptaxPreconditionState( count=jnp.array(0, dtype=jnp.int32), estimator_state=self.estimator.init( rng=rng, @@ -202,23 +207,19 @@ def init( ) @property - def _exact_powers_to_cache(self) -> int | Sequence[int] | None: - + def _exact_powers_to_cache(self) -> int | None: if self._use_exact_inverses and self._use_cached_inverses: return -1 - else: - return None + return None @property - def _approx_powers_to_cache(self) -> int | Sequence[int] | None: - + def _approx_powers_to_cache(self) -> int | None: if not self._use_exact_inverses and self._use_cached_inverses: return -1 - else: - return None + return None @property - def estimator(self) -> kfac_jax.curvature_estimator.BlockDiagonalCurvature: + def estimator(self) -> block_diagonal.BlockDiagonalCurvature: """The underlying curvature estimator used by the preconditioner.""" return self._estimator @@ -227,7 +228,7 @@ def pmap_axis_name(self): return self._pmap_axis_name def get_identity_weight( - self, state: PreconditionState + self, state: OptaxPreconditionState ) -> Array | float: damping = self._damping @@ -239,18 +240,18 @@ def get_identity_weight( def sync_estimator_state( self, - state: PreconditionState, - ) -> PreconditionState: + state: OptaxPreconditionState, + ) -> OptaxPreconditionState: """Syncs the estimator state.""" - return PreconditionState( + return OptaxPreconditionState( count=state.count, estimator_state=self.estimator.sync( state.estimator_state, pmap_axis_name=self.pmap_axis_name), ) def should_update_estimator_curvature( - self, state: PreconditionState + self, state: OptaxPreconditionState ) -> Array | bool: """Whether at the current step the preconditioner should update the curvature estimates.""" @@ -260,7 +261,7 @@ def should_update_estimator_curvature( return state.count % self._curvature_update_period == 0 def should_sync_estimate_curvature( - self, state: PreconditionState + self, state: OptaxPreconditionState ) -> Array | bool: """Whether at the current step the preconditioner should synchronize (pmean) the curvature estimates.""" @@ -272,7 +273,7 @@ def should_sync_estimate_curvature( return self.should_update_inverse_cache(state) def should_update_inverse_cache( - self, state: PreconditionState + self, state: OptaxPreconditionState ) -> Array | bool: """Whether at the current step the preconditioner should update the inverse cache.""" @@ -283,10 +284,10 @@ def should_update_inverse_cache( def maybe_update( self, - state: PreconditionState, - func_args: FuncArgsVariants, + state: OptaxPreconditionState, + func_args: FuncArgs, rng: PRNGKey, - ) -> PreconditionState: + ) -> OptaxPreconditionState: """Updates the estimates if it is the right iteration.""" # NOTE: This maybe update curvatures and inverses at an iteration. But @@ -303,12 +304,12 @@ def maybe_update( state = self.maybe_update_inverse_cache(state) - return PreconditionState(state.count, state.estimator_state) + return OptaxPreconditionState(state.count, state.estimator_state) def _update_estimator_curvature( self, estimator_state: EstimatorState, - func_args: FuncArgsVariants, + func_args: FuncArgs, rng: PRNGKey, ema_old: Numeric, ema_new: Numeric, @@ -322,6 +323,7 @@ def _update_estimator_curvature( ema_new=ema_new, # Note that the batch is always the last entry of FuncArgsVariantsdef batch_size=self._batch_size_extractor(func_args[-1]), + identity_weight=self.get_identity_weight(estimator_state), rng=rng, func_args=func_args, ) @@ -336,12 +338,12 @@ def _update_estimator_curvature( def maybe_update_estimator_curvature( self, - state: PreconditionState, - func_args: FuncArgsVariants, + state: OptaxPreconditionState, + func_args: FuncArgs, rng: PRNGKey, decay_old_ema: Array | bool = True, sync: Array | bool = True, - ) -> PreconditionState: + ) -> OptaxPreconditionState: """Updates the curvature estimates if it is the right iteration.""" ema_old = decay_old_ema * self._curvature_ema + (1.0 - decay_old_ema) * 1.0 @@ -359,8 +361,8 @@ def maybe_update_estimator_curvature( def maybe_update_inverse_cache( self, - state: PreconditionState, - ) -> PreconditionState: + state: OptaxPreconditionState, + ) -> OptaxPreconditionState: """Updates the estimator state cache if it is the right iteration.""" if state.count is None: @@ -383,11 +385,11 @@ def maybe_update_inverse_cache( def _maybe_update_estimator_state( self, - state: PreconditionState, + state: OptaxPreconditionState, should_update: Array | bool, update_func: Callable[..., EstimatorState], **update_func_kwargs, - ) -> PreconditionState: + ) -> OptaxPreconditionState: """Updates the estimator state if it should update.""" estimator_state = lax.cond( @@ -397,12 +399,12 @@ def _maybe_update_estimator_state( state.estimator_state, ) - return PreconditionState(state.count, estimator_state) + return OptaxPreconditionState(state.count, estimator_state) def apply( self, updates: optax.Updates, - state: PreconditionState, + state: OptaxPreconditionState, ) -> optax.Updates: """Preconditions (= multiplies the inverse curvature estimation matrix to) updates.""" @@ -419,13 +421,13 @@ def apply( if self._norm_constraint is not None: - sq_norm_grads = kfac_jax.utils.inner_product(new_updates, updates) + sq_norm_grads = utils.inner_product(new_updates, updates) del updates max_coefficient = jnp.sqrt(self._norm_constraint / sq_norm_grads) coeff = jnp.minimum(max_coefficient, 1) - new_updates = kfac_jax.utils.scalar_mul(new_updates, coeff) + new_updates = utils.scalar_mul(new_updates, coeff) else: del updates @@ -435,7 +437,7 @@ def apply( def multiply_curvature( self, updates: optax.Updates, - state: PreconditionState, + state: OptaxPreconditionState, ) -> optax.Updates: """Multiplies the (non-inverse) curvature estimation matrix to updates.""" @@ -447,7 +449,7 @@ def multiply_curvature( # `use_exact_inverses == False` (default). In particular, the former uses # non-factored damping while the latter uses factored one, and the two are # NOT the exact inverses of each other. - updates = self.estimator.multiply( + return self.estimator.multiply( state=state.estimator_state, parameter_structured_vector=updates, identity_weight=self.get_identity_weight(state), @@ -456,7 +458,6 @@ def multiply_curvature( pmap_axis_name=self.pmap_axis_name, norm_to_scale_identity_weight_per_block=self._norm_to_scale_identity_weight_per_block, ) - return updates def as_gradient_transform( self, use_inverse: bool = True @@ -474,15 +475,14 @@ def update_fn( state, params=None, *, - precond_state: PreconditionState, + precond_state: OptaxPreconditionState, **extra_args, ): del params, extra_args - updates = multiply_fn(updates, precond_state) - return updates, state + return multiply_fn(updates, precond_state), state return optax.GradientTransformationExtraArgs(init_fn, update_fn) - def increment_count(self, state: PreconditionState): + def increment_count(self, state: OptaxPreconditionState): count_inc = optax.safe_int32_increment(state.count) - return PreconditionState(count_inc, state.estimator_state) + return OptaxPreconditionState(count_inc, state.estimator_state) diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index 3fbdca0..73e591b 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -34,11 +34,8 @@ FuncState = Any FuncAux = utils.FuncAux Scalar = utils.Scalar +ScheduleType = utils.ScheduleType -ScheduleType = ( - Callable[[Numeric, Numeric | None], Numeric] | - Callable[[Numeric], Numeric] - ) FuncArgsVariants = ( tuple[Params, Batch] | tuple[Params, FuncState, Batch] | diff --git a/kfac_jax/_src/utils/__init__.py b/kfac_jax/_src/utils/__init__.py index c44cfd1..55a5ea5 100644 --- a/kfac_jax/_src/utils/__init__.py +++ b/kfac_jax/_src/utils/__init__.py @@ -41,6 +41,7 @@ ValueFunc = types.ValueFunc ValueAndGradFunc = types.ValueAndGradFunc AssumedFuncOutput = types.AssumedFuncOutput +ScheduleType = types.ScheduleType tree_is_empty = types.tree_is_empty abstract_objects_equal = types.abstract_objects_equal get_float_dtype_and_check_consistency = ( diff --git a/kfac_jax/_src/utils/types.py b/kfac_jax/_src/utils/types.py index f3397df..b7eb2ba 100644 --- a/kfac_jax/_src/utils/types.py +++ b/kfac_jax/_src/utils/types.py @@ -41,6 +41,10 @@ AssumedFuncOutput = (Array | tuple[Array, FuncAux] | tuple[Array, tuple[FuncState, FuncAux]]) SCALAR_TYPES = (float, int) +ScheduleType = ( + Callable[[Numeric, Numeric | None], Numeric] | + Callable[[Numeric], Numeric] + ) def tree_is_empty(obj: ArrayTree) -> bool: diff --git a/tests/estimator_test_utils.py b/tests/estimator_test_utils.py new file mode 100644 index 0000000..804de79 --- /dev/null +++ b/tests/estimator_test_utils.py @@ -0,0 +1,93 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing functionalities of the curvature estimation.""" + +import functools + +import jax +import kfac_jax +from tests import models + + +Array = kfac_jax.utils.Array +PRNGKey = kfac_jax.utils.PRNGKey +Shape = kfac_jax.utils.Shape +StateType = kfac_jax.curvature_estimator.StateType + + +NON_LINEAR_MODELS_AND_CURVATURE_TYPE = [ + model + ("ggn",) for model in models.NON_LINEAR_MODELS +] + [ + model + ("fisher",) for model in models.NON_LINEAR_MODELS +] + + +LINEAR_MODELS_AND_CURVATURE_TYPE = [ + model + ("ggn",) for model in models.LINEAR_MODELS +] + [ + model + ("fisher",) for model in models.LINEAR_MODELS +] + + +PIECEWISE_LINEAR_MODELS_AND_CURVATURE = [ + model + ("ggn",) for model in models.PIECEWISE_LINEAR_MODELS +] + [ + model + ("fisher",) for model in models.PIECEWISE_LINEAR_MODELS +] + + +CONV_SIZES_AND_ESTIMATION_MODES = [ + [ + dict(images=(16, 16, 3), labels=(10,)), + 1230971, + "ggn", + ], + [ + dict(images=(16, 16, 3), labels=(10,)), + 1230971, + "fisher", + ], +] + +LAYER_CHANNELS = [4, 8, 16] + + +@functools.partial(jax.jit, static_argnums=(0, 3, 4)) +def compute_exact_approx_curvature( + estimator: kfac_jax.CurvatureEstimator[StateType], + rng: PRNGKey, + func_args: kfac_jax.utils.FuncArgs, + batch_size: int, + curvature_type: str, +) -> StateType: + """Computes the full Fisher matrix approximation for the estimator.""" + state = estimator.init( + rng=rng, + func_args=func_args, + exact_powers_to_cache=None, + approx_powers_to_cache=None, + cache_eigenvalues=False, + ) + state = estimator.update_curvature_matrix_estimate( + state=state, + ema_old=0.0, + ema_new=1.0, + identity_weight=0.0, # This doesn't matter here. + batch_size=batch_size, + rng=rng, + func_args=func_args, + estimation_mode=f"{curvature_type}_exact", + ) + estimator.sync(state, pmap_axis_name="i") + return state diff --git a/tests/test_estimator.py b/tests/test_estimator.py index d4f49f4..48f2d5f 100644 --- a/tests/test_estimator.py +++ b/tests/test_estimator.py @@ -20,6 +20,7 @@ import jax import jax.numpy as jnp import kfac_jax +from tests import estimator_test_utils from tests import models import numpy as np @@ -30,71 +31,22 @@ StateType = kfac_jax.curvature_estimator.StateType -NON_LINEAR_MODELS_AND_CURVATURE_TYPE = [ - model + ("ggn",) for model in models.NON_LINEAR_MODELS -] + [ - model + ("fisher",) for model in models.NON_LINEAR_MODELS -] - - -LINEAR_MODELS_AND_CURVATURE_TYPE = [ - model + ("ggn",) for model in models.LINEAR_MODELS -] + [ - model + ("fisher",) for model in models.LINEAR_MODELS -] - - -PIECEWISE_LINEAR_MODELS_AND_CURVATURE = [ - model + ("ggn",) for model in models.PIECEWISE_LINEAR_MODELS -] + [ - model + ("fisher",) for model in models.PIECEWISE_LINEAR_MODELS -] - - -CONV_SIZES_AND_ESTIMATION_MODES = [ - [ - dict(images=(16, 16, 3), labels=(10,)), - 1230971, - "ggn", - ], - [ - dict(images=(16, 16, 3), labels=(10,)), - 1230971, - "fisher", - ], -] - -LAYER_CHANNELS = [4, 8, 16] - - -@functools.partial(jax.jit, static_argnums=(0, 3, 4)) -def compute_exact_approx_curvature( - estimator: kfac_jax.CurvatureEstimator[StateType], - rng: PRNGKey, - func_args: kfac_jax.utils.FuncArgs, - batch_size: int, - curvature_type: str, -) -> StateType: - """Computes the full Fisher matrix approximation for the estimator.""" - state = estimator.init( - rng=rng, - func_args=func_args, - exact_powers_to_cache=None, - approx_powers_to_cache=None, - cache_eigenvalues=False, - ) - state = estimator.update_curvature_matrix_estimate( - state=state, - ema_old=0.0, - ema_new=1.0, - identity_weight=0.0, # This doesn't matter here. - batch_size=batch_size, - rng=rng, - func_args=func_args, - estimation_mode=f"{curvature_type}_exact", - ) - estimator.sync(state, pmap_axis_name="i") - return state +LAYER_CHANNELS = estimator_test_utils.LAYER_CHANNELS +NON_LINEAR_MODELS_AND_CURVATURE_TYPE = ( + estimator_test_utils.NON_LINEAR_MODELS_AND_CURVATURE_TYPE + ) +LINEAR_MODELS_AND_CURVATURE_TYPE = ( + estimator_test_utils.LINEAR_MODELS_AND_CURVATURE_TYPE + ) +PIECEWISE_LINEAR_MODELS_AND_CURVATURE = ( + estimator_test_utils.PIECEWISE_LINEAR_MODELS_AND_CURVATURE + ) +CONV_SIZES_AND_ESTIMATION_MODES = ( + estimator_test_utils.CONV_SIZES_AND_ESTIMATION_MODES + ) +compute_exact_approx_curvature = ( + estimator_test_utils.compute_exact_approx_curvature + ) class TestEstimator(parameterized.TestCase): diff --git a/tests/test_optax_interface.py b/tests/test_optax_interface.py new file mode 100644 index 0000000..3a92a85 --- /dev/null +++ b/tests/test_optax_interface.py @@ -0,0 +1,156 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for testing the optax interface to K-FAC.""" +import functools +from typing import Callable, Mapping + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import kfac_jax +from tests import estimator_test_utils +from tests import models +import numpy as np + +Array = kfac_jax.utils.Array +PRNGKey = kfac_jax.utils.PRNGKey +Shape = kfac_jax.utils.Shape +StateType = kfac_jax.curvature_estimator.StateType + +NON_LINEAR_MODELS_AND_CURVATURE_TYPE = ( + estimator_test_utils.NON_LINEAR_MODELS_AND_CURVATURE_TYPE + ) + +compute_exact_approx_curvature = ( + estimator_test_utils.compute_exact_approx_curvature + ) + + +@functools.partial(jax.jit, static_argnums=(0,)) +def compute_exact_approx_curvature_precon( + preconditioner: kfac_jax.OptaxPreconditioner, + rng: PRNGKey, + func_args: kfac_jax.optimizer.FuncArgsVariants, +) -> kfac_jax.OptaxPreconditionState: + """Computes the full Fisher matrix approximation for the estimator.""" + return preconditioner.maybe_update_estimator_curvature( + state=preconditioner.init(func_args=func_args, rng=rng), + func_args=func_args, + rng=rng, + decay_old_ema=True, + sync=True, + ) + + +class TestOptaxPreconditioner(parameterized.TestCase): + """Testing the optax interface to K-FAC.""" + + def assert_trees_all_close( + self, + x: kfac_jax.utils.PyTree, + y: kfac_jax.utils.PyTree, + check_dtypes: bool = True, + atol: float = 1e-6, + rtol: float = 1e-6, + ): + """Asserts that the two PyTrees are close up to the provided tolerances.""" + if jax.devices()[0].platform == "tpu": + rtol = 3e3 * rtol + atol = 3e3 * atol + + x_v, x_tree = jax.tree_util.tree_flatten(x) + y_v, y_tree = jax.tree_util.tree_flatten(y) + self.assertEqual(x_tree, y_tree) + for xi, yi in zip(x_v, y_v): + self.assertEqual(xi.shape, yi.shape) + if check_dtypes: + self.assertEqual(xi.dtype, yi.dtype) + np.testing.assert_allclose(xi, yi, rtol=rtol, atol=atol, equal_nan=False) + + @parameterized.parameters(NON_LINEAR_MODELS_AND_CURVATURE_TYPE) + def test_block_diagonal_full( + self, + init_func: Callable[..., models.hk.Params], + model_func: Callable[..., Array], + data_point_shapes: Mapping[str, Shape], + seed: int, + curvature_type: str, + data_size: int = 4, + ): + """Tests that the block diagonal full is equal to the explicit curvature.""" + rng_key = jax.random.PRNGKey(seed) + init_key, data_key, estimator_key = jax.random.split(rng_key, 3) + + # Generate data + data = {} + for name, shape in data_point_shapes.items(): + data_key, key = jax.random.split(data_key) + data[name] = jax.random.uniform(key, (data_size, *shape)) + if name == "labels": + data[name] = jnp.argmax(data[name], axis=-1) + + params = init_func(init_key, data) + func_args = (params, data) + + # Compute curvature matrix using the block diagonal full estimator + optax_estimator = kfac_jax.OptaxPreconditioner( + model_func, + damping=0.0, + curvature_ema=0.0, + layer_tag_to_block_ctor=dict( + dense=kfac_jax.DenseFull, + conv2d=kfac_jax.Conv2DFull, + scale_and_shift=kfac_jax.ScaleAndShiftFull, + ), + pmap_axis_name="i", + estimation_mode=f"{curvature_type}_exact", + ) + + precondition_state = compute_exact_approx_curvature_precon( + preconditioner=optax_estimator, + rng=estimator_key, + func_args=func_args, + ) + + block_estimator = optax_estimator.estimator + blocks = block_estimator.to_diagonal_block_dense_matrix( + precondition_state.estimator_state + ) + + # Compute curvature matrix using the explicit exact curvature + full_estimator = kfac_jax.ExplicitExactCurvature( + model_func, default_estimation_mode="fisher_exact", + param_order=block_estimator.param_order + ) + state = compute_exact_approx_curvature( + full_estimator, + estimator_key, + func_args, + data_size, + curvature_type, + ) + full_matrix = full_estimator.to_dense_matrix(state) + + # Compare blocks + d = 0 + for block in blocks: + s = slice(d, d + block.shape[0]) + self.assert_trees_all_close(block, full_matrix[s, s]) + d = d + block.shape[0] + self.assertEqual(d, full_matrix.shape[0]) + + +if __name__ == "__main__": + absltest.main()