Skip to content

Commit

Permalink
Add an option to specify a different value function for the precondit…
Browse files Browse the repository at this point in the history
…ioner's curvature estimator.

This is useful for cases where the value function used for training is expensive to add to the preconditioner, e.g. because it has expensive regularizers.

PiperOrigin-RevId: 664776807
  • Loading branch information
joeljennings authored and KfacJaxDev committed Aug 22, 2024
1 parent 60644ef commit b026fe2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
2 changes: 2 additions & 0 deletions examples/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,7 @@ def create_optimizer(
has_aux: bool,
has_func_state: bool,
has_rng: bool,
model_func_for_estimator: kfac_jax.optimizer.ValueFunc | None,
dataset_size: int,
train_total_batch_size: int,
total_steps: int | None,
Expand Down Expand Up @@ -1341,6 +1342,7 @@ def create_optimizer(
value_func_has_aux=has_aux,
value_func_has_state=has_func_state,
value_func_has_rng=has_rng,
value_func_for_estimator=model_func_for_estimator,
multi_device=True,
**kwargs,
)
Expand Down
12 changes: 12 additions & 0 deletions examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,16 @@ class SupervisedExperiment(abc.ABC):
has_aux: Whether the model function returns any auxiliary data.
has_rng: Whether the model function needs a PRNG key.
has_func_state: Whether the model function has a state.
model_func_for_estimator: A function that allows a different
computation of the loss of the model for the estimator.
eval_splits: Evaluation splits of the evaluation dataset loader.
batch_size: An instance of `ExperimentBatchSizes`.
init_parameters_func: A function that initializes the parameters and
optionally the state of the model if it has one.
params_init: A function that initializes the model parameters.
model_loss_func: A function that computes the loss for the model.
estimator_model_func: The `model_func_for_estimator` with `is_training` set
to `True`.
train_model_func: The `model_loss_func` with `is_training` set to `True`.
eval_model_func: The `model_loss_func` with `is_training` set to `False`.
eval_batch: A pmapped version of `self._evaluate_single_batch`.
Expand All @@ -163,6 +167,7 @@ def __init__(
has_aux: bool,
has_rng: bool,
has_func_state: bool,
model_func_for_estimator: kfac_jax.optimizer.ValueFunc | None = None,
eval_splits: tuple[str, ...] = ("train", "test"),
batch_size_calculator_ctor: BatchSizeCalculatorCtor = BatchSizeCalculator,
):
Expand All @@ -180,6 +185,8 @@ def __init__(
has_aux: Whether the model function returns auxiliary data.
has_rng: Whether the model function requires an RNG.
has_func_state: Whether the model function has a state.
model_func_for_estimator: A function that allows a different
computation of the loss of the model for the estimator.
eval_splits: Evaluation splits of the evaluation dataset loader.
batch_size_calculator_ctor: A constructor function to create a batch size
calculator.
Expand All @@ -203,6 +210,10 @@ def __init__(

self.params_init = jax.pmap(init_parameters_func)
self.model_loss_func = model_loss_func
self.model_func_for_estimator = model_func_for_estimator
self.estimator_model_func = functools.partial(
self.model_func_for_estimator, is_training=True
)
self.train_model_func = functools.partial(
self.model_loss_func, is_training=True
)
Expand Down Expand Up @@ -386,6 +397,7 @@ def create_optimizer(
has_aux=self.has_aux,
has_func_state=self.has_func_state,
has_rng=self.has_rng,
model_func_for_estimator=self.estimator_model_func,
dataset_size=self.dataset_size,
train_total_batch_size=self.batch_size.train.total,
total_steps=self.config.training.steps,
Expand Down
9 changes: 8 additions & 1 deletion kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
value_func_has_aux: bool = False,
value_func_has_state: bool = False,
value_func_has_rng: bool = False,
value_func_for_estimator: ValueFunc | None = None,
use_adaptive_learning_rate: bool = False,
learning_rate_schedule: ScheduleType | None = None,
use_adaptive_momentum: bool = False,
Expand Down Expand Up @@ -212,6 +213,11 @@ def __init__(
value_func_has_rng: Boolean. Specifies whether the provided callable
``value_and_grad_func`` additionally takes as input an rng key.
(Default: ``False``)
value_func_for_estimator: ValueFunc. If specified, this function will be
used by the preconditioner estimator instead of ``value_and_grad_func``.
This is useful for cases where the value function used for training is
expensive to add to the preconditioner, e.g. because it has costly
regularizers. (Default: ``None``)
use_adaptive_learning_rate: Boolean. Specifies whether to use the special
rule from the original K-FAC paper for picking the learning rate at each
step. Note that this won't work well for stochastic objectives. If this
Expand Down Expand Up @@ -488,7 +494,8 @@ def __init__(

# Curvature estimator
self._estimator = estimator_ctor(
func=self._value_func,
func=(self._value_func if value_func_for_estimator is None else
value_func_for_estimator),
default_estimation_mode=estimation_mode,
params_index=self._params_index,
batch_index=batch_index,
Expand Down

0 comments on commit b026fe2

Please sign in to comment.