Skip to content

Commit

Permalink
Move optax interface into the main kfac codebase.
Browse files Browse the repository at this point in the history
Fix bugs in the optax interface, due to argument changes for the estimator.

Add tests for the optax interface.

PiperOrigin-RevId: 670937403
  • Loading branch information
joeljennings authored and KfacJaxDev committed Sep 5, 2024
1 parent 56054a5 commit 1d63c3b
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 133 deletions.
5 changes: 2 additions & 3 deletions examples/optax_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import jax
import kfac_jax
from examples import optax_preconditioner
import optax


Expand All @@ -29,8 +28,8 @@
ScheduleType = kfac_jax.optimizer.ScheduleType
OptaxCtor = Callable[[ScheduleType], optax.GradientTransformation]

PreconditionState = optax_preconditioner.PreconditionState
Preconditioner = optax_preconditioner.Preconditioner
PreconditionState = kfac_jax.OptaxPreconditionState
Preconditioner = kfac_jax.OptaxPreconditioner


class OptaxAndPreconditionState(NamedTuple):
Expand Down
2 changes: 2 additions & 0 deletions kfac_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@
curvature_estimator.set_default_tag_to_block_ctor)
get_default_tag_to_block_ctor = (
curvature_estimator.get_default_tag_to_block_ctor)
OptaxPreconditioner = curvature_estimator.OptaxPreconditioner
OptaxPreconditionState = curvature_estimator.OptaxPreconditionState

# Optimizers
Optimizer = optimizer.Optimizer
Expand Down
4 changes: 4 additions & 0 deletions kfac_jax/_src/curvature_estimator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from kfac_jax._src.curvature_estimator import curvature_estimator
from kfac_jax._src.curvature_estimator import explicit_exact
from kfac_jax._src.curvature_estimator import implicit_exact
from kfac_jax._src.curvature_estimator import optax_interface


BlockDiagonalCurvature = block_diagonal.BlockDiagonalCurvature
Expand All @@ -75,3 +76,6 @@
LossFunctionInputs = implicit_exact.LossFunctionInputs
LossFunctionInputsSequence = implicit_exact.LossFunctionInputsSequence
LossFunctionInputsTuple = implicit_exact.LossFunctionInputsTuple

OptaxPreconditioner = optax_interface.OptaxPreconditioner
OptaxPreconditionState = optax_interface.OptaxPreconditionState
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,29 @@
import jax
from jax import lax
import jax.numpy as jnp
import kfac_jax
from kfac_jax._src import utils
from kfac_jax._src.curvature_estimator import block_diagonal
from kfac_jax._src.curvature_estimator import curvature_estimator
import optax


Array = kfac_jax.utils.Array
Numeric = kfac_jax.utils.Numeric
PRNGKey = kfac_jax.utils.PRNGKey
Params = kfac_jax.utils.Params
Batch = kfac_jax.utils.Batch
ValueFunc = kfac_jax.optimizer.ValueFunc
FuncArgsVariants = kfac_jax.optimizer.FuncArgsVariants
ScheduleType = kfac_jax.optimizer.ScheduleType
EstimatorState = kfac_jax.curvature_estimator.BlockDiagonalCurvature.State
Array = utils.Array
Numeric = utils.Numeric
PRNGKey = utils.PRNGKey
Params = utils.Params
Batch = utils.Batch
ValueFunc = utils.ValueFunc
FuncArgs = utils.FuncArgs
ScheduleType = utils.ScheduleType
EstimatorState = block_diagonal.BlockDiagonalCurvature.State


class PreconditionState(NamedTuple):
class OptaxPreconditionState(NamedTuple):
count: Array
estimator_state: EstimatorState


class Preconditioner:
class OptaxPreconditioner:
"""An Optax-compatible K-FAC preconditioner."""

