Skip to content

Commit

Permalink
- Adding support for the "Schedule-free" method to be used as a wrapp…
Browse files Browse the repository at this point in the history
…er for Optax optimizers in the examples codebase.

- Slight reimplementation of WeightedMovingAverage. Should be a no-op up to numerics.

- Adding refresh_func_state_for_eval_with_n_iters to examples codebase to support better handling of BN stats when using Polyak averaging or Schedule-free.

PiperOrigin-RevId: 675596903
  • Loading branch information
james-martens authored and KfacJaxDev committed Sep 17, 2024
1 parent 93f1ef5 commit baaec40
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 84 deletions.
20 changes: 18 additions & 2 deletions examples/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,17 @@ def create_optimizer(
train_total_batch_size: int,
total_steps: int | None,
total_epochs: float | None,
schedule_free_config: config_dict.ConfigDict,
) -> optax_wrapper.OptaxWrapper | kfac_jax.Optimizer:
"""Creates an optimizer from the provided configuration."""

is_optax = "kfac" not in name and hasattr(optax, name)

if not is_optax and schedule_free_config.enabled:
raise ValueError(
"Schedule Free is only supported for optax optimizers."
)

value_and_grad_func = jax.value_and_grad(train_model_func, has_aux=has_aux)

kwargs = dict(**config[name])
Expand Down Expand Up @@ -155,7 +163,7 @@ def create_optimizer(
**kwargs,
)

elif hasattr(optax, name):
elif is_optax:

learning_rate_schedule = schedules.construct_schedule(
dataset_size=dataset_size,
Expand All @@ -164,7 +172,15 @@ def create_optimizer(
total_epochs=total_epochs,
**kwargs.pop("learning_rate_schedule")
)
optax_ctor = lambda lr: (getattr(optax, name)(learning_rate=lr, **kwargs))

if schedule_free_config.enabled:
optax_ctor = lambda lr: optax.contrib.schedule_free(
base_optimizer=getattr(optax, name)(learning_rate=lr, **kwargs),
learning_rate=lr,
**schedule_free_config.kwargs
)
else:
optax_ctor = lambda lr: (getattr(optax, name)(learning_rate=lr, **kwargs))

return optax_wrapper.OptaxWrapper(
value_and_grad_func=value_and_grad_func,
Expand Down
Loading

0 comments on commit baaec40

Please sign in to comment.