Skip to content

Commit

Permalink
- Fixing improper reuse of PRNG keys in the Optax wrapper
Browse files Browse the repository at this point in the history
- Getting rid of unused function and adding some whitespace for readability

PiperOrigin-RevId: 571075953
  • Loading branch information
james-martens authored and KfacJaxDev committed Oct 5, 2023
1 parent bd7c149 commit f0326d8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 21 deletions.
20 changes: 16 additions & 4 deletions examples/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def _update_estimator_curvature(
sync: Union[Array, bool] = True
) -> EstimatorState:
"""Updates the curvature estimator state."""

state = self.estimator.update_curvature_matrix_estimate(
state=estimator_state,
ema_old=ema_old,
Expand All @@ -316,7 +317,9 @@ def maybe_update_estimator_curvature(
sync: Union[Array, bool] = True,
) -> PreconditionState:
"""Updates the curvature estimates if it is the right iteration."""

ema_old = decay_old_ema * self._curvature_ema + (1.0 - decay_old_ema) * 1.0

return self._maybe_update_estimator_state(
state,
self.should_update_estimate_curvature(state),
Expand Down Expand Up @@ -587,9 +590,12 @@ def _step(
Tuple[Params, OptaxAndPreconditionState, Mapping[str, Array]],
]:
"""A single step of optax."""

rng_func, rng_precon = jax.random.split(rng)

batch = self._batch_process_func(batch)
func_args = kfac_jax.optimizer.make_func_args(
params, func_state, rng, batch,
params, func_state, rng_func, batch,
has_state=self._value_func_has_state,
has_rng=self._value_func_has_rng
)
Expand All @@ -599,7 +605,7 @@ def _step(
precond_state = self._preconditioner.maybe_update(
precond_state,
func_args,
rng,
rng_precon,
)
precond_state = self._preconditioner.increment_count(precond_state)
out, grads = self._value_and_grad_func(*func_args)
Expand Down Expand Up @@ -666,16 +672,22 @@ def step(
Tuple[Params, Any, Mapping[str, Array]],
]:
"""A step with similar interface to KFAC."""

rng_init, rng_step = self._pmap_rng_split(rng, 2)

batch = next(data_iterator)

if self._preconditioner is not None and state.precond_state is None:

precond_state = self._pmap_init_preconditioner(
params, rng, batch, func_state
params, rng_init, batch, func_state
)
state = OptaxAndPreconditionState(state.optax_state, precond_state)

return self._pmap_step(
params,
state,
rng,
rng_step,
batch,
func_state,
global_step_int,
Expand Down
17 changes: 0 additions & 17 deletions kfac_jax/_src/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,6 @@ def _multiply_matpower_unscaled(
vector = self.parameters_list_to_single_vector(vector)

if power == 1:

result = jnp.matmul(state.matrix.value, vector) + identity_weight * vector

elif not use_cached:
Expand Down Expand Up @@ -1208,22 +1207,6 @@ def _eigenvalues_unscaled(

return utils.outer_product(*s)

@utils.auto_scope_method
def update_curvature_matrix_estimate(
self,
state: "KroneckerFactored.State",
estimation_data: Mapping[str, Sequence[Array]],
ema_old: Numeric,
ema_new: Numeric,
batch_size: Numeric,
) -> "KroneckerFactored.State":
assert len(state.factors) == len(self.axis_groups)

# This function call will return a copy of state:
return self._update_curvature_matrix_estimate(
state, estimation_data, ema_old, ema_new, batch_size
)

def _update_cache( # pytype: disable=signature-mismatch # numpy-scalars
self,
state: "KroneckerFactored.State",
Expand Down
3 changes: 3 additions & 0 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,8 +1463,11 @@ def convert_value_and_grad_to_value_func(
Returns:
A function that returns only the loss value.
"""

def value_func(*args, **kwargs) -> Array:

out, _ = value_and_grad_func(*args, **kwargs)

return out[0] if has_aux else out

return value_func
Expand Down

0 comments on commit f0326d8

Please sign in to comment.