Skip to content

Commit

Permalink
- Adding support for the "Schedule-free" method to be used as a wrapp…
Browse files Browse the repository at this point in the history
…er for Optax optimizers in the examples codebase.

- Slight reimplementation of WeightedMovingAverage. Should be a no-op up to numerics.

- Adding refresh_func_state_for_eval_with_n_iters to examples codebase to support better handling of BN stats when using Polyak averaging or Schedule-free.

PiperOrigin-RevId: 674232052
  • Loading branch information
james-martens authored and KfacJaxDev committed Sep 16, 2024
1 parent 93f1ef5 commit e9738d3
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 79 deletions.
18 changes: 17 additions & 1 deletion examples/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,17 @@ def create_optimizer(
train_total_batch_size: int,
total_steps: int | None,
total_epochs: float | None,
schedule_free_config: config_dict.ConfigDict,
) -> optax_wrapper.OptaxWrapper | kfac_jax.Optimizer:
"""Creates an optimizer from the provided configuration."""

is_optax = "kfac" not in name and hasattr(optax, name)

if not is_optax and schedule_free_config.enabled:
raise ValueError(
"Schedule Free is only supported for optax optimizers."
)

value_and_grad_func = jax.value_and_grad(train_model_func, has_aux=has_aux)

kwargs = dict(**config[name])
Expand Down Expand Up @@ -155,7 +163,7 @@ def create_optimizer(
**kwargs,
)

elif hasattr(optax, name):
elif is_optax:

