Skip to content

Commit

Permalink
Add tests for the optax preconditioner wrapper
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 670937403
  • Loading branch information
joeljennings authored and KfacJaxDev committed Sep 4, 2024
1 parent 56054a5 commit 324d8d0
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 80 deletions.
28 changes: 13 additions & 15 deletions examples/optax_preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,22 @@ def __init__(
norm_to_scale_identity_weight_per_block
)

auto_register_kwargs = auto_register_kwargs or {}
auto_register_kwargs.update(dict(
register_only_generic=register_only_generic,
patterns_to_skip=patterns_to_skip,
))
# Curvature estimator
self._estimator = kfac_jax.curvature_estimator.BlockDiagonalCurvature(
func=value_func,
default_estimation_mode=estimation_mode,
params_index=0,
layer_tag_to_block_ctor=layer_tag_to_block_ctor,
register_only_generic=register_only_generic,
patterns_to_skip=patterns_to_skip,
distributed_multiplies=distributed_precon_apply,
distributed_cache_updates=distributed_inverses,
num_samples=num_samples,
should_vmap_samples=should_vmap_samples,
**(auto_register_kwargs or {}),
auto_register_kwargs=auto_register_kwargs,
)

def init(
Expand All @@ -202,20 +205,16 @@ def init(
)

@property
def _exact_powers_to_cache(self) -> int | Sequence[int] | None:

def _exact_powers_to_cache(self) -> int | None:
if self._use_exact_inverses and self._use_cached_inverses:
return -1
else:
return None
return None

@property
def _approx_powers_to_cache(self) -> int | Sequence[int] | None:

def _approx_powers_to_cache(self) -> int | None:
if not self._use_exact_inverses and self._use_cached_inverses:
return -1
else:
return None
return None

@property
def estimator(self) -> kfac_jax.curvature_estimator.BlockDiagonalCurvature:
Expand Down Expand Up @@ -322,6 +321,7 @@ def _update_estimator_curvature(
ema_new=ema_new,
# Note that the batch is always the last entry of FuncArgsVariantsdef
batch_size=self._batch_size_extractor(func_args[-1]),
identity_weight=self.get_identity_weight(estimator_state),
rng=rng,
func_args=func_args,
)
Expand Down Expand Up @@ -447,7 +447,7 @@ def multiply_curvature(
# `use_exact_inverses == False` (default). In particular, the former uses
# non-factored damping while the latter uses factored one, and the two are
# NOT the exact inverses of each other.
updates = self.estimator.multiply(
return self.estimator.multiply(
state=state.estimator_state,
parameter_structured_vector=updates,
identity_weight=self.get_identity_weight(state),
Expand All @@ -456,7 +456,6 @@ def multiply_curvature(
pmap_axis_name=self.pmap_axis_name,
norm_to_scale_identity_weight_per_block=self._norm_to_scale_identity_weight_per_block,
)
return updates

def as_gradient_transform(
self, use_inverse: bool = True
Expand All @@ -478,8 +477,7 @@ def update_fn(
**extra_args,
):
del params, extra_args
updates = multiply_fn(updates, precond_state)
return updates, state
return multiply_fn(updates, precond_state), state

return optax.GradientTransformationExtraArgs(init_fn, update_fn)

Expand Down
93 changes: 93 additions & 0 deletions tests/estimator_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing functionalities of the curvature estimation."""

import functools

import jax
import kfac_jax
from kfac_jax.tests import models


Array = kfac_jax.utils.Array
PRNGKey = kfac_jax.utils.PRNGKey
Shape = kfac_jax.utils.Shape
StateType = kfac_jax.curvature_estimator.StateType


NON_LINEAR_MODELS_AND_CURVATURE_TYPE = [
model + ("ggn",) for model in models.NON_LINEAR_MODELS
] + [
model + ("fisher",) for model in models.NON_LINEAR_MODELS
]


LINEAR_MODELS_AND_CURVATURE_TYPE = [
model + ("ggn",) for model in models.LINEAR_MODELS
] + [
model + ("fisher",) for model in models.LINEAR_MODELS
]


PIECEWISE_LINEAR_MODELS_AND_CURVATURE = [
model + ("ggn",) for model in models.PIECEWISE_LINEAR_MODELS
] + [
model + ("fisher",) for model in models.PIECEWISE_LINEAR_MODELS
]


CONV_SIZES_AND_ESTIMATION_MODES = [
[
dict(images=(16, 16, 3), labels=(10,)),
1230971,
"ggn",
],
[
dict(images=(16, 16, 3), labels=(10,)),
1230971,
"fisher",
],
]

LAYER_CHANNELS = [4, 8, 16]


@functools.partial(jax.jit, static_argnums=(0, 3, 4))
def compute_exact_approx_curvature(
estimator: kfac_jax.CurvatureEstimator[StateType],
rng: PRNGKey,
func_args: kfac_jax.utils.FuncArgs,
batch_size: int,
curvature_type: str,
) -> StateType:
"""Computes the full Fisher matrix approximation for the estimator."""
state = estimator.init(
rng=rng,
func_args=func_args,
exact_powers_to_cache=None,
approx_powers_to_cache=None,
cache_eigenvalues=False,
)
state = estimator.update_curvature_matrix_estimate(
state=state,
ema_old=0.0,
ema_new=1.0,
identity_weight=0.0, # This doesn't matter here.
batch_size=batch_size,
rng=rng,
func_args=func_args,
estimation_mode=f"{curvature_type}_exact",
)
estimator.sync(state, pmap_axis_name="i")
return state
82 changes: 17 additions & 65 deletions tests/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax
import jax.numpy as jnp
import kfac_jax
from tests import estimator_test_utils
from tests import models
import numpy as np

Expand All @@ -30,71 +31,22 @@
StateType = kfac_jax.curvature_estimator.StateType


NON_LINEAR_MODELS_AND_CURVATURE_TYPE = [
model + ("ggn",) for model in models.NON_LINEAR_MODELS
] + [
model + ("fisher",) for model in models.NON_LINEAR_MODELS
]


LINEAR_MODELS_AND_CURVATURE_TYPE = [
model + ("ggn",) for model in models.LINEAR_MODELS
] + [
model + ("fisher",) for model in models.LINEAR_MODELS
]


PIECEWISE_LINEAR_MODELS_AND_CURVATURE = [
model + ("ggn",) for model in models.PIECEWISE_LINEAR_MODELS
] + [
model + ("fisher",) for model in models.PIECEWISE_LINEAR_MODELS
]


CONV_SIZES_AND_ESTIMATION_MODES = [
[
dict(images=(16, 16, 3), labels=(10,)),
1230971,
"ggn",
],
[
dict(images=(16, 16, 3), labels=(10,)),
1230971,
"fisher",
],
]

LAYER_CHANNELS = [4, 8, 16]


@functools.partial(jax.jit, static_argnums=(0, 3, 4))
def compute_exact_approx_curvature(
estimator: kfac_jax.CurvatureEstimator[StateType],
rng: PRNGKey,
func_args: kfac_jax.utils.FuncArgs,
batch_size: int,
curvature_type: str,
) -> StateType:
"""Computes the full Fisher matrix approximation for the estimator."""
state = estimator.init(
rng=rng,
func_args=func_args,
exact_powers_to_cache=None,
approx_powers_to_cache=None,
cache_eigenvalues=False,
)
state = estimator.update_curvature_matrix_estimate(
state=state,
ema_old=0.0,
ema_new=1.0,
identity_weight=0.0, # This doesn't matter here.
batch_size=batch_size,
rng=rng,
func_args=func_args,
estimation_mode=f"{curvature_type}_exact",
)
estimator.sync(state, pmap_axis_name="i")
return state
LAYER_CHANNELS = estimator_test_utils.LAYER_CHANNELS
NON_LINEAR_MODELS_AND_CURVATURE_TYPE = (
estimator_test_utils.NON_LINEAR_MODELS_AND_CURVATURE_TYPE
)
LINEAR_MODELS_AND_CURVATURE_TYPE = (
estimator_test_utils.LINEAR_MODELS_AND_CURVATURE_TYPE
)
PIECEWISE_LINEAR_MODELS_AND_CURVATURE = (
estimator_test_utils.PIECEWISE_LINEAR_MODELS_AND_CURVATURE
)
CONV_SIZES_AND_ESTIMATION_MODES = (
estimator_test_utils.CONV_SIZES_AND_ESTIMATION_MODES
)
compute_exact_approx_curvature = (
estimator_test_utils.compute_exact_approx_curvature
)


class TestEstimator(parameterized.TestCase):
Expand Down
Loading

0 comments on commit 324d8d0

Please sign in to comment.