From b026fe26007fd9138b3b625d81ab4d767998a693 Mon Sep 17 00:00:00 2001 From: Joel Jennings Date: Mon, 19 Aug 2024 06:20:10 -0700 Subject: [PATCH] Add an option to specify a different value function for the preconditioner's curvature estimator. This is useful for cases where the value function used for training is expensive to add to the preconditioner, e.g. because it has expensive regularizers. PiperOrigin-RevId: 664776807 --- examples/optimizers.py | 2 ++ examples/training.py | 12 ++++++++++++ kfac_jax/_src/optimizer.py | 9 ++++++++- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/examples/optimizers.py b/examples/optimizers.py index adaff3e..414e04b 100644 --- a/examples/optimizers.py +++ b/examples/optimizers.py @@ -1300,6 +1300,7 @@ def create_optimizer( has_aux: bool, has_func_state: bool, has_rng: bool, + model_func_for_estimator: kfac_jax.optimizer.ValueFunc | None, dataset_size: int, train_total_batch_size: int, total_steps: int | None, @@ -1341,6 +1342,7 @@ def create_optimizer( value_func_has_aux=has_aux, value_func_has_state=has_func_state, value_func_has_rng=has_rng, + value_func_for_estimator=model_func_for_estimator, multi_device=True, **kwargs, ) diff --git a/examples/training.py b/examples/training.py index 5e6060c..f242923 100644 --- a/examples/training.py +++ b/examples/training.py @@ -141,12 +141,16 @@ class SupervisedExperiment(abc.ABC): has_aux: Whether the model function returns any auxiliary data. has_rng: Whether the model function needs a PRNG key. has_func_state: Whether the model function has a state. + model_func_for_estimator: A function that allows a different + computation of the loss of the model for the estimator. eval_splits: Evaluation splits of the evaluation dataset loader. batch_size: An instance of `ExperimentBatchSizes`. init_parameters_func: A function that initializes the parameters and optionally the state of the model if it has one. params_init: A function that initializes the model parameters. model_loss_func: A function that computes the loss for the model. + estimator_model_func: The `model_func_for_estimator` with `is_training` set + to `True`. train_model_func: The `model_loss_func` with `is_training` set to `True`. eval_model_func: The `model_loss_func` with `is_training` set to `False`. eval_batch: A pmapped version of `self._evaluate_single_batch`. @@ -163,6 +167,7 @@ def __init__( has_aux: bool, has_rng: bool, has_func_state: bool, + model_func_for_estimator: kfac_jax.optimizer.ValueFunc | None = None, eval_splits: tuple[str, ...] = ("train", "test"), batch_size_calculator_ctor: BatchSizeCalculatorCtor = BatchSizeCalculator, ): @@ -180,6 +185,8 @@ def __init__( has_aux: Whether the model function returns auxiliary data. has_rng: Whether the model function requires an RNG. has_func_state: Whether the model function has a state. + model_func_for_estimator: A function that allows a different + computation of the loss of the model for the estimator. eval_splits: Evaluation splits of the evaluation dataset loader. batch_size_calculator_ctor: A constructor function to create a batch size calculator. @@ -203,6 +210,10 @@ def __init__( self.params_init = jax.pmap(init_parameters_func) self.model_loss_func = model_loss_func + self.model_func_for_estimator = model_func_for_estimator + self.estimator_model_func = functools.partial( + self.model_func_for_estimator, is_training=True + ) self.train_model_func = functools.partial( self.model_loss_func, is_training=True ) @@ -386,6 +397,7 @@ def create_optimizer( has_aux=self.has_aux, has_func_state=self.has_func_state, has_rng=self.has_rng, + model_func_for_estimator=self.estimator_model_func, dataset_size=self.dataset_size, train_total_batch_size=self.batch_size.train.total, total_steps=self.config.training.steps, diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index 77e7b98..3fbdca0 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -99,6 +99,7 @@ def __init__( value_func_has_aux: bool = False, value_func_has_state: bool = False, value_func_has_rng: bool = False, + value_func_for_estimator: ValueFunc | None = None, use_adaptive_learning_rate: bool = False, learning_rate_schedule: ScheduleType | None = None, use_adaptive_momentum: bool = False, @@ -212,6 +213,11 @@ def __init__( value_func_has_rng: Boolean. Specifies whether the provided callable ``value_and_grad_func`` additionally takes as input an rng key. (Default: ``False``) + value_func_for_estimator: ValueFunc. If specified, this function will be + used by the preconditioner estimator instead of ``value_and_grad_func``. + This is useful for cases where the value function used for training is + expensive to add to the preconditioner, e.g. because it has costly + regularizers. (Default: ``None``) use_adaptive_learning_rate: Boolean. Specifies whether to use the special rule from the original K-FAC paper for picking the learning rate at each step. Note that this won't work well for stochastic objectives. If this @@ -488,7 +494,8 @@ def __init__( # Curvature estimator self._estimator = estimator_ctor( - func=self._value_func, + func=(self._value_func if value_func_for_estimator is None else + value_func_for_estimator), default_estimation_mode=estimation_mode, params_index=self._params_index, batch_index=batch_index,