diff --git a/examples/optimizers.py b/examples/optimizers.py index 9583523..6b84090 100644 --- a/examples/optimizers.py +++ b/examples/optimizers.py @@ -113,9 +113,17 @@ def create_optimizer( train_total_batch_size: int, total_steps: int | None, total_epochs: float | None, + schedule_free_config: config_dict.ConfigDict, ) -> optax_wrapper.OptaxWrapper | kfac_jax.Optimizer: """Creates an optimizer from the provided configuration.""" + is_optax = "kfac" not in name and hasattr(optax, name) + + if not is_optax and schedule_free_config.enabled: + raise ValueError( + "Schedule Free is only supported for optax optimizers." + ) + value_and_grad_func = jax.value_and_grad(train_model_func, has_aux=has_aux) kwargs = dict(**config[name]) @@ -155,7 +163,7 @@ def create_optimizer( **kwargs, ) - elif hasattr(optax, name): + elif is_optax: learning_rate_schedule = schedules.construct_schedule( dataset_size=dataset_size, @@ -164,7 +172,15 @@ def create_optimizer( total_epochs=total_epochs, **kwargs.pop("learning_rate_schedule") ) - optax_ctor = lambda lr: (getattr(optax, name)(learning_rate=lr, **kwargs)) + + if schedule_free_config.enabled: + optax_ctor = lambda lr: optax.contrib.schedule_free( + base_optimizer=getattr(optax, name)(learning_rate=lr, **kwargs), + learning_rate=lr, + **schedule_free_config.kwargs + ) + else: + optax_ctor = lambda lr: (getattr(optax, name)(learning_rate=lr, **kwargs)) return optax_wrapper.OptaxWrapper( value_and_grad_func=value_and_grad_func, diff --git a/examples/training.py b/examples/training.py index f914695..489744e 100644 --- a/examples/training.py +++ b/examples/training.py @@ -33,6 +33,7 @@ from examples import optimizers import ml_collections import more_itertools +import optax # Types for annotation @@ -42,6 +43,7 @@ Params = kfac_jax.utils.Params Batch = kfac_jax.utils.Batch FuncState = kfac_jax.utils.FuncState +FuncAux = kfac_jax.utils.FuncAux WeightedMovingAverage = kfac_jax.utils.WeightedMovingAverage InitFunc = Callable[[PRNGKey, Batch], Params] @@ -154,7 +156,8 @@ class SupervisedExperiment(abc.ABC): to `True`. train_model_func: The `model_loss_func` with `is_training` set to `True`. eval_model_func: The `model_loss_func` with `is_training` set to `False`. - eval_batch: A pmapped version of `self._evaluate_single_batch`. + train_batch_pmap: A pmapped version of `self._train_batch`. + eval_batch_pmap: A pmapped version of `self._eval_batch`. optimizer: The optimizer instance used for training. """ @@ -200,6 +203,7 @@ def __init__( self.has_rng = has_rng self.has_func_state = has_func_state self.eval_splits = eval_splits + self.batch_size = ExperimentBatchSizes( train=batch_size_calculator_ctor( mode="train", **self.config.batch_size.train @@ -215,58 +219,87 @@ def __init__( self.estimator_model_func = functools.partial( self.model_func_for_estimator, is_training=True ) if self.model_func_for_estimator is not None else None + self.train_model_func = functools.partial( self.model_loss_func, is_training=True ) self.eval_model_func = functools.partial( self.model_loss_func, is_training=False ) - self.eval_batch = jax.pmap( - self._evaluate_single_batch, axis_name="eval_axis" + + self.train_batch_pmap = jax.pmap( + self._train_batch, axis_name="train_axis" + ) + self.eval_batch_pmap = jax.pmap( + self._eval_batch, axis_name="eval_axis" ) # Log some useful information getattr(self.batch_size, self.mode).log_machines_setup() - # Create the optimizer - self.optimizer = self.create_optimizer() - # Initialize the state self._train_input, self._eval_input, self._init_batch = None, None, None self._params, self._state, self._opt_state = None, None, None self._params_polyak = None + # None corresponds to not using Polyak averaging. To get non-decayed Polyak + # averaging (aka classic Polyak averaging with a straight average), set this + # to 1.0. self._use_polyak_avg_with_decay_factor = config.get( "use_polyak_avg_with_decay_factor", None) self._log_train_stats_with_polyak_avg_every_n_steps = config.get( "log_train_stats_with_polyak_avg_every_n_steps", 0) - if (self._use_polyak_avg_with_decay_factor is None - and self._log_train_stats_with_polyak_avg_every_n_steps != 0): - raise ValueError( - "Polyak averaging must be enabled if setting" - "log_train_stats_with_polyak_avg_every_n_steps != 0.") - - if self._log_train_stats_with_polyak_avg_every_n_steps != 0: - self._train_model_func_pmap = jax.pmap(self.train_model_func) - self._get_value_pmap = jax.pmap(lambda x: x.value) + self._get_value_pmap = jax.pmap(lambda x: x.value) if self._use_polyak_avg_with_decay_factor: self._update_polyak_average_pmap = jax.pmap(self._update_polyak_average, donate_argnums=0) + self._refresh_func_state_for_eval_with_n_iters = config.get( + "refresh_func_state_for_eval_with_n_iters", 0) + + if "schedule_free" in config: + + self._schedule_free_config = config.schedule_free + + if (self._schedule_free_config.enabled + and self._use_polyak_avg_with_decay_factor is not None): + raise ValueError("Cannot use Schedule Free method and Polyak averaging " + "together.") + + else: + schedule_free_config = ml_collections.ConfigDict() + schedule_free_config.enabled = False + self._schedule_free_config = schedule_free_config.lock() + + if self._schedule_free_config.enabled: + self._schedule_free_eval_params_pmap = jax.pmap( + optax.contrib.schedule_free_eval_params) + self._python_step = 0 self._num_tensors = 0 self._num_parameters = 0 self._optimizer_state_size = 0 + # Create the optimizer + self.optimizer = self.create_optimizer() + @property @abc.abstractmethod def dataset_size(self) -> int: """The number of data points in the training set.""" + @property + def _schedule_free_enabled(self): + return self._schedule_free_config.enabled + + @property + def _polyak_avg_enabled(self): + return self._use_polyak_avg_with_decay_factor is not None + @property def train_input(self) -> Iterator[Batch]: """Returns the current training iterator.""" @@ -340,6 +373,10 @@ def _polyak_weight( del global_step, stats return 1.0 + @property + def _polyak_add_function(self): + return kfac_jax.utils.default_add_function + def _update_polyak_average( self, params_polyak: WeightedMovingAverage[Params] | None, @@ -357,7 +394,8 @@ def _update_polyak_average( params_polyak = params_polyak.copy() params_polyak.update( - params, self._use_polyak_avg_with_decay_factor, weight) + params, self._use_polyak_avg_with_decay_factor, weight, + add_function=self._polyak_add_function) return params_polyak @@ -403,6 +441,7 @@ def create_optimizer( train_total_batch_size=self.batch_size.train.total, total_steps=self.config.training.steps, total_epochs=self.config.training.epochs, + schedule_free_config=self._schedule_free_config, ) def maybe_initialize_state(self): @@ -494,6 +533,28 @@ def _build_train_input( ) -> Iterator[Batch]: """Constructs the training dataset.""" + def _train_batch( + self, + params: Params, + func_state: FuncState | None, + rng: PRNGKey | None, + batch: Batch, + ) -> tuple[Array, FuncState | None, FuncAux | None]: + """Evaluates a single batch in training mode.""" + + func_args = kfac_jax.optimizer.make_func_args( + params=params, + func_state=func_state, + rng=rng, + batch=batch, + has_state=self.has_func_state, + has_rng=self.has_rng, + ) + + return kfac_jax.optimizer.extract_func_outputs( + self.train_model_func(*func_args), + has_aux=self.has_aux, has_state=self.has_func_state) + def _maybe_update_polyak_average_and_stats( self, rng: PRNGKey, @@ -501,7 +562,7 @@ def _maybe_update_polyak_average_and_stats( ): """Updates the polyak-averaged version of the parameters and gets stats.""" - if self._use_polyak_avg_with_decay_factor is not None: + if self._polyak_avg_enabled: if (self._log_train_stats_with_polyak_avg_every_n_steps and ( (self._python_step + 1) % @@ -514,19 +575,13 @@ def _maybe_update_polyak_average_and_stats( else: batch = self.train_inputs.peek() # pytype: disable=attribute-error - func_args = kfac_jax.optimizer.make_func_args( + loss_polyak, _, aux_polyak = self.train_batch_pmap( params=self._get_value_pmap(self._params_polyak), func_state=self._state, rng=rng, batch=batch, - has_state=self.has_func_state, - has_rng=self.has_rng, ) - loss_polyak, _, aux_polyak = kfac_jax.optimizer.extract_func_outputs( - self._train_model_func_pmap(*func_args), - has_aux=self.has_aux, has_state=self.has_func_state) - assert aux_polyak is not None stats["loss_polyak"] = loss_polyak stats.update({k + "_polyak": v for k, v in aux_polyak.items()}) @@ -599,14 +654,13 @@ def _build_eval_input( ) -> Iterator[Batch]: """Constructs the evaluation dataset.""" - def _evaluate_single_batch( + def _eval_batch( self, global_step: Array, params: Params, - params_polyak: WeightedMovingAverage[Params] | None, - func_state: FuncState, + func_state: FuncState | None, opt_state: kfac_jax.Optimizer.State | optimizers.OptaxState, - rng: PRNGKey, + rng: PRNGKey | None, batch: Batch, ) -> dict[str, Array]: """Evaluates a single batch.""" @@ -624,31 +678,39 @@ def _evaluate_single_batch( loss, stats = self.eval_model_func(*func_args) - if params_polyak is not None: + stats["loss"] = loss - func_args = kfac_jax.optimizer.make_func_args( - params=params_polyak.value, - func_state=func_state, - rng=rng, - batch=batch, - has_state=self.has_func_state, - has_rng=self.has_rng, - ) + if hasattr(opt_state, "data_seen"): + stats["data_seen"] = opt_state.data_seen - loss_no_polyak = loss - stats_no_polyak = stats + return stats - loss, stats = self.eval_model_func(*func_args) + def _refresh_func_state( + self, + params: Params, + func_state: FuncState, + rng: PRNGKey, + dataset_iter_thunk: Callable[[], Iterator[kfac_jax.utils.Batch]], + num_iters: int, + ) -> FuncState: + """Refreshes func_state on the given data using num_iters iterations.""" - stats.update({k + "_no_polyak": v for k, v in stats_no_polyak.items()}) - stats["loss_no_polyak"] = loss_no_polyak + dataset_iter = dataset_iter_thunk() - stats["loss"] = loss + for _ in range(num_iters): - if hasattr(opt_state, "data_seen"): - stats["data_seen"] = opt_state.data_seen + rng_batch, rng = kfac_jax.utils.p_split(rng) - return stats + try: + batch = next(dataset_iter) + except StopIteration: + dataset_iter = dataset_iter_thunk() + batch = next(dataset_iter) + + _, func_state, _ = self.train_batch_pmap( + params, func_state, rng_batch, batch) + + return func_state def run_evaluation( self, @@ -657,22 +719,80 @@ def run_evaluation( ) -> dict[str, Numeric]: """Runs the evaluation of the currently loaded model parameters.""" + if self._polyak_avg_enabled: + params_polyak = self._get_value_pmap(self._params_polyak) + else: + params_polyak = None + + if self._schedule_free_enabled: + + assert isinstance(self._opt_state, + optax_wrapper.OptaxAndPreconditionState) + assert isinstance(self._opt_state.optax_state, + optax.contrib.ScheduleFreeState) + + params_schedule_free = self._schedule_free_eval_params_pmap( + self._opt_state.optax_state, self._params) + else: + params_schedule_free = None + all_stats = dict() # Evaluates both the train and eval split metrics for name, dataset_iter_thunk in self.eval_input.items(): + logging.info("Running evaluation for %s", name) + if params_polyak is not None and self.has_func_state: + assert self._state is not None + func_state_polyak = self._refresh_func_state( + params_polyak, + self._state, + rng, + dataset_iter_thunk, + self._refresh_func_state_for_eval_with_n_iters, + ) + else: + func_state_polyak = self._state + + if params_schedule_free is not None and self.has_func_state: + assert self._state is not None + func_state_schedule_free = self._refresh_func_state( + params_schedule_free, + self._state, + rng, + dataset_iter_thunk, + self._refresh_func_state_for_eval_with_n_iters, + ) + else: + func_state_schedule_free = self._state + averaged_stats = kfac_jax.utils.MultiChunkAccumulator.empty(True) for batch in dataset_iter_thunk(): key, rng = kfac_jax.utils.p_split(rng) - stats = self.eval_batch( - global_step, self._params, self._params_polyak, self._state, - self._opt_state, key, batch - ) + stats = self.eval_batch_pmap( + global_step, self._params, self._state, self._opt_state, key, batch) + + if params_polyak is not None: + stats_no_polyak = stats + stats = self.eval_batch_pmap( + global_step, params_polyak, func_state_polyak, self._opt_state, + key, batch) + stats.update( + {k + "_no_polyak": v for k, v in stats_no_polyak.items() + if k != "data_seen"}) + + if params_schedule_free is not None: + stats_no_sf = stats + stats = self.eval_batch_pmap( + global_step, params_schedule_free, func_state_schedule_free, + self._opt_state, key, batch) + stats.update( + {k + "_no_sf": v for k, v in stats_no_sf.items() + if k != "data_seen"}) averaged_stats.add(stats, 1) diff --git a/kfac_jax/_src/curvature_blocks/diagonal.py b/kfac_jax/_src/curvature_blocks/diagonal.py index 43f0539..a8cfb4d 100644 --- a/kfac_jax/_src/curvature_blocks/diagonal.py +++ b/kfac_jax/_src/curvature_blocks/diagonal.py @@ -346,8 +346,7 @@ def update_curvature_matrix_estimate( if self.has_scale: - assert (state.diagonal_factors[0].raw_value.shape == - self.parameters_shapes[0]) + assert state.diagonal_factors[0].shape == self.parameters_shapes[0] scale_shape = estimation_data.primals.params[0].shape diff --git a/kfac_jax/_src/curvature_blocks/kronecker_factored.py b/kfac_jax/_src/curvature_blocks/kronecker_factored.py index b0670a9..7e3a824 100644 --- a/kfac_jax/_src/curvature_blocks/kronecker_factored.py +++ b/kfac_jax/_src/curvature_blocks/kronecker_factored.py @@ -403,7 +403,7 @@ def _to_dense_unscaled(self, state: "KroneckerFactored.State") -> Array: # Permute the matrix according to the parameters canonical order inputs_factor = utils.block_permuted( state.factors[0].value, - block_sizes=[state.factors[0].raw_value.shape[0] - 1, 1], + block_sizes=[state.factors[0].shape[0] - 1, 1], block_order=(1, 0), ) diff --git a/kfac_jax/_src/utils/__init__.py b/kfac_jax/_src/utils/__init__.py index 55a5ea5..60b4fd9 100644 --- a/kfac_jax/_src/utils/__init__.py +++ b/kfac_jax/_src/utils/__init__.py @@ -134,6 +134,7 @@ del math # accumulators +default_add_function = accumulators.default_add_function WeightedMovingAverage = accumulators.WeightedMovingAverage MultiChunkAccumulator = accumulators.MultiChunkAccumulator del accumulators diff --git a/kfac_jax/_src/utils/accumulators.py b/kfac_jax/_src/utils/accumulators.py index 77d34da..4acc66e 100644 --- a/kfac_jax/_src/utils/accumulators.py +++ b/kfac_jax/_src/utils/accumulators.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """K-FAC for accumulating statistics.""" -from typing import Any, Generic +from typing import Any, Callable, Generic import jax import jax.numpy as jnp @@ -28,72 +28,76 @@ ArrayTree = types.ArrayTree TArrayTree = types.TArrayTree +AddFunction = Callable[[TArrayTree, TArrayTree, Numeric, Numeric], TArrayTree] + + +def default_add_function( + obj1: TArrayTree, + obj2: TArrayTree, + coeff1: Numeric, + coeff2: Numeric +) -> TArrayTree: + + return jax.tree_util.tree_map( + lambda x, y: coeff1 * x + coeff2 * y, obj1, obj2) + @misc.register_state_class class WeightedMovingAverage(Generic[TArrayTree], misc.State): """A wrapped class for an arbitrary weighted moving average.""" weight: Numeric - raw_value: TArrayTree | None + value: TArrayTree | None @property def ndim(self) -> int: - assert self.raw_value is not None - return self.raw_value.ndim + assert self.value is not None + return self.value.ndim @property def shape(self) -> Shape: - assert self.raw_value is not None - return self.raw_value.shape + assert self.value is not None + return self.value.shape @property def dtype(self) -> DType: - assert self.raw_value is not None - return self.raw_value.dtype - - @property - def value(self) -> TArrayTree: - """The value of the underlying arrays data structure.""" - return jax.tree_util.tree_map(lambda x: x / self.weight, self.raw_value) + assert self.value is not None + return self.value.dtype def update( self, value: TArrayTree, old_weight_multiplier: Numeric, new_weight: Numeric, + add_function: AddFunction = default_add_function, ): """Updates the underlying array and weight accordingly.""" - assert self.raw_value is not None + assert self.value is not None # A negative value of new_weight means we should only update the value # (with -new_weight) and not the total running weight. This roughly # corresponds to summation instead of averaging, and is useful in a few # contexts. - new_weight_for_value = jnp.abs(new_weight) - new_weight_for_weight = jax.nn.relu(new_weight) + self.weight = old_weight_multiplier * self.weight + jax.nn.relu(new_weight) + eta_for_old = jax.nn.relu(new_weight) / self.weight + eta_for_new = jnp.abs(new_weight) / self.weight - self.weight = self.weight * old_weight_multiplier + new_weight_for_weight - - self.raw_value = jax.tree_util.tree_map( - lambda x, y: x * old_weight_multiplier + y * new_weight_for_value, - self.raw_value, - value, - ) + self.value = add_function(self.value, value, 1.0 - eta_for_old, eta_for_new) def sync(self, pmap_axis_name: str | None): """Syncs the underlying array across devices.""" - if self.raw_value is None: - raise ValueError("`raw_value` has not been set yet.") + if self.value is None: + raise ValueError("`_value` has not been set yet.") - self.raw_value = parallel.pmean_if_pmap(self.raw_value, pmap_axis_name) + self.value = parallel.pmean_if_pmap(self.value, pmap_axis_name) def clear(self, value_to_none: bool = False): """Resets the weighted average.""" self.weight = jnp.zeros_like(self.weight) - self.raw_value = None if value_to_none else jnp.zeros_like(self.raw_value) + self.value = None if value_to_none else jnp.zeros_like(self.value) def value_and_clear(self) -> TArrayTree: """Retrieves the value of the weighted average and clears it.""" @@ -101,6 +105,7 @@ def value_and_clear(self) -> TArrayTree: value = self.value self.clear() + assert value is not None return value @classmethod @@ -113,7 +118,7 @@ def zeros_array( return cls( # pytype: disable=wrong-keyword-args weight=jnp.zeros([], dtype=dtype), - raw_value=jnp.zeros(shape, dtype=dtype), + value=jnp.zeros(shape, dtype=dtype), ) @classmethod @@ -124,7 +129,7 @@ def zeros_like(cls, value: TArrayTree) -> "WeightedMovingAverage[TArrayTree]": weight=jnp.array( 0.0, dtype=types.get_float_dtype_and_check_consistency(value) ), - raw_value=jax.tree_util.tree_map(jnp.zeros_like, value), + value=jax.tree_util.tree_map(jnp.zeros_like, value), )