def __init__(
Expand All @@ -59,12 +61,12 @@ def __init__(
patterns_to_skip: Sequence[str] = (),
auto_register_kwargs: dict[str, Any] | None = None,
layer_tag_to_block_ctor: (
dict[str, kfac_jax.curvature_estimator.CurvatureBlockCtor] | None
dict[str, curvature_estimator.CurvatureBlockCtor] | None
) = None,
pmap_axis_name: str = "kfac_axis",
batch_size_extractor: Callable[
[Batch], Numeric
] = kfac_jax.utils.default_batch_size_extractor,
] = utils.default_batch_size_extractor,
distributed_inverses: bool = True,
distributed_precon_apply: bool = True,
num_samples: int = 1,
Expand Down Expand Up @@ -168,29 +170,32 @@ def __init__(
norm_to_scale_identity_weight_per_block
)

auto_register_kwargs = auto_register_kwargs or {}
auto_register_kwargs.update(dict(
register_only_generic=register_only_generic,
patterns_to_skip=patterns_to_skip,
))
# Curvature estimator
self._estimator = kfac_jax.curvature_estimator.BlockDiagonalCurvature(
self._estimator = block_diagonal.BlockDiagonalCurvature(
func=value_func,
default_estimation_mode=estimation_mode,
params_index=0,
layer_tag_to_block_ctor=layer_tag_to_block_ctor,
register_only_generic=register_only_generic,
patterns_to_skip=patterns_to_skip,
distributed_multiplies=distributed_precon_apply,
distributed_cache_updates=distributed_inverses,
num_samples=num_samples,
should_vmap_samples=should_vmap_samples,
**(auto_register_kwargs or {}),
auto_register_kwargs=auto_register_kwargs,
)

def init(
self,
func_args: FuncArgsVariants,
func_args: FuncArgs,
rng: PRNGKey,
) -> PreconditionState:
) -> OptaxPreconditionState:
"""Initializes the preconditioner and returns the state."""

return PreconditionState(
return OptaxPreconditionState(
count=jnp.array(0, dtype=jnp.int32),
estimator_state=self.estimator.init(
rng=rng,
Expand All @@ -202,23 +207,19 @@ def init(
)

@property
def _exact_powers_to_cache(self) -> int | Sequence[int] | None:

def _exact_powers_to_cache(self) -> int | None:
if self._use_exact_inverses and self._use_cached_inverses:
return -1
else:
return None
return None

@property
def _approx_powers_to_cache(self) -> int | Sequence[int] | None:

def _approx_powers_to_cache(self) -> int | None:
if not self._use_exact_inverses and self._use_cached_inverses:
return -1
else:
return None
return None

@property
def estimator(self) -> kfac_jax.curvature_estimator.BlockDiagonalCurvature:
def estimator(self) -> block_diagonal.BlockDiagonalCurvature:
"""The underlying curvature estimator used by the preconditioner."""
return self._estimator

Expand All @@ -227,7 +228,7 @@ def pmap_axis_name(self):
return self._pmap_axis_name

def get_identity_weight(
self, state: PreconditionState
self, state: OptaxPreconditionState
) -> Array | float:

damping = self._damping
Expand All @@ -239,18 +240,18 @@ def get_identity_weight(

def sync_estimator_state(
self,
state: PreconditionState,
) -> PreconditionState:
state: OptaxPreconditionState,
) -> OptaxPreconditionState:
"""Syncs the estimator state."""

return PreconditionState(
return OptaxPreconditionState(
count=state.count,
estimator_state=self.estimator.sync(
state.estimator_state, pmap_axis_name=self.pmap_axis_name),
)

def should_update_estimator_curvature(
self, state: PreconditionState
self, state: OptaxPreconditionState
) -> Array | bool:
"""Whether at the current step the preconditioner should update the curvature estimates."""

Expand All @@ -260,7 +261,7 @@ def should_update_estimator_curvature(
return state.count % self._curvature_update_period == 0

def should_sync_estimate_curvature(
self, state: PreconditionState
self, state: OptaxPreconditionState
) -> Array | bool:
"""Whether at the current step the preconditioner should synchronize (pmean) the curvature estimates."""

Expand All @@ -272,7 +273,7 @@ def should_sync_estimate_curvature(
return self.should_update_inverse_cache(state)

def should_update_inverse_cache(
self, state: PreconditionState
self, state: OptaxPreconditionState
) -> Array | bool:
"""Whether at the current step the preconditioner should update the inverse cache."""

Expand All @@ -283,10 +284,10 @@ def should_update_inverse_cache(

def maybe_update(
self,
state: PreconditionState,
func_args: FuncArgsVariants,
state: OptaxPreconditionState,
func_args: FuncArgs,
rng: PRNGKey,
) -> PreconditionState:
) -> OptaxPreconditionState:
"""Updates the estimates if it is the right iteration."""

# NOTE: This maybe update curvatures and inverses at an iteration. But
Expand All @@ -303,12 +304,12 @@ def maybe_update(

state = self.maybe_update_inverse_cache(state)

return PreconditionState(state.count, state.estimator_state)
return OptaxPreconditionState(state.count, state.estimator_state)

def _update_estimator_curvature(
self,
estimator_state: EstimatorState,
func_args: FuncArgsVariants,
func_args: FuncArgs,
rng: PRNGKey,
ema_old: Numeric,
ema_new: Numeric,
Expand All @@ -322,6 +323,7 @@ def _update_estimator_curvature(
ema_new=ema_new,
# Note that the batch is always the last entry of FuncArgsVariantsdef
batch_size=self._batch_size_extractor(func_args[-1]),
identity_weight=self.get_identity_weight(estimator_state),
rng=rng,
func_args=func_args,
)
Expand All @@ -336,12 +338,12 @@ def _update_estimator_curvature(

def maybe_update_estimator_curvature(
self,
state: PreconditionState,
func_args: FuncArgsVariants,
state: OptaxPreconditionState,
func_args: FuncArgs,
rng: PRNGKey,
decay_old_ema: Array | bool = True,
sync: Array | bool = True,
) -> PreconditionState:
) -> OptaxPreconditionState:
"""Updates the curvature estimates if it is the right iteration."""

ema_old = decay_old_ema * self._curvature_ema + (1.0 - decay_old_ema) * 1.0
Expand All @@ -359,8 +361,8 @@ def maybe_update_estimator_curvature(

def maybe_update_inverse_cache(
self,
state: PreconditionState,
) -> PreconditionState:
state: OptaxPreconditionState,
) -> OptaxPreconditionState:
"""Updates the estimator state cache if it is the right iteration."""

if state.count is None:
Expand All @@ -383,11 +385,11 @@ def maybe_update_inverse_cache(

def _maybe_update_estimator_state(
self,
state: PreconditionState,
state: OptaxPreconditionState,
should_update: Array | bool,
update_func: Callable[..., EstimatorState],
**update_func_kwargs,
) -> PreconditionState:
) -> OptaxPreconditionState:
"""Updates the estimator state if it should update."""

estimator_state = lax.cond(
Expand All @@ -397,12 +399,12 @@ def _maybe_update_estimator_state(
state.estimator_state,
)

return PreconditionState(state.count, estimator_state)
return OptaxPreconditionState(state.count, estimator_state)

def apply(
self,
updates: optax.Updates,
state: PreconditionState,
state: OptaxPreconditionState,
) -> optax.Updates:
"""Preconditions (= multiplies the inverse curvature estimation matrix to) updates."""

Expand All @@ -419,13 +421,13 @@ def apply(

if self._norm_constraint is not None:

sq_norm_grads = kfac_jax.utils.inner_product(new_updates, updates)
sq_norm_grads = utils.inner_product(new_updates, updates)
del updates

max_coefficient = jnp.sqrt(self._norm_constraint / sq_norm_grads)
coeff = jnp.minimum(max_coefficient, 1)

new_updates = kfac_jax.utils.scalar_mul(new_updates, coeff)
new_updates = utils.scalar_mul(new_updates, coeff)

else:
del updates
Expand All @@ -435,7 +437,7 @@ def apply(
def multiply_curvature(
self,
updates: optax.Updates,
state: PreconditionState,
state: OptaxPreconditionState,
) -> optax.Updates:
"""Multiplies the (non-inverse) curvature estimation matrix to updates."""

Expand All @@ -447,7 +449,7 @@ def multiply_curvature(
# `use_exact_inverses == False` (default). In particular, the former uses
# non-factored damping while the latter uses factored one, and the two are
# NOT the exact inverses of each other.
updates = self.estimator.multiply(
return self.estimator.multiply(
state=state.estimator_state,
parameter_structured_vector=updates,
identity_weight=self.get_identity_weight(state),
Expand All @@ -456,7 +458,6 @@ def multiply_curvature(
pmap_axis_name=self.pmap_axis_name,
norm_to_scale_identity_weight_per_block=self._norm_to_scale_identity_weight_per_block,
)
return updates

def as_gradient_transform(
self, use_inverse: bool = True
Expand All @@ -474,15 +475,14 @@ def update_fn(
state,
params=None,
*,
precond_state: PreconditionState,
precond_state: OptaxPreconditionState,
**extra_args,
):
del params, extra_args
updates = multiply_fn(updates, precond_state)
return updates, state
return multiply_fn(updates, precond_state), state

return optax.GradientTransformationExtraArgs(init_fn, update_fn)

def increment_count(self, state: PreconditionState):
def increment_count(self, state: OptaxPreconditionState):
count_inc = optax.safe_int32_increment(state.count)
return PreconditionState(count_inc, state.estimator_state)
return OptaxPreconditionState(count_inc, state.estimator_state)
5 changes: 1 addition & 4 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@
FuncState = Any
FuncAux = utils.FuncAux
Scalar = utils.Scalar
ScheduleType = utils.ScheduleType

ScheduleType = (
Callable[[Numeric, Numeric | None], Numeric] |
Callable[[Numeric], Numeric]
)
FuncArgsVariants = (
tuple[Params, Batch] |
tuple[Params, FuncState, Batch] |
Expand Down
1 change: 1 addition & 0 deletions kfac_jax/_src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ValueFunc = types.ValueFunc
ValueAndGradFunc = types.ValueAndGradFunc
AssumedFuncOutput = types.AssumedFuncOutput
ScheduleType = types.ScheduleType
tree_is_empty = types.tree_is_empty
abstract_objects_equal = types.abstract_objects_equal
get_float_dtype_and_check_consistency = (
Expand Down
Loading

0 comments on commit 1d63c3b

Please sign in to comment.