learning_rate_schedule = schedules.construct_schedule(
dataset_size=dataset_size,
Expand All @@ -164,8 +172,16 @@ def create_optimizer(
total_epochs=total_epochs,
**kwargs.pop("learning_rate_schedule")
)

optax_ctor = lambda lr: (getattr(optax, name)(learning_rate=lr, **kwargs))

if schedule_free_config.enabled:
optax_ctor = lambda lr: optax.contrib.schedule_free(
base_optimizer=getattr(optax, name)(learning_rate=lr, **kwargs),
learning_rate=lr,
**schedule_free_config.kwargs
)

return optax_wrapper.OptaxWrapper(
value_and_grad_func=value_and_grad_func,
value_func_has_aux=has_aux,
Expand Down
199 changes: 152 additions & 47 deletions examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from examples import optimizers
import ml_collections
import more_itertools
import optax


# Types for annotation
Expand Down Expand Up @@ -154,7 +155,8 @@ class SupervisedExperiment(abc.ABC):
to `True`.
train_model_func: The `model_loss_func` with `is_training` set to `True`.
eval_model_func: The `model_loss_func` with `is_training` set to `False`.
eval_batch: A pmapped version of `self._evaluate_single_batch`.
train_batch_pmap: A pmapped version of `self._train_batch`.
eval_batch_pmap: A pmapped version of `self._eval_batch`.
optimizer: The optimizer instance used for training.
"""

Expand Down Expand Up @@ -200,6 +202,7 @@ def __init__(
self.has_rng = has_rng
self.has_func_state = has_func_state
self.eval_splits = eval_splits

self.batch_size = ExperimentBatchSizes(
train=batch_size_calculator_ctor(
mode="train", **self.config.batch_size.train
Expand All @@ -215,22 +218,24 @@ def __init__(
self.estimator_model_func = functools.partial(
self.model_func_for_estimator, is_training=True
) if self.model_func_for_estimator is not None else None

self.train_model_func = functools.partial(
self.model_loss_func, is_training=True
)
self.eval_model_func = functools.partial(
self.model_loss_func, is_training=False
)
self.eval_batch = jax.pmap(
self._evaluate_single_batch, axis_name="eval_axis"

self.train_batch_pmap = jax.pmap(
self._train_batch, axis_name="train_axis"
)
self.eval_batch_pmap = jax.pmap(
self._eval_batch, axis_name="eval_axis"
)

# Log some useful information
getattr(self.batch_size, self.mode).log_machines_setup()

# Create the optimizer
self.optimizer = self.create_optimizer()

# Initialize the state
self._train_input, self._eval_input, self._init_batch = None, None, None

Expand All @@ -243,30 +248,50 @@ def __init__(
self._log_train_stats_with_polyak_avg_every_n_steps = config.get(
"log_train_stats_with_polyak_avg_every_n_steps", 0)

if (self._use_polyak_avg_with_decay_factor is None
and self._log_train_stats_with_polyak_avg_every_n_steps != 0):
raise ValueError(
"Polyak averaging must be enabled if setting"
"log_train_stats_with_polyak_avg_every_n_steps != 0.")

if self._log_train_stats_with_polyak_avg_every_n_steps != 0:
self._train_model_func_pmap = jax.pmap(self.train_model_func)
self._get_value_pmap = jax.pmap(lambda x: x.value)
self._get_value_pmap = jax.pmap(lambda x: x.value)

if self._use_polyak_avg_with_decay_factor:
self._update_polyak_average_pmap = jax.pmap(self._update_polyak_average,
donate_argnums=0)

self._refresh_func_state_for_eval_with_n_iters = config.get(
"refresh_func_state_for_eval_with_n_iters", 0)

if "schedule_free" in config:

self._schedule_free_config = config.schedule_free

if (self._schedule_free_config.enabled
and self._use_polyak_avg_with_decay_factor is not None):
raise ValueError("Cannot use Schedule Free method and Polyak averaging "
"together.")

else:
schedule_free_config = ml_collections.ConfigDict()
schedule_free_config.enabled = False
self._schedule_free_config = schedule_free_config.lock()

if self._schedule_free_config.enabled:
self._schedule_free_eval_params_pmap = jax.pmap(
optax.contrib.schedule_free_eval_params)

self._python_step = 0
self._num_tensors = 0
self._num_parameters = 0
self._optimizer_state_size = 0

# Create the optimizer
self.optimizer = self.create_optimizer()

@property
@abc.abstractmethod
def dataset_size(self) -> int:
"""The number of data points in the training set."""

@property
def _schedule_free_enabled(self):
return self._schedule_free_config.enabled

@property
def train_input(self) -> Iterator[Batch]:
"""Returns the current training iterator."""
Expand Down Expand Up @@ -340,6 +365,10 @@ def _polyak_weight(
del global_step, stats
return 1.0

@property
def _polyak_add_function(self):
return kfac_jax.utils.default_add_function

def _update_polyak_average(
self,
params_polyak: WeightedMovingAverage[Params] | None,
Expand All @@ -357,7 +386,8 @@ def _update_polyak_average(
params_polyak = params_polyak.copy()

params_polyak.update(
params, self._use_polyak_avg_with_decay_factor, weight)
params, self._use_polyak_avg_with_decay_factor, weight,
add_function=self._polyak_add_function)

return params_polyak

Expand Down Expand Up @@ -494,6 +524,28 @@ def _build_train_input(
) -> Iterator[Batch]:
"""Constructs the training dataset."""

def _train_batch(
self,
params: Params,
func_state: FuncState,
rng: PRNGKey,
batch: Batch,
) -> dict[str, Array]:
"""Evaluates a single batch in training mode."""

func_args = kfac_jax.optimizer.make_func_args(
params=params,
func_state=func_state,
rng=rng,
batch=batch,
has_state=self.has_func_state,
has_rng=self.has_rng,
)

return kfac_jax.optimizer.extract_func_outputs(
self.train_model_func(*func_args),
has_aux=self.has_aux, has_state=self.has_func_state)

def _maybe_update_polyak_average_and_stats(
self,
rng: PRNGKey,
Expand All @@ -514,19 +566,13 @@ def _maybe_update_polyak_average_and_stats(
else:
batch = self.train_inputs.peek() # pytype: disable=attribute-error

func_args = kfac_jax.optimizer.make_func_args(
loss_polyak, _, aux_polyak = self.train_batch_pmap(
params=self._get_value_pmap(self._params_polyak),
func_state=self._state,
rng=rng,
batch=batch,
has_state=self.has_func_state,
has_rng=self.has_rng,
)

loss_polyak, _, aux_polyak = kfac_jax.optimizer.extract_func_outputs(
self._train_model_func_pmap(*func_args),
has_aux=self.has_aux, has_state=self.has_func_state)

assert aux_polyak is not None
stats["loss_polyak"] = loss_polyak
stats.update({k + "_polyak": v for k, v in aux_polyak.items()})
Expand Down Expand Up @@ -599,11 +645,10 @@ def _build_eval_input(
) -> Iterator[Batch]:
"""Constructs the evaluation dataset."""

def _evaluate_single_batch(
def _eval_batch(
self,
global_step: Array,
params: Params,
params_polyak: WeightedMovingAverage[Params] | None,
func_state: FuncState,
opt_state: kfac_jax.Optimizer.State | optimizers.OptaxState,
rng: PRNGKey,
Expand All @@ -624,31 +669,39 @@ def _evaluate_single_batch(

loss, stats = self.eval_model_func(*func_args)

if params_polyak is not None:
stats["loss"] = loss

func_args = kfac_jax.optimizer.make_func_args(
params=params_polyak.value,
func_state=func_state,
rng=rng,
batch=batch,
has_state=self.has_func_state,
has_rng=self.has_rng,
)
if hasattr(opt_state, "data_seen"):
stats["data_seen"] = opt_state.data_seen

loss_no_polyak = loss
stats_no_polyak = stats
return stats

loss, stats = self.eval_model_func(*func_args)
def _refresh_func_state(
self,
params: Params,
func_state: FuncState,
rng: PRNGKey,
dataset_iter_thunk: Callable[[], Iterator[kfac_jax.utils.Batch]],
num_iters: int,
) -> FuncState:
"""Refreshes func_state on the given data using num_iters iterations."""

stats.update({k + "_no_polyak": v for k, v in stats_no_polyak.items()})
stats["loss_no_polyak"] = loss_no_polyak
dataset_iter = dataset_iter_thunk()

stats["loss"] = loss
for _ in range(num_iters):

if hasattr(opt_state, "data_seen"):
stats["data_seen"] = opt_state.data_seen
rng_batch, rng = kfac_jax.utils.p_split(rng)

return stats
try:
batch = next(dataset_iter)
except StopIteration:
dataset_iter = dataset_iter_thunk()
batch = next(dataset_iter)

_, func_state, _ = self.train_batch_pmap(
params, func_state, rng_batch, batch)

return func_state

def run_evaluation(
self,
Expand All @@ -657,22 +710,74 @@ def run_evaluation(
) -> dict[str, Numeric]:
"""Runs the evaluation of the currently loaded model parameters."""

if self._use_polyak_avg_with_decay_factor is not None:
params_polyak = self._get_value_pmap(self._params_polyak)
else:
params_polyak = None

if self._schedule_free_enabled:

assert isinstance(self._opt_state,
optax_wrapper.OptaxAndPreconditionState)
assert isinstance(self._opt_state.optax_state,
optax.contrib.ScheduleFreeState)

params_schedule_free = self._schedule_free_eval_params_pmap(
self._opt_state.optax_state, self._params)
else:
params_schedule_free = None

all_stats = dict()

# Evaluates both the train and eval split metrics
for name, dataset_iter_thunk in self.eval_input.items():

logging.info("Running evaluation for %s", name)

if params_polyak is not None:
func_state_polyak = self._refresh_func_state(
params_polyak,
self._state,
rng,
dataset_iter_thunk,
self._refresh_func_state_for_eval_with_n_iters,
)

if params_schedule_free is not None:
func_state_schedule_free = self._refresh_func_state(
params_schedule_free,
self._state,
rng,
dataset_iter_thunk,
self._refresh_func_state_for_eval_with_n_iters,
)

averaged_stats = kfac_jax.utils.MultiChunkAccumulator.empty(True)

for batch in dataset_iter_thunk():

key, rng = kfac_jax.utils.p_split(rng)

stats = self.eval_batch(
global_step, self._params, self._params_polyak, self._state,
self._opt_state, key, batch
)
stats = self.eval_batch_pmap(
global_step, self._params, self._state, self._opt_state, key, batch)

if params_polyak is not None:
stats_no_polyak = stats
stats = self.eval_batch_pmap(
global_step, params_polyak, func_state_polyak, self._opt_state,
key, batch)
stats.update(
{k + "_no_polyak": v for k, v in stats_no_polyak.items()
if k != "data_seen"})

if params_schedule_free is not None:
stats_no_sf = stats
stats = self.eval_batch_pmap(
global_step, params_schedule_free, func_state_schedule_free,
self._opt_state, key, batch)
stats.update(
{k + "_no_sf": v for k, v in stats_no_sf.items()
if k != "data_seen"})

averaged_stats.add(stats, 1)

Expand Down
3 changes: 1 addition & 2 deletions kfac_jax/_src/curvature_blocks/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,7 @@ def update_curvature_matrix_estimate(

if self.has_scale:

assert (state.diagonal_factors[0].raw_value.shape ==
self.parameters_shapes[0])
assert state.diagonal_factors[0].shape == self.parameters_shapes[0]

scale_shape = estimation_data.primals.params[0].shape

Expand Down
1 change: 1 addition & 0 deletions kfac_jax/_src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
del math

# accumulators
default_add_function = accumulators.default_add_function
WeightedMovingAverage = accumulators.WeightedMovingAverage
MultiChunkAccumulator = accumulators.MultiChunkAccumulator
del accumulators
Expand Down
Loading

0 comments on commit e9738d3

Please sign in to comment.