Skip to content

Commit

Permalink
- Passing stats to _post_param_update_processing in examples code.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671546076
  • Loading branch information
james-martens authored and KfacJaxDev committed Sep 7, 2024
1 parent 315f5e9 commit cec1b1a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,8 @@ def _maybe_update_polyak_average_and_stats(
)
)

def _post_param_update_processing(self, global_step: Array):
def _post_param_update_processing(
self, global_step: Array, stats: dict[str, Numeric]):
pass

def train_step(self, global_step: Array, rng: PRNGKey) -> dict[str, Numeric]:
Expand All @@ -561,7 +562,7 @@ def train_step(self, global_step: Array, rng: PRNGKey) -> dict[str, Numeric]:
else:
self._params, self._opt_state, stats = result

self._post_param_update_processing(global_step)
self._post_param_update_processing(global_step, stats)

self._maybe_update_polyak_average_and_stats(rng, stats)

Expand Down
8 changes: 4 additions & 4 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
norm_to_scale_identity_weight_per_block: str | None = None,
precon_power: Scalar = -1.0,
):
"""Initializes the K-FAC optimizer with the provided settings.
"""Initializes the kfac_jax optimizer with the provided settings.
NOTE: Please read the docstring for this constructor carefully. Especially
the description of ``value_and_grad_func``.
Expand Down Expand Up @@ -404,13 +404,13 @@ def __init__(
raise ValueError("When use_adaptive_damping is False you should not "
"provide a value for initial_damping.")
if use_adaptive_learning_rate and learning_rate_schedule is not None:
raise ValueError("If you are using adaptive learning rate than "
raise ValueError("If you are using adaptive learning rate then "
"`learning_rate_schedule` should be None.")
if use_adaptive_momentum and momentum_schedule is not None:
raise ValueError("If you are using adaptive momentum than "
raise ValueError("If you are using adaptive momentum then "
"`momentum_schedule` should be None.")
if use_adaptive_damping and damping_schedule is not None:
raise ValueError("If you are using adaptive damping than "
raise ValueError("If you are using adaptive damping then "
"`damping_schedule` should be None.")

self._value_and_grad_func = value_and_grad_func
Expand Down

0 comments on commit cec1b1a

Please sign in to comment.