diff --git a/examples/optimizers.py b/examples/optimizers.py index 66c6225..7003869 100644 --- a/examples/optimizers.py +++ b/examples/optimizers.py @@ -290,6 +290,7 @@ def _update_estimator_curvature( sync: Union[Array, bool] = True ) -> EstimatorState: """Updates the curvature estimator state.""" + state = self.estimator.update_curvature_matrix_estimate( state=estimator_state, ema_old=ema_old, @@ -316,7 +317,9 @@ def maybe_update_estimator_curvature( sync: Union[Array, bool] = True, ) -> PreconditionState: """Updates the curvature estimates if it is the right iteration.""" + ema_old = decay_old_ema * self._curvature_ema + (1.0 - decay_old_ema) * 1.0 + return self._maybe_update_estimator_state( state, self.should_update_estimate_curvature(state), @@ -587,9 +590,12 @@ def _step( Tuple[Params, OptaxAndPreconditionState, Mapping[str, Array]], ]: """A single step of optax.""" + + rng_func, rng_precon = jax.random.split(rng) + batch = self._batch_process_func(batch) func_args = kfac_jax.optimizer.make_func_args( - params, func_state, rng, batch, + params, func_state, rng_func, batch, has_state=self._value_func_has_state, has_rng=self._value_func_has_rng ) @@ -599,7 +605,7 @@ def _step( precond_state = self._preconditioner.maybe_update( precond_state, func_args, - rng, + rng_precon, ) precond_state = self._preconditioner.increment_count(precond_state) out, grads = self._value_and_grad_func(*func_args) @@ -666,16 +672,22 @@ def step( Tuple[Params, Any, Mapping[str, Array]], ]: """A step with similar interface to KFAC.""" + + rng_init, rng_step = self._pmap_rng_split(rng, 2) + batch = next(data_iterator) + if self._preconditioner is not None and state.precond_state is None: + precond_state = self._pmap_init_preconditioner( - params, rng, batch, func_state + params, rng_init, batch, func_state ) state = OptaxAndPreconditionState(state.optax_state, precond_state) + return self._pmap_step( params, state, - rng, + rng_step, batch, func_state, global_step_int, diff --git a/kfac_jax/_src/curvature_blocks.py b/kfac_jax/_src/curvature_blocks.py index 65e0b3d..9e69727 100644 --- a/kfac_jax/_src/curvature_blocks.py +++ b/kfac_jax/_src/curvature_blocks.py @@ -880,7 +880,6 @@ def _multiply_matpower_unscaled( vector = self.parameters_list_to_single_vector(vector) if power == 1: - result = jnp.matmul(state.matrix.value, vector) + identity_weight * vector elif not use_cached: @@ -1208,22 +1207,6 @@ def _eigenvalues_unscaled( return utils.outer_product(*s) - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: "KroneckerFactored.State", - estimation_data: Mapping[str, Sequence[Array]], - ema_old: Numeric, - ema_new: Numeric, - batch_size: Numeric, - ) -> "KroneckerFactored.State": - assert len(state.factors) == len(self.axis_groups) - - # This function call will return a copy of state: - return self._update_curvature_matrix_estimate( - state, estimation_data, ema_old, ema_new, batch_size - ) - def _update_cache( # pytype: disable=signature-mismatch # numpy-scalars self, state: "KroneckerFactored.State", diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index f33e8b4..b17d2f3 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -1463,8 +1463,11 @@ def convert_value_and_grad_to_value_func( Returns: A function that returns only the loss value. """ + def value_func(*args, **kwargs) -> Array: + out, _ = value_and_grad_func(*args, **kwargs) + return out[0] if has_aux else out return value_func