diff --git a/kfac_jax/_src/curvature_blocks.py b/kfac_jax/_src/curvature_blocks.py deleted file mode 100644 index cfc1f7c..0000000 --- a/kfac_jax/_src/curvature_blocks.py +++ /dev/null @@ -1,2488 +0,0 @@ -# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""K-FAC curvature approximation to single layer blocks.""" -import abc -import collections -import functools -import math -import string -from typing import Any, Sequence - -import jax -import jax.numpy as jnp -import jax.scipy -from kfac_jax._src import layers_and_loss_tags as tags -from kfac_jax._src import patches_second_moment as psm -from kfac_jax._src import tag_graph_matcher as tgm -from kfac_jax._src import tracer -from kfac_jax._src import utils -import numpy as np -from typing_extensions import Self - -# Types for annotation -Array = utils.Array -Scalar = utils.Scalar -Numeric = utils.Numeric -PRNGKey = utils.PRNGKey -Shape = utils.Shape -DType = utils.DType -ScalarOrSequence = Scalar | Sequence[Scalar] -Cache = dict[str, Array | dict[str, Array]] - -# Special global variables -# This is used for einsum strings -_ALPHABET = string.ascii_lowercase -# The default value that would be used for the argument -# ``max_elements_for_vmap``, when it is set to ``None`` in the -# ``Conv2DDiagonal`` and ``Conv2DFull` curvature blocks. -_MAX_PARALLEL_ELEMENTS: int = 2 ** 23 -# The default value that would be used for the argument -# ``eigen_decomposition_threshold``, when it is set to ``None`` in any of the -# curvature blocks that inherit from ``Full`. -_DEFAULT_EIGEN_DECOMPOSITION_THRESHOLD = 5 - - -def set_max_parallel_elements(value: int): - """Sets the default value of maximum parallel elements in the module. - - This value is used to determine the parallel-to-memory tradeoff in the - curvature estimation procedure of :class:`~Conv2DDiagonal` and - :class:`~Conv2DFull`. See their corresponding docs for further details. - - Args: - value: The default value for maximum number of parallel elements. - """ - global _MAX_PARALLEL_ELEMENTS - _MAX_PARALLEL_ELEMENTS = value - - -def get_max_parallel_elements() -> int: - """Returns the default value of maximum parallel elements in the module. - - This value is used to determine the parallel-to-memory tradeoff in the - curvature estimation procedure of :class:`~Conv2DDiagonal` and - :class:`~Conv2DFull`. See their corresponding docs for further details. - - Returns: - The default value for maximum number of parallel elements. - """ - return _MAX_PARALLEL_ELEMENTS - - -def set_default_eigen_decomposition_threshold(value: int): - """Sets the default value of the eigen decomposition threshold. - - This value is used in :class:`~Full` to determine when updating the cache, - at what number of different powers to switch the implementation from a simple - matrix power to an eigenvector decomposition. - - Args: - value: The default value for eigen decomposition threshold. - """ - global _DEFAULT_EIGEN_DECOMPOSITION_THRESHOLD - _DEFAULT_EIGEN_DECOMPOSITION_THRESHOLD = value - - -def get_default_eigen_decomposition_threshold() -> int: - """Returns the default value of the eigen decomposition threshold. - - This value is used in :class:`~Full` to determine when updating the cache, - at what number of different powers to switch the implementation from a simple - matrix power to an eigenvector decomposition. - - Returns: - The default value of the eigen decomposition threshold. - """ - return _DEFAULT_EIGEN_DECOMPOSITION_THRESHOLD - - -def _to_real_set( - number_or_sequence: ScalarOrSequence | None -) -> set[Scalar]: - """Converts the optional number or sequence to a set.""" - if number_or_sequence is None: - return set() - elif isinstance(number_or_sequence, set): - return number_or_sequence - elif isinstance(number_or_sequence, (float, int)): - return {number_or_sequence} - elif (isinstance(number_or_sequence, collections.abc.Sequence) and - all(isinstance(x, (int, float)) for x in number_or_sequence)): - return set(number_or_sequence) - else: - raise ValueError(f"Expecting a real-number or a sequence of reals, but got " - f"{type(number_or_sequence)}.") - - -class CurvatureBlock(utils.Finalizable): - """Abstract class for curvature approximation blocks. - - A CurvatureBlock defines a curvature matrix to be estimated, and gives methods - to multiply powers of this with a vector. Powers can be computed exactly or - with a class-determined approximation. Cached versions of the powers can be - pre-computed to make repeated multiplications cheaper. During initialization, - you would have to explicitly specify all powers that you will need to cache. - """ - - @utils.register_state_class - class State(utils.State): - """Persistent state of the block. - - Any subclasses of :class:`~CurvatureBlock` should also internally extend - this class, with any attributes needed for the curvature estimation. - - Attributes: - cache: A dictionary, containing any state data that is updated on - irregular intervals, such as inverses, eigenvalues, etc. Elements of - this are updated via calls to :func:`~CurvatureBlock.update_cache`, and - do not necessarily correspond to the most up-to-date curvature estimate. - """ - cache: dict[str, Array | dict[str, Array]] | None - - def __init__(self, layer_tag_eq: tags.LayerTagEqn): - """Initializes the block. - - Args: - layer_tag_eq: The Jax equation corresponding to the layer tag that this - block will approximate the curvature to. - """ - super().__init__() - - self._layer_tag_eq = layer_tag_eq - - self.finalize() - - @property - def name(self) -> str: - return tags.layer_eqn_name(self._layer_tag_eq) - - @property - def layer_tag_primitive(self) -> tags.LayerTag: - """The :class:`jax.core.Primitive` corresponding to the block's tag equation.""" - - primitive = self._layer_tag_eq.primitive - assert isinstance(primitive, tgm.tags.LayerTag) - - return primitive - - @property - def parameter_variables(self) -> tuple[jax.core.Var, ...]: - """The parameter variables of the underlying Jax equation.""" - - param_vars = [] - - for p in tags.layer_eqn_data(self._layer_tag_eq).params: - - assert isinstance(p, jax.core.Var) - param_vars.append(p) - - return tuple(param_vars) - - @property - def outputs_shapes(self) -> tuple[Shape, ...]: - """The shapes of the output variables of the block's tag equation.""" - - output_vars = tags.layer_eqn_data(self._layer_tag_eq).outputs - - return jax.tree.map(lambda x: x.aval.shape, output_vars) - - @property - def inputs_shapes(self) -> tuple[Shape, ...]: - """The shapes of the input variables of the block's tag equation.""" - - input_vars = tags.layer_eqn_data(self._layer_tag_eq).inputs - - return jax.tree.map(lambda x: x.aval.shape, input_vars) - - @property - def parameters_shapes(self) -> tuple[Shape, ...]: - """The shapes of the parameter variables of the block's tag equation.""" - return tuple(jax.tree.map( - lambda x: tuple(x.aval.shape), self.parameter_variables)) - - @property - def dtype(self) -> DType: - dtypes = set(p.aval.dtype for p in self.parameter_variables) # pytype: disable=attribute-error - if len(dtypes) > 1: - raise ValueError("Not all parameters are the same dtype.") - return dtypes.pop() - - @property - def parameters_canonical_order(self) -> tuple[int, ...]: - """The canonical order of the parameter variables.""" - - return tuple(np.argsort([p.count for p in self.parameter_variables])) - - @property - def layer_tag_extra_params(self) -> dict[str, Any]: - """Any extra parameters of passed into the Jax primitive of this block.""" - - return self._layer_tag_eq.params - - @property - def number_of_parameters(self) -> int: - """Number of parameter variables of this block.""" - - return len(self.parameters_shapes) - - @property - def dim(self) -> int: - """The number of elements of all parameter variables together.""" - - return sum(utils.product(shape) for shape in self.parameters_shapes) - - def scale(self, state: State, use_cache: bool) -> Numeric: - """A scalar pre-factor of the curvature approximation. - - Importantly, all methods assume that whenever a user requests cached values, - any state dependant scale is taken into account by the cache (e.g. either - stored explicitly and used or mathematically added to values). - - Args: - state: The state for this block. - use_cache: Whether the method requesting this is using cached values or - not. - - Returns: - A scalar value to be multiplied with any unscaled block representation. - """ - - # TODO(jamesmartens,botev): This way of handling state dependent scale is - # a bit hacky and leads to complexity in other parts of the code that must - # be aware of how this part works. Should try to replace this with something - # better. - - if use_cache: - return self.fixed_scale() - - return self.fixed_scale() * self.state_dependent_scale(state) - - def fixed_scale(self) -> Numeric: - """A fixed scalar pre-factor of the curvature (e.g. constant).""" - return 1.0 - - def state_dependent_scale(self, state: State) -> Numeric: - """A scalar pre-factor of the curvature, computed from the most fresh curvature estimate.""" - del state # Unused - return 1.0 - - def __str__(self): - return (f"{self.__class__.__name__}, tag name: {self.name}, " - f"params shapes: {self.parameters_shapes!r}") - - @utils.auto_scope_method - def init( - self, - rng: PRNGKey, - exact_powers_to_cache: ScalarOrSequence | None, - approx_powers_to_cache: ScalarOrSequence | None, - cache_eigenvalues: bool, - ) -> State: - """Initializes the state for this block. - - Args: - rng: The PRNGKey which to be used for any randomness of the initialization - exact_powers_to_cache: A single value, or multiple values in a list, which - specify which exact matrix powers the block should be caching. Matrix - powers, which are expected to be used in - :func:`~CurvatureBlock.multiply_matpower`, - :func:`~CurvatureBlock.multiply_inverse` or - :func:`~CurvatureBlock.multiply` with ``exact_power=True`` and - ``use_cached=True`` must be provided here. - approx_powers_to_cache: A single value, or multiple values in a list, - which specify approximate matrix powers the block should be caching. - Matrix powers, which are expected to be used in - :func:`~CurvatureBlock.multiply_matrix_power`, - :func:`~CurvatureBlock.multiply_inverse` or - :func:`~CurvatureBlock.multiply` with ``exact_power=False`` and - ``use_cached=True`` must be provided here. - cache_eigenvalues: Specifies whether the block should be caching the - eigenvalues of its approximate curvature. - Returns: - A dictionary with the initialized state. - """ - return self._init( - rng=rng, - exact_powers_to_cache=_to_real_set(exact_powers_to_cache), - approx_powers_to_cache=_to_real_set(approx_powers_to_cache), - cache_eigenvalues=cache_eigenvalues) - - @abc.abstractmethod - def _init( - self, - rng: PRNGKey, - exact_powers_to_cache: set[Scalar], - approx_powers_to_cache: set[Scalar], - cache_eigenvalues: bool, - ) -> State: - """The non-public interface of ``init``.""" - - @abc.abstractmethod - def sync( - self, - state: State, - pmap_axis_name: str, - ) -> State: - """Syncs the state across different devices (does not sync the cache).""" - - @utils.auto_scope_method - def multiply_matpower( - self, - state: State, - vector: Sequence[Array], - identity_weight: Numeric, - power: Scalar, - exact_power: bool, - use_cached: bool, - ) -> tuple[Array, ...]: - """Computes ``(BlockMatrix + identity_weight I)**power`` times ``vector``. - - Args: - state: The state for this block. - vector: A tuple of arrays that should have the same shapes as the block's - parameters_shapes, which represent the vector you want to multiply. - identity_weight: A scalar specifying the weight on the identity matrix - that is added to the block matrix before raising it to a power. If - ``use_cached=False`` it is guaranteed that this argument will be used in - the computation. When returning cached values, this argument *may* be - ignored in favor whatever value was last passed to - :func:`~CurvatureBlock.update_cache`. The precise semantics of this - depend on the concrete subclass and its particular behavior in regard to - caching. - power: The power to which to raise the matrix. - exact_power: Specifies whether to compute the exact matrix power of - ``BlockMatrix + identity_weight I``. When this argument is ``False`` - the exact behaviour will depend on the concrete subclass and the - result will *in general* be an approximation to - ``(BlockMatrix + identity_weight I)^power``, although some subclasses - may still compute the exact matrix power. - use_cached: Whether to use a cached version for computing the product or - to use the most recent curvature estimates. The cached version is - going to be *at least* as fresh as the value provided to the last call - to :func:`~CurvatureBlock.update_cache` with the same value of ``power`` - - Returns: - A tuple of arrays, representing the result of the matrix-vector product. - """ - - scale = self.scale(state, use_cached) - - result = self._multiply_matpower_unscaled( - state=state, - vector=vector, - identity_weight=identity_weight / scale, - power=power, - exact_power=exact_power, - use_cached=use_cached, - ) - - return utils.scalar_mul(result, jnp.power(scale, power)) - - @abc.abstractmethod - def _multiply_matpower_unscaled( - self, - state: State, - vector: Sequence[Array], - identity_weight: Numeric, - power: Scalar, - exact_power: bool, - use_cached: bool, - ) -> tuple[Array, ...]: - """Performs matrix-vector multiplication, ignoring ``self.scale``.""" - - def multiply( - self, - state: State, - vector: Sequence[Array], - identity_weight: Numeric, - exact_power: bool, - use_cached: bool, - ) -> tuple[Array, ...]: - """Computes ``(BlockMatrix + identity_weight I)`` times ``vector``.""" - - return self.multiply_matpower( - state=state, - vector=vector, - identity_weight=identity_weight, - power=1, - exact_power=exact_power, - use_cached=use_cached, - ) - - def multiply_inverse( - self, - state: State, - vector: Sequence[Array], - identity_weight: Numeric, - exact_power: bool, - use_cached: bool, - ) -> tuple[Array, ...]: - """Computes ``(BlockMatrix + identity_weight I)^-1`` times ``vector``.""" - - return self.multiply_matpower( - state=state, - vector=vector, - identity_weight=identity_weight, - power=-1, - exact_power=exact_power, - use_cached=use_cached, - ) - - @utils.auto_scope_method - def eigenvalues( - self, - state: State, - use_cached: bool, - ) -> Array: - """Computes the eigenvalues for this block approximation. - - Args: - state: The state dict for this block. - use_cached: Whether to use a cached versions of the eigenvalues or to use - the most recent curvature estimates to compute them. The cached version - are going to be *at least* as fresh as the last time you called - :func:`~CurvatureBlock.update_cache` with ``eigenvalues=True``. - - Returns: - An array containing the eigenvalues of the block. - """ - eigenvalues = self._eigenvalues_unscaled(state, use_cached) - - assert eigenvalues.size == self.dim - - return self.scale(state, use_cached) * eigenvalues - - @abc.abstractmethod - def _eigenvalues_unscaled( - self, - state: State, - use_cached: bool, - ) -> Array: - """Computes the eigenvalues for this block, ignoring `self.scale`.""" - - @abc.abstractmethod - def update_curvature_matrix_estimate( - self, - state: State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> State: - """Updates the block's curvature estimates using the ``info`` provided. - - Each block *in general* estimates a moving average of its associated - curvature matrix. If you don't want a moving average you can set - ``ema_old=0`` and ``ema_new=1``. - - Args: - state: The state dict for this block to update. - estimation_data: A map containing data used for updating the curvature - matrix estimate for this block. This can be computed by calling the - function returned from :func:`~layer_tags_vjp`. Please see its - implementation for more details on the name of the fields and how they - are constructed. - ema_old: Specifies the weight of the old value when computing the updated - estimate in the moving average. - ema_new: Specifies the weight of the new value when computing the updated - estimate in the moving average. - identity_weight: The weight of the identity added to the block's curvature - matrix before computing the cached matrix power. - batch_size: The batch size used in computing the values in ``info``. - """ - - @utils.auto_scope_method - def update_cache( - self, - state: State, - identity_weight: Numeric, - exact_powers: ScalarOrSequence | None, - approx_powers: ScalarOrSequence | None, - eigenvalues: bool, - ) -> State: - """Updates the cached estimates of the different powers specified. - - Args: - state: The state dict for this block to update. - identity_weight: The weight of the identity added to the block's curvature - matrix before computing the cached matrix power. - exact_powers: Specifies any cached exact matrix powers to be updated. - approx_powers: Specifies any cached approximate matrix powers to be - updated. - eigenvalues: Specifies whether to update the cached eigenvalues - of the block. If they have not been cached before, this will create an - entry with them in the block's cache. - - Returns: - The updated state. - """ - return self._update_cache( - state=state, - identity_weight=identity_weight / self.scale(state, False), - exact_powers=_to_real_set(exact_powers), - approx_powers=_to_real_set(approx_powers), - eigenvalues=eigenvalues, - ) - - @abc.abstractmethod - def _update_cache( - self, - state: State, - identity_weight: Numeric, - exact_powers: set[Scalar], - approx_powers: set[Scalar], - eigenvalues: bool, - ) -> State: - """The cache updating function, ignoring ``self.scale``.""" - - @utils.auto_scope_method - def to_dense_matrix(self, state: State) -> Array: - """Returns a dense representation of the curvature matrix.""" - return self.scale(state, False) * self._to_dense_unscaled(state) - - @abc.abstractmethod - def _to_dense_unscaled(self, state: State) -> Array: - """A dense representation of the curvature, ignoring ``self.scale``.""" - - def undamped_diagonal(self, state: State) -> tuple[Array, ...]: - """Returns the diagonal of the undamped curvature.""" - return utils.scalar_mul(self._undamped_diagonal_unscaled(state), - self.scale(state, False)) - - def _undamped_diagonal_unscaled(self, state: State) -> tuple[Array, ...]: - """Returns the diagonal of the undamped curvature, ignoring ``self.scale``.""" - raise NotImplementedError() - - def norm(self, state: State, norm_type: str) -> Numeric: - """Computes the norm of the curvature block, according to ``norm_type``.""" - - return self.scale(state, False) * self._norm_unscaled(state, norm_type) - - @abc.abstractmethod - def _norm_unscaled( - self, - state: State, - norm_type: str - ) -> Numeric: - """Like ``norm`` but with ``self.scale`` not included.""" - - -class ScaledIdentity(CurvatureBlock): - """A block that assumes that the curvature is a scaled identity matrix.""" - - def __init__( - self, - layer_tag_eq: tags.LayerTagEqn, - scale: Numeric = 1.0, - ): - """Initializes the block. - - Args: - layer_tag_eq: The Jax equation corresponding to the layer tag, that this - block will approximate the curvature to. - scale: The scale of the identity matrix. - """ - self._scale = scale - super().__init__(layer_tag_eq) - - def fixed_scale(self) -> Numeric: - return self._scale - - def _init( - self, - rng: PRNGKey, - exact_powers_to_cache: set[Scalar], - approx_powers_to_cache: set[Scalar], - cache_eigenvalues: bool, - ) -> CurvatureBlock.State: - - del rng, exact_powers_to_cache, approx_powers_to_cache # Unused - - return CurvatureBlock.State( - cache=None, - ) - - def sync( - self, - state: CurvatureBlock.State, - pmap_axis_name: str, - ) -> CurvatureBlock.State: - return state - - def _multiply_matpower_unscaled( - self, - state: CurvatureBlock.State, - vector: Sequence[Array], - identity_weight: Numeric, - power: Scalar, - exact_power: bool, - use_cached: bool, - ) -> tuple[Array, ...]: - - del exact_power # Unused - - # state_dependent_scale needs to be included because it won't be by the - # caller of this function (multiply_matpower) when use_cached=True - scale = self.state_dependent_scale(state) if use_cached else 1.0 - - identity_weight = identity_weight + scale - - if power == 1: - return jax.tree.map(lambda x: identity_weight * x, vector) - - elif power == -1: - return jax.tree.map(lambda x: x / identity_weight, vector) - - else: - identity_weight = jnp.power(identity_weight, power) - return jax.tree.map(lambda x: identity_weight * x, vector) - - def _eigenvalues_unscaled( - self, - state: CurvatureBlock.State, - use_cached: bool, - ) -> Array: - return jnp.ones([self.dim]) - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: CurvatureBlock.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> CurvatureBlock.State: - - return state.copy() - - def _update_cache( - self, - state: CurvatureBlock.State, - identity_weight: Numeric, - exact_powers: set[Scalar], - approx_powers: set[Scalar], - eigenvalues: bool, - ) -> CurvatureBlock.State: - - return state.copy() - - def _to_dense_unscaled(self, state: CurvatureBlock.State) -> Array: - del state # not used - return jnp.eye(self.dim) - - def _norm_unscaled( - self, - state: CurvatureBlock.State, - norm_type: str - ) -> Numeric: - - return utils.psd_matrix_norm(jnp.ones([self.dim]), norm_type=norm_type) - - -class Diagonal(CurvatureBlock, abc.ABC): - """An abstract class for approximating only the diagonal of curvature.""" - - @utils.register_state_class - class State(CurvatureBlock.State): - """Persistent state of the block. - - Attributes: - diagonal_factors: A tuple of the moving averages of the estimated - diagonals of the curvature for each parameter that is part of the - associated layer. - """ - diagonal_factors: tuple[utils.WeightedMovingAverage, ...] - - def _init( - self, - rng: PRNGKey, - exact_powers_to_cache: set[Scalar], - approx_powers_to_cache: set[Scalar], - cache_eigenvalues: bool, - ) -> State: - - del rng - - return Diagonal.State( - cache=None, - diagonal_factors=tuple( - utils.WeightedMovingAverage.zeros_array(shape, self.dtype) - for shape in self.parameters_shapes - ), - ) - - def sync( - self, - state: State, - pmap_axis_name: str, - ) -> State: - - # Copy this first since we mutate it later in this function. - state = state.copy() - - for factor in state.diagonal_factors: - factor.sync(pmap_axis_name) - - return state - - def _multiply_matpower_unscaled( - self, - state: State, - vector: Sequence[Array], - identity_weight: Numeric, - power: Scalar, - exact_power: bool, - use_cached: bool, - ) -> tuple[Array, ...]: - - # state_dependent_scale needs to be included because it won't be by the - # caller of this function (multiply_matpower) when use_cached=True - scale = self.state_dependent_scale(state) if use_cached else 1.0 - - factors = tuple(scale * f.value + identity_weight - for f in state.diagonal_factors) - - assert len(factors) == len(vector) - - if power == 1: - return tuple(f * v for f, v in zip(factors, vector)) - elif power == -1: - return tuple(v / f for f, v in zip(factors, vector)) - else: - return tuple(jnp.power(f, power) * v for f, v in zip(factors, vector)) - - def _eigenvalues_unscaled( - self, - state: State, - use_cached: bool, - ) -> Array: - return jnp.concatenate([f.value.flatten() for f in state.diagonal_factors], - axis=0) - - def _update_cache( - self, - state: State, - identity_weight: Numeric, - exact_powers: set[Scalar], - approx_powers: set[Scalar], - eigenvalues: bool, - ) -> State: - - return state.copy() - - def _to_dense_unscaled(self, state: State) -> Array: - - # Extract factors in canonical order - factors = [state.diagonal_factors[i].value.flatten() - for i in self.parameters_canonical_order] - - # Construct diagonal matrix - return jnp.diag(jnp.concatenate(factors, axis=0)) - - def _norm_unscaled( - self, - state: CurvatureBlock.State, - norm_type: str - ) -> Numeric: - - return utils.product( - utils.psd_matrix_norm(f.value.flatten(), norm_type=norm_type) - for f in state.diagonal_factors) - - -class Full(CurvatureBlock, abc.ABC): - """An abstract class for approximating the block matrix with a full matrix.""" - - @utils.register_state_class - class State(CurvatureBlock.State): - """Persistent state of the block. - - Attributes: - matrix: A moving average of the estimated curvature matrix for all - parameters that are part of the associated layer. - """ - matrix: utils.WeightedMovingAverage - - def __init__( - self, - layer_tag_eq: tags.LayerTagEqn, - eigen_decomposition_threshold: int | None = None, - ): - """Initializes the block. - - Args: - layer_tag_eq: The Jax equation corresponding to the layer tag that this - block will approximate the curvature to. - eigen_decomposition_threshold: During calls to ``init`` and - ``update_cache`` if higher number of matrix powers than this threshold - are requested, instead of computing individual approximate powers, will - directly compute the eigen-decomposition instead (which provide access to - any matrix power). If this is ``None`` will use the value returned from - :func:`~get_default_eigen_decomposition_threshold()`. - """ - - if eigen_decomposition_threshold is None: - threshold = get_default_eigen_decomposition_threshold() - self._eigen_decomposition_threshold = threshold - - else: - self._eigen_decomposition_threshold = eigen_decomposition_threshold - - super().__init__(layer_tag_eq) - - def parameters_list_to_single_vector( - self, - parameters_shaped_list: Sequence[Array], - ) -> Array: - """Converts values corresponding to parameters of the block to vector.""" - - if len(parameters_shaped_list) != self.number_of_parameters: - - raise ValueError(f"Expected a list of {self.number_of_parameters} values," - f" but got {len(parameters_shaped_list)} instead.") - - for array, shape in zip(parameters_shaped_list, self.parameters_shapes): - - if array.shape != shape: - raise ValueError(f"Expected a value of shape {shape}, but got " - f"{array.shape} instead.") - - return jnp.concatenate([v.flatten() for v in parameters_shaped_list]) - - def single_vector_to_parameters_list( - self, - vector: Array, - ) -> tuple[Array, ...]: - """Reverses the transformation ``self.parameters_list_to_single_vector``.""" - - if vector.ndim != 1: - raise ValueError(f"Expecting a vector, got {vector.ndim}-tensor.") - - if vector.size != self.dim: - raise ValueError(f"Expected a vector of size {self.dim}, but got " - f"{vector.size} instead.") - - parameters_shaped_list = [] - index = 0 - - for shape in self.parameters_shapes: - - size = utils.product(shape) - parameters_shaped_list.append(vector[index: index + size].reshape(shape)) - index += size - - assert index == self.dim - - return tuple(parameters_shaped_list) - - def _init( - self, - rng: PRNGKey, - exact_powers_to_cache: set[Scalar], - approx_powers_to_cache: set[Scalar], - cache_eigenvalues: bool, - ) -> State: - - del rng - - # This block does not have any notion of "approximate" powers - exact_powers_to_cache = exact_powers_to_cache | approx_powers_to_cache - cache = {} - - if len(exact_powers_to_cache) > self._eigen_decomposition_threshold: - cache["eigenvalues"] = jnp.zeros([self.dim], self.dtype) - cache["eigen_vectors"] = jnp.zeros([self.dim, self.dim], self.dtype) - - elif cache_eigenvalues: - cache["eigenvalues"] = jnp.zeros([self.dim], self.dtype) - - if len(exact_powers_to_cache) <= self._eigen_decomposition_threshold: - for power in exact_powers_to_cache: - cache[str(power)] = jnp.zeros([self.dim, self.dim], self.dtype) - - return Full.State( - cache=cache, - matrix=utils.WeightedMovingAverage.zeros_array( - [self.dim, self.dim], self.dtype), - ) - - def sync( - self, - state: State, - pmap_axis_name: str, - ) -> State: - - # Copy this first since we mutate it later in this function. - state = state.copy() - - state.matrix.sync(pmap_axis_name) - - return state - - def _multiply_matpower_unscaled( - self, - state: State, - vector: Sequence[Array], - identity_weight: Numeric, - power: Scalar, - exact_power: bool, - use_cached: bool, - ) -> tuple[Array, ...]: - - vector = self.parameters_list_to_single_vector(vector) - - if power == 1: - - result = jnp.matmul(state.matrix.value, vector) - - if use_cached: - # state_dependent_scale needs to be included here because it won't be by - # the caller of this function (multiply_matpower) when use_cached=True. - # This is not an issue for other powers because they bake in - # state_dependent_scale. - result *= self.state_dependent_scale(state) - - result += identity_weight * vector - - elif not use_cached: - - matrix = state.matrix.value + identity_weight * jnp.eye(self.dim) - - if power == -1: - result = utils.psd_solve(matrix, vector) - else: - # TODO(jamesmartens,botev): investigate this for determinism on GPUs - # NOTE: this function only works for integer powers - result = jnp.matmul(jnp.linalg.matrix_power(matrix, power), vector) - - else: - - if str(power) in state.cache: - result = jnp.matmul(state.cache[str(power)], vector) - - else: - s = state.cache["eigenvalues"] - q = state.cache["eigen_vectors"] - - result = jnp.matmul(jnp.transpose(q), vector) - result = jnp.power(s + identity_weight, power) * result - result = jnp.matmul(q, result) - - return self.single_vector_to_parameters_list(result) - - def _eigenvalues_unscaled( - self, - state: State, - use_cached: bool, - ) -> Array: - - if not use_cached: - return utils.safe_psd_eigh(state.matrix.value)[0] - - else: - return state.cache["eigenvalues"] - - def _update_cache( - self, - state: State, - identity_weight: Numeric, - exact_powers: set[Scalar], - approx_powers: set[Scalar], - eigenvalues: bool, - ) -> State: - - # Copy this first since we mutate it later in this function. - state = state.copy() - - scale = self.state_dependent_scale(state) - - # This block does not have any notion of "approximate" powers - exact_powers = exact_powers | approx_powers - - if len(exact_powers) > self._eigen_decomposition_threshold: - - s, q = utils.safe_psd_eigh(state.matrix.value) - state.cache = dict(eigenvalues=scale * s, eigen_vectors=q) - - else: - - if eigenvalues: - state.cache["eigenvalues"] = scale * utils.safe_psd_eigh( - state.matrix.value)[0] - - for power in exact_powers: - - if power == -1: - state.cache[str(power)] = utils.psd_inv( - state.matrix.value + identity_weight * jnp.eye(self.dim)) / scale - else: - matrix = state.matrix.value + identity_weight * jnp.eye(self.dim) - state.cache[str(power)] = ( - (scale ** power) * jnp.linalg.matrix_power(matrix, power)) - - return state - - def _to_dense_unscaled(self, state: State) -> Array: - - # Permute the matrix according to the parameters canonical order - return utils.block_permuted( - state.matrix.value, - block_sizes=[utils.product(shape) for shape in self.parameters_shapes], - block_order=self.parameters_canonical_order - ) - - def _norm_unscaled( - self, - state: CurvatureBlock.State, - norm_type: str - ) -> Numeric: - - return utils.psd_matrix_norm(state.matrix.value, norm_type=norm_type) - - def _undamped_diagonal_unscaled(self, state: State) -> tuple[Array, ...]: - diag_vec = jnp.diag(state.matrix.value) - return self.single_vector_to_parameters_list(diag_vec) - - -class KroneckerFactored(CurvatureBlock, abc.ABC): - """An abstract class for approximating the block with a Kronecker product. - - The constructor takes two special arguments: - - parameters_specs: A list, where each element specifies for each - parameter a "rearrange string". This is in the format `abc->b(ca)` - similar to `einops.rearrange`. - - parameters_concat_axis: The axis along which the parameters will be - concatenated to form a single array after each parameter has been - rearranged according to its "rearrange string". - - The above implies that: - - All parameters must have the same rank after they have been rearranged. - - All parameters must have the same size along all axes except the - concatenation axis after they have been rearranged. - - By default, each parameter is rearanged to a matrix, by merging all dimensions - except the last one. If a parameter is a vector (rank 1), it is rearranged to - a matrix with the first dimension being 1. Then concatenation is done along - axis=0. - """ - - @utils.register_state_class - class State(CurvatureBlock.State): - """Persistent state of the block. - - Attributes: - factors: A tuple of the moving averages of the estimated factors of the - curvature for each axis group. - """ - - factors: tuple[utils.WeightedMovingAverage, ...] - - @classmethod - def from_dict(cls, dict_rep: dict[str, Any]) -> Self: - class_name = dict_rep.pop("__class__", cls.__name__) - assert class_name == cls.__name__ - return cls( - factors=tuple( - utils.WeightedMovingAverage.from_dict(rep) - for rep in dict_rep["factor"] - ) - ) - - def __init__( - self, - layer_tag_eq: tags.LayerTagEqn, - parameters_specs: Sequence[str] | None = None, - parameters_concat_axis: int = 0, - ): - - # Even though the superclass constructor will set this later, we need to do - # it now since it's used below. - self._layer_tag_eq = layer_tag_eq - - if parameters_specs is None: - parameters_specs = [] - - for shape in self.parameters_shapes: - - if len(shape) == 1: - parameters_specs.append("a -> 1a") - - else: - in_str = _ALPHABET[:len(shape)] - out_str = f"({in_str[:-1]}){in_str[-1]}" - parameters_specs.append(f"{in_str} -> {out_str}") - - else: - assert len(parameters_specs) == self.number_of_parameters - - self.parameters_specs = parameters_specs - self.parameters_concat_axis = parameters_concat_axis - - super().__init__(layer_tag_eq) - - def __str__(self): - return ( - f"{self.__class__.__name__}(parameter_specs={self.parameters_specs}, " - f"parameters_concat_axis={self.parameters_concat_axis}), " - f"tag name: {self.name}, params shapes: {self.parameters_shapes!r}" - ) - - def parameters_shaped_list_to_array( - self, - parameters_shaped_list: Sequence[Array], - ) -> Array: - """Combines all parameters to a single array.""" - values = [] - for p, spec in zip( - parameters_shaped_list, - self.parameters_specs, - strict=True, - ): - values.append(utils.rearrange(p, spec)) - - return jnp.concatenate(values, axis=self.parameters_concat_axis) - - def array_to_parameters_shaped_list(self, array: Array) -> tuple[Array, ...]: - """An inverse transformation of ``self.parameters_shaped_list_to_array``.""" - parameters_list = [] - n = 0 - index = [slice(None)] * array.ndim - - for shape, spec in zip( - self.parameters_shapes, - self.parameters_specs, - strict=True, - ): - zero = utils.rearrange(jnp.zeros(shape), spec) - d = zero.shape[self.parameters_concat_axis] - index[self.parameters_concat_axis] = slice(n, n + d) - p = array[tuple(index)] - parameters_list.append(p.reshape(shape)) - n += d - - return tuple(parameters_list) - - @property - def array_shape(self) -> Shape: - """The shape of the single non axis grouped array.""" - avals = [jnp.zeros(shape) for shape in self.parameters_shapes] - return self.parameters_shaped_list_to_array(avals).shape - - @property - def array_ndim(self) -> int: - """The number of dimensions of the single non axis grouped array.""" - return len(self.array_shape) - - def _init( - self, - rng: PRNGKey, - exact_powers_to_cache: set[Scalar], - approx_powers_to_cache: set[Scalar], - cache_eigenvalues: bool, - ) -> State: - - cache = {} - factors = [] - - for i, d in enumerate(self.array_shape): - - factors.append( - utils.WeightedMovingAverage.zeros_array((d, d), self.dtype) - ) - - if cache_eigenvalues or exact_powers_to_cache: - cache[f"{i}_factor_eigenvalues"] = jnp.zeros((d,), dtype=self.dtype) - - if exact_powers_to_cache: - cache[f"{i}_factor_eigen_vectors"] = jnp.zeros((d, d), dtype=self.dtype) - - for power in approx_powers_to_cache: - - if power != -1: - raise NotImplementedError( - f"Approximations for power {power} is not yet implemented." - ) - - if str(power) not in cache: - cache[str(power)] = {} - - cache[str(power)][f"{i}_factor"] = jnp.zeros((d, d), dtype=self.dtype) - - return KroneckerFactored.State( - cache=cache, - factors=tuple(factors), - ) - - def sync( - self, - state: State, - pmap_axis_name: str, - ) -> State: - - # Copy this first since we mutate it later in this function. - state = state.copy() - - for factor in state.factors: - factor.sync(pmap_axis_name) - - return state - - def _multiply_matpower_unscaled( - self, - state: State, - vector: Sequence[Array], - identity_weight: Numeric, - power: Scalar, - exact_power: bool, - use_cached: bool, - ) -> tuple[Array, ...]: - - assert len(state.factors) == self.array_ndim - - vector = self.parameters_shaped_list_to_array(vector) - - if power == 1: - - factors = [f.value for f in state.factors] - - # state_dependent_scale needs to be included here because it won't be by - # the caller of this function (multiply_matpower) when use_cached=True. - # This is not an issue for other powers because they bake in - # state_dependent_scale. - scale = self.state_dependent_scale(state) if use_cached else 1.0 - - if exact_power: - result = scale * utils.kronecker_product_axis_mul_v(factors, vector) - result = result + identity_weight * vector - - else: - # If compute pi_adjusted_kronecker_factors used a more expensive matrix - # norm in its computation, it might make sense to cache it. But we - # currently don't do that. - - result = scale * utils.kronecker_product_axis_mul_v( - utils.pi_adjusted_kronecker_factors( - *factors, damping=identity_weight / scale), - vector) - - elif exact_power: - - if use_cached: - s = [ - state.cache[f"{i}_factor_eigenvalues"] - for i in range(len(state.factors)) - ] - q = [ - state.cache[f"{i}_factor_eigen_vectors"] - for i in range(len(state.factors)) - ] - - else: - s, q = zip( - *[utils.safe_psd_eigh(factor.value) for factor in state.factors] - ) - - eigenvalues = utils.outer_product(*s) + identity_weight - eigenvalues = jnp.power(eigenvalues, power) - - result = utils.kronecker_eigen_basis_axis_mul_v(q, eigenvalues, vector) - - else: - - if power != -1 and power != -0.5: - raise NotImplementedError( - f"Approximations for power {power} is not yet implemented." - ) - - if use_cached: - - assert power != -0.5 - - factors = [ - state.cache[str(power)][f"{i}_factor"] - for i in range(len(state.factors)) - ] - - else: - - factors = [factor.value for factor in state.factors] - - factors = utils.pi_adjusted_kronecker_factors( - *factors, damping=identity_weight) - - if power == -1: - factors = utils.invert_psd_matrices(factors) - elif power == -0.5: - factors = utils.inverse_sqrt_psd_matrices(factors) - else: - raise NotImplementedError() - - result = utils.kronecker_product_axis_mul_v(factors, vector) - - return self.array_to_parameters_shaped_list(result) - - def _eigenvalues_unscaled( - self, - state: State, - use_cached: bool, - ) -> Array: - - assert len(state.factors) == self.array_ndim - - if use_cached: - s = [ - state.cache[f"{i}_factor_eigenvalues"] - for i in range(len(state.factors)) - ] - else: - s_q = [utils.safe_psd_eigh(factor.value) for factor in state.factors] - s, _ = zip(*s_q) - - return utils.outer_product(*s) - - def _update_cache( - self, - state: State, - identity_weight: Numeric, - exact_powers: set[Scalar], - approx_powers: set[Scalar], - eigenvalues: bool, - ) -> State: - - assert len(state.factors) == self.array_ndim - - # Copy this first since we mutate it later in this function. - state = state.copy() - - scale = self.state_dependent_scale(state) - factor_scale = jnp.power(scale, 1.0 / self.array_ndim) - - if eigenvalues or exact_powers: - - s_q = [utils.safe_psd_eigh(factor.value) for factor in state.factors] - - s, q = zip(*s_q) - - for i in range(len(state.factors)): - state.cache[f"{i}_factor_eigenvalues"] = factor_scale * s[i] - - if exact_powers: - state.cache[f"{i}_factor_eigen_vectors"] = q[i] - - for power in approx_powers: - - if power != -1: - raise NotImplementedError( - f"Approximations for power {power} is not yet implemented." - ) - - cache = state.cache[str(power)] - - # This computes the approximate inverse factors using the generalization - # of the pi-adjusted inversion from the original KFAC paper. - inv_factors = utils.pi_adjusted_kronecker_inverse( - *[factor.value for factor in state.factors], - damping=identity_weight, - ) - - for i in range(len(state.factors)): - cache[f"{i}_factor"] = inv_factors[i] / factor_scale - - return state - - def _norm_unscaled( - self, - state: CurvatureBlock.State, - norm_type: str - ) -> Numeric: - - return utils.product( - utils.psd_matrix_norm(f.value, norm_type=norm_type) - for f in state.factors) - - def _to_dense_unscaled(self, state: "KroneckerFactored.State") -> Array: - - # We currently support this only for 2 parameters - assert 0 < self.number_of_parameters <= 2 - inputs_factor = state.factors[0].value - - if (self.number_of_parameters == 2 and - self.parameters_canonical_order[0] != 0): - - # Permute the matrix according to the parameters canonical order - inputs_factor = utils.block_permuted( - state.factors[0].value, - block_sizes=[state.factors[0].raw_value.shape[0] - 1, 1], - block_order=(1, 0), - ) - - return jnp.kron(inputs_factor, state.factors[1].value) - - -# _ _ _ -# | \ | | (_) -# | \| | __ _ ___ _____ -# | . ` |/ _` | \ \ / / _ \ -# | |\ | (_| | |\ V / __/ -# |_| \_|\__,_|_| \_/ \___| -# - - -class NaiveDiagonal(Diagonal): - """Approximates the diagonal of the curvature with in the most obvious way. - - The update to the curvature estimate is computed by ``(sum_i g_i) ** 2 / N``. - where `g_i` is the gradient of each individual data point, and ``N`` is the - batch size. - """ - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: Diagonal.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> Diagonal.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - for factor, dw in zip( - state.diagonal_factors, estimation_data.tangents.params - ): - factor.update(dw * dw / batch_size, ema_old, ema_new) - - return state - - -class NaiveFull(Full): - """Approximates the full curvature with in the most obvious way. - - The update to the curvature estimate is computed by - ``(sum_i g_i) (sum_i g_i)^T / N``, where ``g_i`` is the gradient of each - individual data point, and ``N`` is the batch size. - """ - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: Full.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> Full.State: - del identity_weight - - # This method supports the case where the param tangents have an extra - # leading dimension that should be summed over (after the outer products). - # TODO(jamesmartens): add support for this to NaiveDiagonal - - # Copy this first since we mutate it later in this function. - state = state.copy() - - params_tangents = jax.tree_util.tree_leaves( - estimation_data.tangents.params) - - params_tangents_flattened = [] - - assert len(params_tangents) == self.number_of_parameters - - for p_shape, pt in zip(self.parameters_shapes, params_tangents): - - if p_shape: - assert ( - pt.shape[-len(p_shape) :] == p_shape - ), f"{pt.shape=} and {p_shape=}" - - p_size = utils.product(p_shape) - - params_tangents_flattened.append(pt.reshape([-1, p_size])) - - tangents = jnp.concatenate(params_tangents_flattened, axis=1) - - if jnp.iscomplexobj(tangents): - stats = ( - jnp.einsum("ay,az->yz", tangents.real, tangents.real) - - jnp.einsum("ay,az->yz", tangents.imag, tangents.imag)) / batch_size - else: - stats = jnp.einsum("ay,az->yz", tangents, tangents) / batch_size - - state.matrix.update(stats, ema_old, ema_new) - - return state - - -class NaiveTNT(KroneckerFactored): - """A standard TNT block for a single parameter, or weights + bias. - - Each factor of the standard TNT curvature approximation estimates the expected - value of the contraction of the gradients with themselves along all but a - single axis `i`: - ``F_i ~~ E[contract_all_but_one(g, g, i)].`` - where `contrat_all_but_one` is defined as the contraction over all axes except - the i-th of its first two inputs, e.g.: - ``contract_all_but_one(A, B, 1)[a,b] = sum_{i,j} A[i, a, j] B[i, b, j]`` - - The estimation is performed in a naive way by contracting the sum of each - examples' gradients and then dividing by the batch size: - ``F_i = contract_all_but_one(sum_n g_n, sum_n g_n, i) / N`` - where `g_n` is the model gradient of a single example and `N` is the - batch size. Since the expectations of the gradients over the model - distribution is zero and they are independent across cases, this is still an - unbiased estimator. - """ - - def state_dependent_scale( - self, - state: "NaiveTNT.State", - ) -> Numeric: - return utils.tnt_scale([factor.value for factor in state.factors]) - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: KroneckerFactored.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> KroneckerFactored.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - dw = self.parameters_shaped_list_to_array(estimation_data.tangents.params) - - assert dw.ndim == len(state.factors) - - in_str = _ALPHABET[: dw.ndim] - - for i, factor in enumerate(state.factors): - # For factor i we contract the gradient with itself along all axes, - # except the i-th. - - lhs_str = utils.replace_char(in_str, "y", i) - rhs_str = utils.replace_char(in_str, "z", i) - - # This is a rank-1 mod since it's like we flattened all but dim i together - # and then did an outer product - factor_update = ( - jnp.einsum(f"{lhs_str},{rhs_str}->yz", dw, dw) / batch_size - ) - - factor.update(factor_update, ema_old, ema_new) - - return state - - -# _____ -# | __ \ -# | | | | ___ _ __ ___ ___ -# | | | |/ _ \ '_ \/ __|/ _ \ -# | |__| | __/ | | \__ \ __/ -# |_____/ \___|_| |_|___/\___| -# - - -class DenseDiagonal(Diagonal): - """A `Diagonal` block specifically for dense layers.""" - - @property - def has_bias(self) -> bool: - """Whether the layer has a bias parameter.""" - return len(self.parameter_variables) == 2 - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: Diagonal.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> Diagonal.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - - assert utils.first_dim_is_size(batch_size, x, dy) - - diagonals = (jnp.matmul((x * x).T, dy * dy) / batch_size,) - if self.has_bias: - diagonals += (jnp.mean(dy * dy, axis=0),) - - assert len(diagonals) == self.number_of_parameters - - for diagonal_factor, diagonal in zip(state.diagonal_factors, diagonals): - diagonal_factor.update(diagonal, ema_old, ema_new) - - return state - - -class DenseFull(Full): - """A `Full` block specifically for dense layers.""" - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: Full.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> Full.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - - assert utils.first_dim_is_size(batch_size, x, dy) - - params_tangents = x[:, :, None] * dy[:, None, :] - - if self.number_of_parameters == 2: - params_tangents = jnp.concatenate([params_tangents, dy[:, None]], axis=1) - - params_tangents = jnp.reshape(params_tangents, [batch_size, -1]) - - matrix_update = jnp.matmul(params_tangents.T, params_tangents) / batch_size - state.matrix.update(matrix_update, ema_old, ema_new) - - return state - - -class DenseTwoKroneckerFactored(KroneckerFactored): - """A :class:`~TwoKroneckerFactored` block specifically for dense layers.""" - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: KroneckerFactored.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> KroneckerFactored.State: - del identity_weight - assert 1 <= self.number_of_parameters <= 2 - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - - assert utils.first_dim_is_size(batch_size, x, dy) - - if self.number_of_parameters == 2: - x_one = jnp.ones_like(x[:, :1]) - x = jnp.concatenate([x, x_one], axis=1) - - input_stats = jnp.einsum("ay,az->yz", x, x) / batch_size - output_stats = jnp.einsum("ay,az->yz", dy, dy) / batch_size - - state.factors[0].update(input_stats, ema_old, ema_new) - state.factors[1].update(output_stats, ema_old, ema_new) - - return state - - -class RepeatedDenseKroneckerFactored(DenseTwoKroneckerFactored): - """Block for dense layers applied to tensors with extra time/loc dims.""" - - @utils.register_state_class - class State(KroneckerFactored.State): - """Persistent state of the block. - - Attributes: - average_repeats: A decayed average of the per-case number of non-masked - repeats in the data used to compute the block's statistics. We use the - same decayed averaging for this quantity that we do for the statistics, - so that they "match". - """ - - average_repeats: utils.WeightedMovingAverage - - def __init__( - self, - layer_tag_eq: tags.LayerTagEqn, - use_masking: bool = True, - parameters_specs: Sequence[str] | None = None, - parameters_concat_axis: int = 0, - ): - self._use_masking = use_masking - super().__init__( - layer_tag_eq=layer_tag_eq, - parameters_specs=parameters_specs, - parameters_concat_axis=parameters_concat_axis, - ) - - def _init( - self, - rng: PRNGKey, - exact_powers_to_cache: set[Scalar], - approx_powers_to_cache: set[Scalar], - cache_eigenvalues: bool, - ) -> "RepeatedDenseKroneckerFactored.State": - - super_state = super()._init( - rng, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues - ) - - return RepeatedDenseKroneckerFactored.State( - average_repeats=utils.WeightedMovingAverage.zeros_array((), self.dtype), - **super_state.__dict__, - ) - - def state_dependent_scale( - self, - state: "RepeatedDenseKroneckerFactored.State", - ) -> Numeric: - return 1.0 / state.average_repeats.value - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: KroneckerFactored.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> KroneckerFactored.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - - assert utils.first_dim_is_size(batch_size, x, dy) - - if self._use_masking: - - # hack: we identify masked repeats by checking if all corresponding - # entries of dy are zero - mask = 1.0 - jnp.all(dy == 0.0, axis=-1, keepdims=True) - - # zero out corresponding elts of x - x = x * mask - - # compute total non-masked - total = jnp.sum(mask) - - else: - total = math.prod(dy.shape[:-1]) - - x = x.reshape([-1, x.shape[-1]]) - dy = dy.reshape([-1, dy.shape[-1]]) - - if self.number_of_parameters == 2: - x_one = jnp.ones_like(x[:, :1]) - x = jnp.concatenate([x, x_one], axis=1) - - input_stats = jnp.einsum("ay,az->yz", x, x) / batch_size - output_stats = jnp.einsum("ay,az->yz", dy, dy) / batch_size - - state.factors[0].update(input_stats, ema_old, ema_new) - state.factors[1].update(output_stats, ema_old, ema_new) - state.average_repeats.update(total / batch_size, ema_old, ema_new) - - return state - - -class DenseTNT(DenseTwoKroneckerFactored): - """A TNT block for dense layers. - - This TNT block modifies :class:`~NaiveTNTBlock` by the way it estimates each - factor specifically for a dense layer. Instead of using the contraction over - the summed gradient, it performs the contraction over each individual batch - elements and then averages over the batch: - ``F_i = sum_n contract_all_but_one(g_n, g_n, i) / N`` - The estimator is unbiased, and will have lower variance then the naive one. - """ - - def state_dependent_scale(self, state: "DenseTNT.State") -> Numeric: - return utils.tnt_scale([factor.value for factor in state.factors]) - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: KroneckerFactored.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> KroneckerFactored.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - - assert utils.first_dim_is_size(batch_size, x, dy) - - if self.number_of_parameters == 2: - x_one = jnp.ones_like(x[:, :1]) - x = jnp.concatenate([x, x_one], axis=1) - - # We multiply each x by the norm_y, and each dy by the norm of x - dy_norms = jnp.linalg.norm(dy, axis=-1, keepdims=True) - x_norms = jnp.linalg.norm(x, axis=-1, keepdims=True) - x = x * dy_norms - dy = dy * x_norms - - input_stats = jnp.einsum("ay,az->yz", x, x) / batch_size - output_stats = jnp.einsum("ay,az->yz", dy, dy) / batch_size - - state.factors[0].update(input_stats, ema_old, ema_new) - state.factors[1].update(output_stats, ema_old, ema_new) - - return state - - -# _____ ___ _____ -# / ____| |__ \| __ \ -# | | ___ _ ____ __ ) | | | | -# | | / _ \| '_ \ \ / // /| | | | -# | |___| (_) | | | \ V // /_| |__| | -# \_____\___/|_| |_|\_/|____|_____/ -# - - -class Conv2DDiagonal(Diagonal): - """A :class:`~Diagonal` block specifically for 2D convolution layers.""" - - def __init__( - self, - layer_tag_eq: tags.LayerTagEqn, - max_elements_for_vmap: int | None = None, - ): - """Initializes the block. - - Since there is no 'nice' formula for computing the average of the - tangents for a 2D convolution, what we do is that we have a function - - ``self.conv2d_tangent_squared`` - that computes for a single feature map the - square of the tangents for the kernel of the convolution. To average over - the batch we have two choices - vmap or loop over the batch sequentially - using scan. This utility function provides a trade-off by being able to - specify the maximum number of batch size that we can vmap over. This means - that the maximum memory usage will be ``max_batch_size_for_vmap`` times the - memory needed when calling ``self.conv2d_tangent_squared``. And the actual - ``vmap`` will be called ``ceil(total_batch_size / max_batch_size_for_vmap)`` - number of times in a loop to find the final average. - - Args: - layer_tag_eq: The Jax equation corresponding to the layer tag, that this - block will approximate the curvature to. - max_elements_for_vmap: The threshold used for determining how much - computation to the in parallel and how much in serial manner. If - ``None`` will use the value returned by - :func:`~get_max_parallel_elements`. - """ - self._averaged_kernel_squared_tangents = utils.loop_and_parallelize_average( - func=self.conv2d_tangent_squared, - max_parallel_size=max_elements_for_vmap or get_max_parallel_elements(), - ) - super().__init__(layer_tag_eq) - - @property - def has_bias(self) -> bool: - return len(self.parameter_variables) == 2 - - def conv2d_tangent_squared( - self, - image_features_map: Array, - output_tangent: Array, - ) -> Array: - """Computes the elementwise square of a tangent for a single feature map.""" - - extra_params = {k: v for k, v in self.layer_tag_extra_params.items() - if k not in ("lhs_shape", "rhs_shape", "meta")} - - _, vjp = jax.vjp( - functools.partial( - jax.lax.conv_general_dilated, - **extra_params - ), - image_features_map[None], jnp.zeros(self.parameters_shapes[0]) - ) - - return jnp.square(vjp(output_tangent[None])[1]) - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: Diagonal.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> Diagonal.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - - assert utils.first_dim_is_size(batch_size, x, dy) - - diagonals = (self._averaged_kernel_squared_tangents(x, dy),) - - if self.has_bias: - sum_axis = tuple(range(1, dy.ndim - len(self.parameters_shapes[1]))) - bias_dy = jnp.sum(dy, axis=sum_axis) - diagonals += (jnp.mean(bias_dy * bias_dy, axis=0),) - - assert len(diagonals) == self.number_of_parameters - - for diagonal_factor, diagonal in zip(state.diagonal_factors, diagonals): - diagonal_factor.update(diagonal, ema_old, ema_new) - - return state - - -class Conv2DFull(Full): - """A :class:`~Full` block specifically for 2D convolution layers.""" - - def __init__( - self, - layer_tag_eq: tags.LayerTagEqn, - max_elements_for_vmap: int | None = None, - ): - """Initializes the block. - - Since there is no 'nice' formula for computing the average of the - tangents for a 2D convolution, what we do is that we have a function - - ``self.conv2d_tangent_squared`` - that computes for a single feature map the - square of the tangents for the kernel of the convolution. To average over - the batch we have two choices - vmap or loop over the batch sequentially - using scan. This utility function provides a trade-off by being able to - specify the maximum batch that that will be handled in a single iteration - of the loop. This means that the maximum memory usage will be - ``max_batch_size_for_vmap`` times the memory needed when calling - ``self.conv2d_tangent_squared``. And the actual ``vmap`` will be - called ``ceil(total_batch_size / max_batch_size_for_vmap)`` number of times - in a loop to find the final average. - - Args: - layer_tag_eq: The Jax equation corresponding to the layer tag, that this - block will approximate the curvature to. - max_elements_for_vmap: The threshold used for determining how much - computation to the in parallel and how much in serial manner. If - ``None`` will use the value returned by - :func:`~get_max_parallel_elements`. - """ - - self._averaged_tangents_outer_product = utils.loop_and_parallelize_average( - func=self.conv2d_tangent_outer_product, - max_parallel_size=max_elements_for_vmap or get_max_parallel_elements(), - ) - - super().__init__(layer_tag_eq) - - def conv2d_tangent_outer_product( - self, - inputs: Array, - tangent_of_outputs: Array, - ) -> Array: - """Computes the outer product of a tangent for a single feature map.""" - - extra_params = {k: v for k, v in self.layer_tag_extra_params.items() - if k not in ("lhs_shape", "rhs_shape", "meta")} - - _, vjp = jax.vjp( - functools.partial( - jax.lax.conv_general_dilated, - **extra_params - ), - inputs[None], jnp.zeros(self.parameters_shapes[0]) - ) - - tangents = (vjp(tangent_of_outputs[None])[1],) - - if self.number_of_parameters == 2: - num_axis = tangent_of_outputs.ndim - len(self.parameters_shapes[1]) - sum_axis = tuple(range(num_axis)) - tangents += (jnp.sum(tangent_of_outputs, axis=sum_axis),) - - flat_tangents = self.parameters_list_to_single_vector(tangents) - - return jnp.outer(flat_tangents, flat_tangents) - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: Full.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> Full.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - assert utils.first_dim_is_size(batch_size, x, dy) - - matrix_update = self._averaged_tangents_outer_product(x, dy) - state.matrix.update(matrix_update, ema_old, ema_new) - - return state - - -class Conv2DTwoKroneckerFactored(KroneckerFactored): - """A :class:`~TwoKroneckerFactored` block specifically for 2D convolution layers.""" - - def fixed_scale(self) -> Numeric: - return float(self.num_locations) - - @property - def kernel_output_axis(self) -> int: - return self._layer_tag_eq.params["dimension_numbers"].rhs_spec[0] - - @property - def outputs_channel_index(self) -> int: - """The ``channels`` index in the outputs of the layer.""" - return self._layer_tag_eq.params["dimension_numbers"].out_spec[1] - - @property - def inputs_channel_index(self) -> int: - """The ``channels`` index in the inputs of the layer.""" - return self._layer_tag_eq.params["dimension_numbers"].lhs_spec[1] - - @property - def weights_output_channel_index(self) -> int: - """The ``channels`` index in weights of the layer.""" - return self._layer_tag_eq.params["dimension_numbers"].rhs_spec[0] - - @property - def weights_spatial_shape(self) -> Shape: - spatial_index = self._layer_tag_eq.params["dimension_numbers"].rhs_spec[2:] - return tuple(self.parameters_shapes[0][i] for i in spatial_index) - - @property - def weights_spatial_size(self) -> int: - """The spatial filter size of the weights.""" - return utils.product(dim for dim in self.weights_spatial_shape) - - @property - def inputs_spatial_shape(self) -> Shape: - spatial_index = self._layer_tag_eq.params["dimension_numbers"].lhs_spec[2:] - return tuple(self.inputs_shapes[0][i] for i in spatial_index) - - @property - def num_locations(self) -> int: - """The number of spatial locations that each filter is applied to.""" - return psm.num_conv_locations( - self.inputs_spatial_shape, - self.weights_spatial_shape, - self._layer_tag_eq.params["window_strides"], - self._layer_tag_eq.params["padding"]) - - def input_size(self) -> int: - if self.has_bias: - return self.num_inputs_channels * self.weights_spatial_size + 1 - else: - return self.num_inputs_channels * self.weights_spatial_size - - def output_size(self) -> int: - return self.num_outputs_channels - - @property - def num_inputs_channels(self) -> int: - """The number of channels in the inputs to the layer.""" - return self._layer_tag_eq.invars[0].aval.shape[ # pytype: disable=attribute-error - self.inputs_channel_index] - - @property - def num_outputs_channels(self) -> int: - """The number of channels in the outputs to the layer.""" - return self._layer_tag_eq.invars[1].aval.shape[ # pytype: disable=attribute-error - self.weights_output_channel_index] - - def compute_inputs_stats( - self, - inputs: Array, - weighting_array: Array | None = None, - ) -> Array: - """Computes the statistics for the inputs factor.""" - batch_size = inputs.shape[0] - - input_cov_m, input_cov_v = psm.patches_moments( - inputs, - kernel_spatial_shape=self.weights_spatial_shape, - strides=self._layer_tag_eq.params["window_strides"], - padding=self._layer_tag_eq.params["padding"], - data_format=None, - dim_numbers=self._layer_tag_eq.params["dimension_numbers"], - precision=self._layer_tag_eq.params.get("precision"), - weighting_array=weighting_array, - ) - - # Flatten the kernel and channels dimensions - k, h, c = input_cov_v.shape - input_cov_v = jnp.reshape(input_cov_v, (k * h * c,)) - input_cov_m = jnp.reshape(input_cov_m, (k * h * c, k * h * c)) - - # Normalize by the `batch size` * `num_locations` - normalizer = batch_size * self.num_locations - input_cov_m = input_cov_m / normalizer - input_cov_v = input_cov_v / normalizer - - if self.number_of_parameters == 1: - return input_cov_m - - if weighting_array is None: - corner = jnp.ones([1], dtype=input_cov_m.dtype) - else: - corner = jnp.mean(weighting_array).reshape([1]) - - input_cov = jnp.concatenate([input_cov_m, input_cov_v[None]], axis=0) - input_cov_v = jnp.concatenate([input_cov_v, corner], axis=0) - - return jnp.concatenate([input_cov, input_cov_v[:, None]], axis=1) - - def compute_outputs_stats(self, tangent_of_output: Array) -> Array: - """Computes the statistics for the outputs factor.""" - lhs_str = utils.replace_char(_ALPHABET[:4], "y", self.outputs_channel_index) - rhs_str = utils.replace_char(_ALPHABET[:4], "z", self.outputs_channel_index) - ein_str = f"{lhs_str},{rhs_str}->yz" - stats = jnp.einsum(ein_str, tangent_of_output, tangent_of_output) - - # Normalize by the `batch size` * `num_locations` - normalizer = tangent_of_output.shape[0] * self.num_locations - return stats / normalizer - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: KroneckerFactored.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> KroneckerFactored.State: - del identity_weight - assert 1 <= self.number_of_parameters <= 2 - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - - assert utils.first_dim_is_size(batch_size, x, dy) - - input_stats = self.compute_inputs_stats(x) - output_stats = self.compute_outputs_stats(dy) - - state.factors[0].update(input_stats, ema_old, ema_new) - state.factors[1].update(output_stats, ema_old, ema_new) - - return state - - -class Conv2DTNT(Conv2DTwoKroneckerFactored): - """A TNT block for Conv2D layers. - - This TNT block modifies :class:`~NaiveTNTBlock` by the way it estimates each - factor specifically for a conv2D layer. Importantly, it assumes "location - independence" similar to :class:~`Conv2DTwoKroneckerFactored`. Given this - assumption, instead of using the contraction over the summed gradient, it - performs the contraction for each individual example in the batch, and each - individual spatial location, and then averages over these: - ``F_i = sum_n sum_t contract_all_but_one(g_{n,t}, g_{n,t}, i) / (N * T)`` - where T here is the number of spatial locations. The estimator is unbiased - (under the "location independence" approximation), and will have lower - variance then the naive one. - - If the argument `weighting_per_location` is set to `False`, then the block - uses a mixture between location-independence and not, in the sense that it - computes the contractions per example, while the matrix factor statistics - still assume location independence. - """ - - def __init__( - self, - layer_tag_eq: tags.LayerTagEqn, - weighting_per_location: bool = True, - parameters_specs: Sequence[str] | None = None, - parameters_concat_axis: int = 0, - ): - self.weighting_per_location = weighting_per_location - super().__init__( - layer_tag_eq=layer_tag_eq, - parameters_specs=parameters_specs, - parameters_concat_axis=parameters_concat_axis, - ) - - def state_dependent_scale( - self, state: "Conv2DTNT.State" - ) -> Numeric: - return utils.tnt_scale([factor.value for factor in state.factors]) - - def x_squared_spatial_norms(self, x: Array) -> Array: - - kernel_shape = list(self.parameters_shapes[0]) - kernel_shape[self.kernel_output_axis] = 1 - - return jax.lax.conv_general_dilated( - lhs=x * x, - rhs=jnp.ones(kernel_shape), - window_strides=self.layer_tag_extra_params["window_strides"], - padding=self.layer_tag_extra_params["padding"], - lhs_dilation=self.layer_tag_extra_params["lhs_dilation"], - rhs_dilation=self.layer_tag_extra_params["rhs_dilation"], - dimension_numbers=self.layer_tag_extra_params["dimension_numbers"], - feature_group_count=self.layer_tag_extra_params["feature_group_count"], - precision=self.layer_tag_extra_params["precision"], - preferred_element_type= - self.layer_tag_extra_params["preferred_element_type"], - ) - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: KroneckerFactored.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> KroneckerFactored.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - - assert utils.first_dim_is_size(batch_size, x, dy) - - # We multiply each x by the norm_y, and each dy by the norm of x - dy_sq_norms = jnp.sum(dy * dy, axis=self.outputs_channel_index) - x_sq_norms = self.x_squared_spatial_norms(x) - - if self.number_of_parameters == 2: - # When we have a bias we need to add 1 coming from it to the squared norm - x_sq_norms = x_sq_norms + 1 - - if not self.weighting_per_location: - dy_sq_norms = jnp.sum(dy_sq_norms, axis=[1, 2]) - x_sq_norms = jnp.sum(x_sq_norms, axis=[1, 2, 3], keepdims=True) - - input_cov = self.compute_inputs_stats(x, weighting_array=dy_sq_norms) - output_cov = self.compute_outputs_stats(dy * jnp.sqrt(x_sq_norms)) - - state.factors[0].update(input_cov, ema_old, ema_new) - state.factors[1].update(output_cov, ema_old, ema_new) - - return state - - -# _____ _ _ _____ _ _ __ _ -# / ____| | | /\ | |/ ____| | (_)/ _| | -# | (___ ___ __ _| | ___ / \ _ __ __| | (___ | |__ _| |_| |_ -# \___ \ / __/ _` | |/ _ \ / /\ \ | '_ \ / _` |\___ \| '_ \| | _| __| -# ____) | (_| (_| | | __// ____ \| | | | (_| |____) | | | | | | | |_ -# |_____/ \___\__,_|_|\___/_/ \_\_| |_|\__,_|_____/|_| |_|_|_| \__| -# - - -def compatible_shapes(ref_shape, target_shape): - - if len(target_shape) > len(ref_shape): - raise ValueError("Target shape should be smaller.") - - for ref_d, target_d in zip(reversed(ref_shape), reversed(target_shape)): - if ref_d != target_d and target_d != 1: - raise ValueError(f"{target_shape} is incompatible with {ref_shape}.") - - -def compatible_sum(tensor, target_shape, skip_axes): - """Compute sum over ``tensor`` to achieve shape given by ``target_shape``.""" - - compatible_shapes(tensor.shape, target_shape) - - n = tensor.ndim - len(target_shape) - - axis = [i + n for i, t in enumerate(target_shape) - if t == 1 and i + n not in skip_axes] - - tensor = jnp.sum(tensor, axis=axis, keepdims=True) - - axis = [i for i in range(tensor.ndim - len(target_shape)) - if i not in skip_axes] - - return jnp.sum(tensor, axis=axis) - - -class ScaleAndShiftDiagonal(Diagonal): - """A diagonal approximation specifically for a scale and shift layers.""" - - @property - def has_scale(self) -> bool: - """Whether this layer's equation has a scale.""" - return self._layer_tag_eq.params["has_scale"] - - @property - def has_shift(self) -> bool: - """Whether this layer's equation has a shift.""" - return self._layer_tag_eq.params["has_shift"] - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: Diagonal.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> Diagonal.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - - assert utils.first_dim_is_size(batch_size, x, dy) - - if self.has_scale: - - assert (state.diagonal_factors[0].raw_value.shape == - self.parameters_shapes[0]) - - scale_shape = estimation_data.primals.params[0].shape - - d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0]) - - scale_diag_update = jnp.sum( - d_scale * d_scale, - axis=0, keepdims=d_scale.ndim == len(scale_shape) - ) / batch_size - - state.diagonal_factors[0].update(scale_diag_update, ema_old, ema_new) - - if self.has_shift: - - shift_shape = estimation_data.primals.params[-1].shape - d_shift = compatible_sum(dy, shift_shape, skip_axes=[0]) - - shift_diag_update = jnp.sum( - d_shift * d_shift, - axis=0, keepdims=d_shift.ndim == len(shift_shape) - ) / batch_size - - state.diagonal_factors[-1].update(shift_diag_update, ema_old, ema_new) - - return state - - -class ScaleAndShiftFull(Full): - """A full dense approximation specifically for a scale and shift layers.""" - - @property - def _has_scale(self) -> bool: - """Whether this layer's equation has a scale.""" - return self._layer_tag_eq.params["has_scale"] - - @property - def _has_shift(self) -> bool: - """Whether this layer's equation has a shift.""" - return self._layer_tag_eq.params["has_shift"] - - @utils.auto_scope_method - def update_curvature_matrix_estimate( - self, - state: Full.State, - estimation_data: tracer.LayerVjpData[Array], - ema_old: Numeric, - ema_new: Numeric, - identity_weight: Numeric, - batch_size: Numeric, - ) -> Full.State: - del identity_weight - - # Copy this first since we mutate it later in this function. - state = state.copy() - - [x] = estimation_data.primals.inputs - [dy] = estimation_data.tangents.outputs - assert utils.first_dim_is_size(batch_size, x, dy) - - tangents = [] - - if self._has_scale: - # Scale tangent - scale_shape = estimation_data.primals.params[0].shape - - d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0]) - d_scale = d_scale.reshape([batch_size, -1]) - - tangents.append(d_scale) - - if self._has_shift: - # Shift tangent - - shift_shape = estimation_data.primals.params[-1].shape - - d_shift = compatible_sum(dy, shift_shape, skip_axes=[0]) - d_shift = d_shift.reshape([batch_size, -1]) - - tangents.append(d_shift) - - tangents = jnp.concatenate(tangents, axis=1) - matrix_update = jnp.matmul(tangents.T, tangents) / batch_size - - state.matrix.update(matrix_update, ema_old, ema_new) - - return state diff --git a/kfac_jax/_src/curvature_blocks/__init__.py b/kfac_jax/_src/curvature_blocks/__init__.py new file mode 100644 index 0000000..29f23d9 --- /dev/null +++ b/kfac_jax/_src/curvature_blocks/__init__.py @@ -0,0 +1,51 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""K-FAC curvature approximation to single layer blocks.""" +from kfac_jax._src.curvature_blocks import curvature_block +from kfac_jax._src.curvature_blocks import diagonal +from kfac_jax._src.curvature_blocks import full +from kfac_jax._src.curvature_blocks import kronecker_factored +from kfac_jax._src.curvature_blocks import tnt +from kfac_jax._src.curvature_blocks import utils + +CurvatureBlock = curvature_block.CurvatureBlock +ScaledIdentity = curvature_block.ScaledIdentity +ScalarOrSequence = curvature_block.ScalarOrSequence + +Diagonal = diagonal.Diagonal +Full = full.Full +KroneckerFactored = kronecker_factored.KroneckerFactored +NaiveDiagonal = diagonal.NaiveDiagonal +NaiveFull = full.NaiveFull +NaiveTNT = tnt.NaiveTNT +DenseDiagonal = diagonal.DenseDiagonal +DenseFull = full.DenseFull +DenseTwoKroneckerFactored = kronecker_factored.DenseTwoKroneckerFactored +RepeatedDenseKroneckerFactored = ( + kronecker_factored.RepeatedDenseKroneckerFactored) +DenseTNT = tnt.DenseTNT +Conv2DDiagonal = diagonal.Conv2DDiagonal +Conv2DFull = full.Conv2DFull +Conv2DTwoKroneckerFactored = kronecker_factored.Conv2DTwoKroneckerFactored +Conv2DTNT = tnt.Conv2DTNT +ScaleAndShiftDiagonal = diagonal.ScaleAndShiftDiagonal +ScaleAndShiftFull = full.ScaleAndShiftFull + +set_max_parallel_elements = utils.set_max_parallel_elements +get_max_parallel_elements = utils.get_max_parallel_elements +set_default_eigen_decomposition_threshold = ( + utils.set_default_eigen_decomposition_threshold) +get_default_eigen_decomposition_threshold = ( + utils.get_default_eigen_decomposition_threshold) + diff --git a/kfac_jax/_src/curvature_blocks/curvature_block.py b/kfac_jax/_src/curvature_blocks/curvature_block.py new file mode 100644 index 0000000..9d54c29 --- /dev/null +++ b/kfac_jax/_src/curvature_blocks/curvature_block.py @@ -0,0 +1,605 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing the abstract base class for curvature blocks.""" + +import abc +from typing import Any, Sequence + +import jax +import jax.numpy as jnp +import jax.scipy +from kfac_jax._src import layers_and_loss_tags as tags +from kfac_jax._src import tag_graph_matcher as tgm +from kfac_jax._src import tracer +from kfac_jax._src import utils +from kfac_jax._src.curvature_blocks import utils as cb_utils +import numpy as np + + +# Types for annotation +Array = utils.Array +Scalar = utils.Scalar +Numeric = utils.Numeric +PRNGKey = utils.PRNGKey +Shape = utils.Shape +DType = utils.DType +ScalarOrSequence = Scalar | Sequence[Scalar] + + +class CurvatureBlock(utils.Finalizable): + """Abstract class for curvature approximation blocks. + + A CurvatureBlock defines a curvature matrix to be estimated, and gives methods + to multiply powers of this with a vector. Powers can be computed exactly or + with a class-determined approximation. Cached versions of the powers can be + pre-computed to make repeated multiplications cheaper. During initialization, + you would have to explicitly specify all powers that you will need to cache. + """ + + @utils.register_state_class + class State(utils.State): + """Persistent state of the block. + + Any subclasses of :class:`~CurvatureBlock` should also internally extend + this class, with any attributes needed for the curvature estimation. + + Attributes: + cache: A dictionary, containing any state data that is updated on + irregular intervals, such as inverses, eigenvalues, etc. Elements of + this are updated via calls to :func:`~CurvatureBlock.update_cache`, and + do not necessarily correspond to the most up-to-date curvature estimate. + """ + cache: dict[str, Array | dict[str, Array]] | None + + def __init__(self, layer_tag_eq: tags.LayerTagEqn): + """Initializes the block. + + Args: + layer_tag_eq: The Jax equation corresponding to the layer tag that this + block will approximate the curvature to. + """ + super().__init__() + + self._layer_tag_eq = layer_tag_eq + + self.finalize() + + @property + def name(self) -> str: + return tags.layer_eqn_name(self._layer_tag_eq) + + @property + def layer_tag_primitive(self) -> tags.LayerTag: + """The :class:`jax.core.Primitive` corresponding to the block's tag equation.""" + + primitive = self._layer_tag_eq.primitive + assert isinstance(primitive, tgm.tags.LayerTag) + + return primitive + + @property + def parameter_variables(self) -> tuple[jax.core.Var, ...]: + """The parameter variables of the underlying Jax equation.""" + + param_vars = [] + + for p in tags.layer_eqn_data(self._layer_tag_eq).params: + + assert isinstance(p, jax.core.Var) + param_vars.append(p) + + return tuple(param_vars) + + @property + def outputs_shapes(self) -> tuple[Shape, ...]: + """The shapes of the output variables of the block's tag equation.""" + + output_vars = tags.layer_eqn_data(self._layer_tag_eq).outputs + + return jax.tree.map(lambda x: x.aval.shape, output_vars) + + @property + def inputs_shapes(self) -> tuple[Shape, ...]: + """The shapes of the input variables of the block's tag equation.""" + + input_vars = tags.layer_eqn_data(self._layer_tag_eq).inputs + + return jax.tree.map(lambda x: x.aval.shape, input_vars) + + @property + def parameters_shapes(self) -> tuple[Shape, ...]: + """The shapes of the parameter variables of the block's tag equation.""" + return tuple(jax.tree.map( + lambda x: tuple(x.aval.shape), self.parameter_variables)) + + @property + def dtype(self) -> DType: + dtypes = set(p.aval.dtype for p in self.parameter_variables) # pytype: disable=attribute-error + if len(dtypes) > 1: + raise ValueError("Not all parameters are the same dtype.") + return dtypes.pop() + + @property + def parameters_canonical_order(self) -> tuple[int, ...]: + """The canonical order of the parameter variables.""" + + return tuple(np.argsort([p.count for p in self.parameter_variables])) + + @property + def layer_tag_extra_params(self) -> dict[str, Any]: + """Any extra parameters of passed into the Jax primitive of this block.""" + + return self._layer_tag_eq.params + + @property + def number_of_parameters(self) -> int: + """Number of parameter variables of this block.""" + + return len(self.parameters_shapes) + + @property + def dim(self) -> int: + """The number of elements of all parameter variables together.""" + + return sum(utils.product(shape) for shape in self.parameters_shapes) + + def scale(self, state: State, use_cache: bool) -> Numeric: + """A scalar pre-factor of the curvature approximation. + + Importantly, all methods assume that whenever a user requests cached values, + any state dependant scale is taken into account by the cache (e.g. either + stored explicitly and used or mathematically added to values). + + Args: + state: The state for this block. + use_cache: Whether the method requesting this is using cached values or + not. + + Returns: + A scalar value to be multiplied with any unscaled block representation. + """ + + # TODO(jamesmartens,botev): This way of handling state dependent scale is + # a bit hacky and leads to complexity in other parts of the code that must + # be aware of how this part works. Should try to replace this with something + # better. + + if use_cache: + return self.fixed_scale() + + return self.fixed_scale() * self.state_dependent_scale(state) + + def fixed_scale(self) -> Numeric: + """A fixed scalar pre-factor of the curvature (e.g. constant).""" + return 1.0 + + def state_dependent_scale(self, state: State) -> Numeric: + """A scalar pre-factor of the curvature, computed from the most fresh curvature estimate.""" + del state # Unused + return 1.0 + + def __str__(self): + return (f"{self.__class__.__name__}, tag name: {self.name}, " + f"params shapes: {self.parameters_shapes!r}") + + @utils.auto_scope_method + def init( + self, + rng: PRNGKey, + exact_powers_to_cache: ScalarOrSequence | None, + approx_powers_to_cache: ScalarOrSequence | None, + cache_eigenvalues: bool, + ) -> State: + """Initializes the state for this block. + + Args: + rng: The PRNGKey which to be used for any randomness of the initialization + exact_powers_to_cache: A single value, or multiple values in a list, which + specify which exact matrix powers the block should be caching. Matrix + powers, which are expected to be used in + :func:`~CurvatureBlock.multiply_matpower`, + :func:`~CurvatureBlock.multiply_inverse` or + :func:`~CurvatureBlock.multiply` with ``exact_power=True`` and + ``use_cached=True`` must be provided here. + approx_powers_to_cache: A single value, or multiple values in a list, + which specify approximate matrix powers the block should be caching. + Matrix powers, which are expected to be used in + :func:`~CurvatureBlock.multiply_matrix_power`, + :func:`~CurvatureBlock.multiply_inverse` or + :func:`~CurvatureBlock.multiply` with ``exact_power=False`` and + ``use_cached=True`` must be provided here. + cache_eigenvalues: Specifies whether the block should be caching the + eigenvalues of its approximate curvature. + Returns: + A dictionary with the initialized state. + """ + return self._init( + rng=rng, + exact_powers_to_cache=cb_utils.to_real_set(exact_powers_to_cache), + approx_powers_to_cache=cb_utils.to_real_set(approx_powers_to_cache), + cache_eigenvalues=cache_eigenvalues) + + @abc.abstractmethod + def _init( + self, + rng: PRNGKey, + exact_powers_to_cache: set[Scalar], + approx_powers_to_cache: set[Scalar], + cache_eigenvalues: bool, + ) -> State: + """The non-public interface of ``init``.""" + + @abc.abstractmethod + def sync( + self, + state: State, + pmap_axis_name: str, + ) -> State: + """Syncs the state across different devices (does not sync the cache).""" + + @utils.auto_scope_method + def multiply_matpower( + self, + state: State, + vector: Sequence[Array], + identity_weight: Numeric, + power: Scalar, + exact_power: bool, + use_cached: bool, + ) -> tuple[Array, ...]: + """Computes ``(BlockMatrix + identity_weight I)**power`` times ``vector``. + + Args: + state: The state for this block. + vector: A tuple of arrays that should have the same shapes as the block's + parameters_shapes, which represent the vector you want to multiply. + identity_weight: A scalar specifying the weight on the identity matrix + that is added to the block matrix before raising it to a power. If + ``use_cached=False`` it is guaranteed that this argument will be used in + the computation. When returning cached values, this argument *may* be + ignored in favor whatever value was last passed to + :func:`~CurvatureBlock.update_cache`. The precise semantics of this + depend on the concrete subclass and its particular behavior in regard to + caching. + power: The power to which to raise the matrix. + exact_power: Specifies whether to compute the exact matrix power of + ``BlockMatrix + identity_weight I``. When this argument is ``False`` + the exact behaviour will depend on the concrete subclass and the + result will *in general* be an approximation to + ``(BlockMatrix + identity_weight I)^power``, although some subclasses + may still compute the exact matrix power. + use_cached: Whether to use a cached version for computing the product or + to use the most recent curvature estimates. The cached version is + going to be *at least* as fresh as the value provided to the last call + to :func:`~CurvatureBlock.update_cache` with the same value of ``power`` + + Returns: + A tuple of arrays, representing the result of the matrix-vector product. + """ + + scale = self.scale(state, use_cached) + + result = self._multiply_matpower_unscaled( + state=state, + vector=vector, + identity_weight=identity_weight / scale, + power=power, + exact_power=exact_power, + use_cached=use_cached, + ) + + return utils.scalar_mul(result, jnp.power(scale, power)) + + @abc.abstractmethod + def _multiply_matpower_unscaled( + self, + state: State, + vector: Sequence[Array], + identity_weight: Numeric, + power: Scalar, + exact_power: bool, + use_cached: bool, + ) -> tuple[Array, ...]: + """Performs matrix-vector multiplication, ignoring ``self.scale``.""" + + def multiply( + self, + state: State, + vector: Sequence[Array], + identity_weight: Numeric, + exact_power: bool, + use_cached: bool, + ) -> tuple[Array, ...]: + """Computes ``(BlockMatrix + identity_weight I)`` times ``vector``.""" + + return self.multiply_matpower( + state=state, + vector=vector, + identity_weight=identity_weight, + power=1, + exact_power=exact_power, + use_cached=use_cached, + ) + + def multiply_inverse( + self, + state: State, + vector: Sequence[Array], + identity_weight: Numeric, + exact_power: bool, + use_cached: bool, + ) -> tuple[Array, ...]: + """Computes ``(BlockMatrix + identity_weight I)^-1`` times ``vector``.""" + + return self.multiply_matpower( + state=state, + vector=vector, + identity_weight=identity_weight, + power=-1, + exact_power=exact_power, + use_cached=use_cached, + ) + + @utils.auto_scope_method + def eigenvalues( + self, + state: State, + use_cached: bool, + ) -> Array: + """Computes the eigenvalues for this block approximation. + + Args: + state: The state dict for this block. + use_cached: Whether to use a cached versions of the eigenvalues or to use + the most recent curvature estimates to compute them. The cached version + are going to be *at least* as fresh as the last time you called + :func:`~CurvatureBlock.update_cache` with ``eigenvalues=True``. + + Returns: + An array containing the eigenvalues of the block. + """ + eigenvalues = self._eigenvalues_unscaled(state, use_cached) + + assert eigenvalues.size == self.dim + + return self.scale(state, use_cached) * eigenvalues + + @abc.abstractmethod + def _eigenvalues_unscaled( + self, + state: State, + use_cached: bool, + ) -> Array: + """Computes the eigenvalues for this block, ignoring `self.scale`.""" + + @abc.abstractmethod + def update_curvature_matrix_estimate( + self, + state: State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> State: + """Updates the block's curvature estimates using the ``info`` provided. + + Each block *in general* estimates a moving average of its associated + curvature matrix. If you don't want a moving average you can set + ``ema_old=0`` and ``ema_new=1``. + + Args: + state: The state dict for this block to update. + estimation_data: A map containing data used for updating the curvature + matrix estimate for this block. This can be computed by calling the + function returned from :func:`~layer_tags_vjp`. Please see its + implementation for more details on the name of the fields and how they + are constructed. + ema_old: Specifies the weight of the old value when computing the updated + estimate in the moving average. + ema_new: Specifies the weight of the new value when computing the updated + estimate in the moving average. + identity_weight: The weight of the identity added to the block's curvature + matrix before computing the cached matrix power. + batch_size: The batch size used in computing the values in ``info``. + """ + + @utils.auto_scope_method + def update_cache( + self, + state: State, + identity_weight: Numeric, + exact_powers: ScalarOrSequence | None, + approx_powers: ScalarOrSequence | None, + eigenvalues: bool, + ) -> State: + """Updates the cached estimates of the different powers specified. + + Args: + state: The state dict for this block to update. + identity_weight: The weight of the identity added to the block's curvature + matrix before computing the cached matrix power. + exact_powers: Specifies any cached exact matrix powers to be updated. + approx_powers: Specifies any cached approximate matrix powers to be + updated. + eigenvalues: Specifies whether to update the cached eigenvalues + of the block. If they have not been cached before, this will create an + entry with them in the block's cache. + + Returns: + The updated state. + """ + return self._update_cache( + state=state, + identity_weight=identity_weight / self.scale(state, False), + exact_powers=cb_utils.to_real_set(exact_powers), + approx_powers=cb_utils.to_real_set(approx_powers), + eigenvalues=eigenvalues, + ) + + @abc.abstractmethod + def _update_cache( + self, + state: State, + identity_weight: Numeric, + exact_powers: set[Scalar], + approx_powers: set[Scalar], + eigenvalues: bool, + ) -> State: + """The cache updating function, ignoring ``self.scale``.""" + + @utils.auto_scope_method + def to_dense_matrix(self, state: State) -> Array: + """Returns a dense representation of the curvature matrix.""" + return self.scale(state, False) * self._to_dense_unscaled(state) + + @abc.abstractmethod + def _to_dense_unscaled(self, state: State) -> Array: + """A dense representation of the curvature, ignoring ``self.scale``.""" + + def undamped_diagonal(self, state: State) -> tuple[Array, ...]: + """Returns the diagonal of the undamped curvature.""" + return utils.scalar_mul(self._undamped_diagonal_unscaled(state), + self.scale(state, False)) + + def _undamped_diagonal_unscaled(self, state: State) -> tuple[Array, ...]: + """Returns the diagonal of the undamped curvature, ignoring ``self.scale``.""" + raise NotImplementedError() + + def norm(self, state: State, norm_type: str) -> Numeric: + """Computes the norm of the curvature block, according to ``norm_type``.""" + + return self.scale(state, False) * self._norm_unscaled(state, norm_type) + + @abc.abstractmethod + def _norm_unscaled( + self, + state: State, + norm_type: str + ) -> Numeric: + """Like ``norm`` but with ``self.scale`` not included.""" + + +class ScaledIdentity(CurvatureBlock): + """A block that assumes that the curvature is a scaled identity matrix.""" + + def __init__( + self, + layer_tag_eq: tags.LayerTagEqn, + scale: Numeric = 1.0, + ): + """Initializes the block. + + Args: + layer_tag_eq: The Jax equation corresponding to the layer tag, that this + block will approximate the curvature to. + scale: The scale of the identity matrix. + """ + self._scale = scale + super().__init__(layer_tag_eq) + + def fixed_scale(self) -> Numeric: + return self._scale + + def _init( + self, + rng: PRNGKey, + exact_powers_to_cache: set[Scalar], + approx_powers_to_cache: set[Scalar], + cache_eigenvalues: bool, + ) -> CurvatureBlock.State: + + del rng, exact_powers_to_cache, approx_powers_to_cache # Unused + + return CurvatureBlock.State( + cache=None, + ) + + def sync( + self, + state: CurvatureBlock.State, + pmap_axis_name: str, + ) -> CurvatureBlock.State: + return state + + def _multiply_matpower_unscaled( + self, + state: CurvatureBlock.State, + vector: Sequence[Array], + identity_weight: Numeric, + power: Scalar, + exact_power: bool, + use_cached: bool, + ) -> tuple[Array, ...]: + + del exact_power # Unused + + # state_dependent_scale needs to be included because it won't be by the + # caller of this function (multiply_matpower) when use_cached=True + scale = self.state_dependent_scale(state) if use_cached else 1.0 + + identity_weight = identity_weight + scale + + if power == 1: + return jax.tree.map(lambda x: identity_weight * x, vector) + + elif power == -1: + return jax.tree.map(lambda x: x / identity_weight, vector) + + else: + identity_weight = jnp.power(identity_weight, power) + return jax.tree.map(lambda x: identity_weight * x, vector) + + def _eigenvalues_unscaled( + self, + state: CurvatureBlock.State, + use_cached: bool, + ) -> Array: + return jnp.ones([self.dim]) + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: CurvatureBlock.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> CurvatureBlock.State: + + return state.copy() + + def _update_cache( + self, + state: CurvatureBlock.State, + identity_weight: Numeric, + exact_powers: set[Scalar], + approx_powers: set[Scalar], + eigenvalues: bool, + ) -> CurvatureBlock.State: + + return state.copy() + + def _to_dense_unscaled(self, state: CurvatureBlock.State) -> Array: + del state # not used + return jnp.eye(self.dim) + + def _norm_unscaled( + self, + state: CurvatureBlock.State, + norm_type: str + ) -> Numeric: + + return utils.psd_matrix_norm(jnp.ones([self.dim]), norm_type=norm_type) diff --git a/kfac_jax/_src/curvature_blocks/diagonal.py b/kfac_jax/_src/curvature_blocks/diagonal.py new file mode 100644 index 0000000..43f0539 --- /dev/null +++ b/kfac_jax/_src/curvature_blocks/diagonal.py @@ -0,0 +1,375 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing the diagonal curvature blocks.""" +import abc +import functools +from typing import Sequence + +import jax +import jax.numpy as jnp +import jax.scipy +from kfac_jax._src import layers_and_loss_tags as tags +from kfac_jax._src import tracer +from kfac_jax._src import utils +from kfac_jax._src.curvature_blocks import curvature_block +from kfac_jax._src.curvature_blocks import utils as cb_utils + +# Types for annotation +Array = utils.Array +Scalar = utils.Scalar +Numeric = utils.Numeric +PRNGKey = utils.PRNGKey +CurvatureBlock = curvature_block.CurvatureBlock + + +class Diagonal(CurvatureBlock, abc.ABC): + """An abstract class for approximating only the diagonal of curvature.""" + + @utils.register_state_class + class State(CurvatureBlock.State): + """Persistent state of the block. + + Attributes: + diagonal_factors: A tuple of the moving averages of the estimated + diagonals of the curvature for each parameter that is part of the + associated layer. + """ + diagonal_factors: tuple[utils.WeightedMovingAverage, ...] + + def _init( + self, + rng: PRNGKey, + exact_powers_to_cache: set[Scalar], + approx_powers_to_cache: set[Scalar], + cache_eigenvalues: bool, + ) -> State: + + del rng + + return Diagonal.State( + cache=None, + diagonal_factors=tuple( + utils.WeightedMovingAverage.zeros_array(shape, self.dtype) + for shape in self.parameters_shapes + ), + ) + + def sync( + self, + state: State, + pmap_axis_name: str, + ) -> State: + + # Copy this first since we mutate it later in this function. + state = state.copy() + + for factor in state.diagonal_factors: + factor.sync(pmap_axis_name) + + return state + + def _multiply_matpower_unscaled( + self, + state: State, + vector: Sequence[Array], + identity_weight: Numeric, + power: Scalar, + exact_power: bool, + use_cached: bool, + ) -> tuple[Array, ...]: + + # state_dependent_scale needs to be included because it won't be by the + # caller of this function (multiply_matpower) when use_cached=True + scale = self.state_dependent_scale(state) if use_cached else 1.0 + + factors = tuple(scale * f.value + identity_weight + for f in state.diagonal_factors) + + assert len(factors) == len(vector) + + if power == 1: + return tuple(f * v for f, v in zip(factors, vector)) + elif power == -1: + return tuple(v / f for f, v in zip(factors, vector)) + else: + return tuple(jnp.power(f, power) * v for f, v in zip(factors, vector)) + + def _eigenvalues_unscaled( + self, + state: State, + use_cached: bool, + ) -> Array: + return jnp.concatenate([f.value.flatten() for f in state.diagonal_factors], + axis=0) + + def _update_cache( + self, + state: State, + identity_weight: Numeric, + exact_powers: set[Scalar], + approx_powers: set[Scalar], + eigenvalues: bool, + ) -> State: + + return state.copy() + + def _to_dense_unscaled(self, state: State) -> Array: + + # Extract factors in canonical order + factors = [state.diagonal_factors[i].value.flatten() + for i in self.parameters_canonical_order] + + # Construct diagonal matrix + return jnp.diag(jnp.concatenate(factors, axis=0)) + + def _norm_unscaled( + self, + state: CurvatureBlock.State, + norm_type: str + ) -> Numeric: + + return utils.product( + utils.psd_matrix_norm(f.value.flatten(), norm_type=norm_type) + for f in state.diagonal_factors) + + +class NaiveDiagonal(Diagonal): + """Approximates the diagonal of the curvature with in the most obvious way. + + The update to the curvature estimate is computed by ``(sum_i g_i) ** 2 / N``. + where `g_i` is the gradient of each individual data point, and ``N`` is the + batch size. + """ + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: Diagonal.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> Diagonal.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + for factor, dw in zip( + state.diagonal_factors, estimation_data.tangents.params + ): + factor.update(dw * dw / batch_size, ema_old, ema_new) + + return state + + +class DenseDiagonal(Diagonal): + """A `Diagonal` block specifically for dense layers.""" + + @property + def has_bias(self) -> bool: + """Whether the layer has a bias parameter.""" + return len(self.parameter_variables) == 2 + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: Diagonal.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> Diagonal.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + diagonals = (jnp.matmul((x * x).T, dy * dy) / batch_size,) + if self.has_bias: + diagonals += (jnp.mean(dy * dy, axis=0),) + + assert len(diagonals) == self.number_of_parameters + + for diagonal_factor, diagonal in zip(state.diagonal_factors, diagonals): + diagonal_factor.update(diagonal, ema_old, ema_new) + + return state + + +class Conv2DDiagonal(Diagonal): + """A :class:`~Diagonal` block specifically for 2D convolution layers.""" + + def __init__( + self, + layer_tag_eq: tags.LayerTagEqn, + max_elements_for_vmap: int | None = None, + ): + """Initializes the block. + + Since there is no 'nice' formula for computing the average of the + tangents for a 2D convolution, what we do is that we have a function - + ``self.conv2d_tangent_squared`` - that computes for a single feature map the + square of the tangents for the kernel of the convolution. To average over + the batch we have two choices - vmap or loop over the batch sequentially + using scan. This utility function provides a trade-off by being able to + specify the maximum number of batch size that we can vmap over. This means + that the maximum memory usage will be ``max_batch_size_for_vmap`` times the + memory needed when calling ``self.conv2d_tangent_squared``. And the actual + ``vmap`` will be called ``ceil(total_batch_size / max_batch_size_for_vmap)`` + number of times in a loop to find the final average. + + Args: + layer_tag_eq: The Jax equation corresponding to the layer tag, that this + block will approximate the curvature to. + max_elements_for_vmap: The threshold used for determining how much + computation to the in parallel and how much in serial manner. If + ``None`` will use the value returned by + :func:`~get_max_parallel_elements`. + """ + self._averaged_kernel_squared_tangents = utils.loop_and_parallelize_average( + func=self.conv2d_tangent_squared, + max_parallel_size=max_elements_for_vmap or + cb_utils.get_max_parallel_elements(), + ) + super().__init__(layer_tag_eq) + + @property + def has_bias(self) -> bool: + return len(self.parameter_variables) == 2 + + def conv2d_tangent_squared( + self, + image_features_map: Array, + output_tangent: Array, + ) -> Array: + """Computes the elementwise square of a tangent for a single feature map.""" + + extra_params = {k: v for k, v in self.layer_tag_extra_params.items() + if k not in ("lhs_shape", "rhs_shape", "meta")} + + _, vjp = jax.vjp( + functools.partial( + jax.lax.conv_general_dilated, + **extra_params + ), + image_features_map[None], jnp.zeros(self.parameters_shapes[0]) + ) + + return jnp.square(vjp(output_tangent[None])[1]) + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: Diagonal.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> Diagonal.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + diagonals = (self._averaged_kernel_squared_tangents(x, dy),) + + if self.has_bias: + sum_axis = tuple(range(1, dy.ndim - len(self.parameters_shapes[1]))) + bias_dy = jnp.sum(dy, axis=sum_axis) + diagonals += (jnp.mean(bias_dy * bias_dy, axis=0),) + + assert len(diagonals) == self.number_of_parameters + + for diagonal_factor, diagonal in zip(state.diagonal_factors, diagonals): + diagonal_factor.update(diagonal, ema_old, ema_new) + + return state + + +class ScaleAndShiftDiagonal(Diagonal): + """A diagonal approximation specifically for a scale and shift layers.""" + + @property + def has_scale(self) -> bool: + """Whether this layer's equation has a scale.""" + return self._layer_tag_eq.params["has_scale"] + + @property + def has_shift(self) -> bool: + """Whether this layer's equation has a shift.""" + return self._layer_tag_eq.params["has_shift"] + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: Diagonal.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> Diagonal.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + if self.has_scale: + + assert (state.diagonal_factors[0].raw_value.shape == + self.parameters_shapes[0]) + + scale_shape = estimation_data.primals.params[0].shape + + d_scale = cb_utils.compatible_sum(x * dy, scale_shape, skip_axes=[0]) + + scale_diag_update = jnp.sum( + d_scale * d_scale, + axis=0, keepdims=d_scale.ndim == len(scale_shape) + ) / batch_size + + state.diagonal_factors[0].update(scale_diag_update, ema_old, ema_new) + + if self.has_shift: + + shift_shape = estimation_data.primals.params[-1].shape + d_shift = cb_utils.compatible_sum(dy, shift_shape, skip_axes=[0]) + + shift_diag_update = jnp.sum( + d_shift * d_shift, + axis=0, keepdims=d_shift.ndim == len(shift_shape) + ) / batch_size + + state.diagonal_factors[-1].update(shift_diag_update, ema_old, ema_new) + + return state diff --git a/kfac_jax/_src/curvature_blocks/full.py b/kfac_jax/_src/curvature_blocks/full.py new file mode 100644 index 0000000..1d9fedc --- /dev/null +++ b/kfac_jax/_src/curvature_blocks/full.py @@ -0,0 +1,537 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing the full matrix curvature blocks.""" +import abc +import functools +from typing import Sequence + +import jax +import jax.numpy as jnp +import jax.scipy +from kfac_jax._src import layers_and_loss_tags as tags +from kfac_jax._src import tracer +from kfac_jax._src import utils +from kfac_jax._src.curvature_blocks import curvature_block +from kfac_jax._src.curvature_blocks import utils as cb_utils + +# Types for annotation +Array = utils.Array +Scalar = utils.Scalar +Numeric = utils.Numeric +PRNGKey = utils.PRNGKey +CurvatureBlock = curvature_block.CurvatureBlock + + +class Full(CurvatureBlock, abc.ABC): + """An abstract class for approximating the block matrix with a full matrix.""" + + @utils.register_state_class + class State(CurvatureBlock.State): + """Persistent state of the block. + + Attributes: + matrix: A moving average of the estimated curvature matrix for all + parameters that are part of the associated layer. + """ + matrix: utils.WeightedMovingAverage + + def __init__( + self, + layer_tag_eq: tags.LayerTagEqn, + eigen_decomposition_threshold: int | None = None, + ): + """Initializes the block. + + Args: + layer_tag_eq: The Jax equation corresponding to the layer tag that this + block will approximate the curvature to. + eigen_decomposition_threshold: During calls to ``init`` and + ``update_cache`` if higher number of matrix powers than this threshold + are requested, instead of computing individual approximate powers, will + directly compute the eigen-decomposition instead (which provide access to + any matrix power). If this is ``None`` will use the value returned from + :func:`~get_default_eigen_decomposition_threshold()`. + """ + + if eigen_decomposition_threshold is None: + threshold = cb_utils.get_default_eigen_decomposition_threshold() + self._eigen_decomposition_threshold = threshold + + else: + self._eigen_decomposition_threshold = eigen_decomposition_threshold + + super().__init__(layer_tag_eq) + + def parameters_list_to_single_vector( + self, + parameters_shaped_list: Sequence[Array], + ) -> Array: + """Converts values corresponding to parameters of the block to vector.""" + + if len(parameters_shaped_list) != self.number_of_parameters: + + raise ValueError(f"Expected a list of {self.number_of_parameters} values," + f" but got {len(parameters_shaped_list)} instead.") + + for array, shape in zip(parameters_shaped_list, self.parameters_shapes): + + if array.shape != shape: + raise ValueError(f"Expected a value of shape {shape}, but got " + f"{array.shape} instead.") + + return jnp.concatenate([v.flatten() for v in parameters_shaped_list]) + + def single_vector_to_parameters_list( + self, + vector: Array, + ) -> tuple[Array, ...]: + """Reverses the transformation ``self.parameters_list_to_single_vector``.""" + + if vector.ndim != 1: + raise ValueError(f"Expecting a vector, got {vector.ndim}-tensor.") + + if vector.size != self.dim: + raise ValueError(f"Expected a vector of size {self.dim}, but got " + f"{vector.size} instead.") + + parameters_shaped_list = [] + index = 0 + + for shape in self.parameters_shapes: + + size = utils.product(shape) + parameters_shaped_list.append(vector[index: index + size].reshape(shape)) + index += size + + assert index == self.dim + + return tuple(parameters_shaped_list) + + def _init( + self, + rng: PRNGKey, + exact_powers_to_cache: set[Scalar], + approx_powers_to_cache: set[Scalar], + cache_eigenvalues: bool, + ) -> State: + + del rng + + # This block does not have any notion of "approximate" powers + exact_powers_to_cache = exact_powers_to_cache | approx_powers_to_cache + cache = {} + + if len(exact_powers_to_cache) > self._eigen_decomposition_threshold: + cache["eigenvalues"] = jnp.zeros([self.dim], self.dtype) + cache["eigen_vectors"] = jnp.zeros([self.dim, self.dim], self.dtype) + + elif cache_eigenvalues: + cache["eigenvalues"] = jnp.zeros([self.dim], self.dtype) + + if len(exact_powers_to_cache) <= self._eigen_decomposition_threshold: + for power in exact_powers_to_cache: + cache[str(power)] = jnp.zeros([self.dim, self.dim], self.dtype) + + return Full.State( + cache=cache, + matrix=utils.WeightedMovingAverage.zeros_array( + [self.dim, self.dim], self.dtype), + ) + + def sync( + self, + state: State, + pmap_axis_name: str, + ) -> State: + + # Copy this first since we mutate it later in this function. + state = state.copy() + + state.matrix.sync(pmap_axis_name) + + return state + + def _multiply_matpower_unscaled( + self, + state: State, + vector: Sequence[Array], + identity_weight: Numeric, + power: Scalar, + exact_power: bool, + use_cached: bool, + ) -> tuple[Array, ...]: + + vector = self.parameters_list_to_single_vector(vector) + + if power == 1: + + result = jnp.matmul(state.matrix.value, vector) + + if use_cached: + # state_dependent_scale needs to be included here because it won't be by + # the caller of this function (multiply_matpower) when use_cached=True. + # This is not an issue for other powers because they bake in + # state_dependent_scale. + result *= self.state_dependent_scale(state) + + result += identity_weight * vector + + elif not use_cached: + + matrix = state.matrix.value + identity_weight * jnp.eye(self.dim) + + if power == -1: + result = utils.psd_solve(matrix, vector) + else: + # TODO(jamesmartens,botev): investigate this for determinism on GPUs + # NOTE: this function only works for integer powers + result = jnp.matmul(jnp.linalg.matrix_power(matrix, power), vector) + + else: + + if str(power) in state.cache: + result = jnp.matmul(state.cache[str(power)], vector) + + else: + s = state.cache["eigenvalues"] + q = state.cache["eigen_vectors"] + + result = jnp.matmul(jnp.transpose(q), vector) + result = jnp.power(s + identity_weight, power) * result + result = jnp.matmul(q, result) + + return self.single_vector_to_parameters_list(result) + + def _eigenvalues_unscaled( + self, + state: State, + use_cached: bool, + ) -> Array: + + if not use_cached: + return utils.safe_psd_eigh(state.matrix.value)[0] + + else: + return state.cache["eigenvalues"] + + def _update_cache( + self, + state: State, + identity_weight: Numeric, + exact_powers: set[Scalar], + approx_powers: set[Scalar], + eigenvalues: bool, + ) -> State: + + # Copy this first since we mutate it later in this function. + state = state.copy() + + scale = self.state_dependent_scale(state) + + # This block does not have any notion of "approximate" powers + exact_powers = exact_powers | approx_powers + + if len(exact_powers) > self._eigen_decomposition_threshold: + + s, q = utils.safe_psd_eigh(state.matrix.value) + state.cache = dict(eigenvalues=scale * s, eigen_vectors=q) + + else: + + if eigenvalues: + state.cache["eigenvalues"] = scale * utils.safe_psd_eigh( + state.matrix.value)[0] + + for power in exact_powers: + + if power == -1: + state.cache[str(power)] = utils.psd_inv( + state.matrix.value + identity_weight * jnp.eye(self.dim)) / scale + else: + matrix = state.matrix.value + identity_weight * jnp.eye(self.dim) + state.cache[str(power)] = ( + (scale ** power) * jnp.linalg.matrix_power(matrix, power)) + + return state + + def _to_dense_unscaled(self, state: State) -> Array: + + # Permute the matrix according to the parameters canonical order + return utils.block_permuted( + state.matrix.value, + block_sizes=[utils.product(shape) for shape in self.parameters_shapes], + block_order=self.parameters_canonical_order + ) + + def _norm_unscaled( + self, + state: CurvatureBlock.State, + norm_type: str + ) -> Numeric: + + return utils.psd_matrix_norm(state.matrix.value, norm_type=norm_type) + + def _undamped_diagonal_unscaled(self, state: State) -> tuple[Array, ...]: + diag_vec = jnp.diag(state.matrix.value) + return self.single_vector_to_parameters_list(diag_vec) + + +class NaiveFull(Full): + """Approximates the full curvature with in the most obvious way. + + The update to the curvature estimate is computed by + ``(sum_i g_i) (sum_i g_i)^T / N``, where ``g_i`` is the gradient of each + individual data point, and ``N`` is the batch size. + """ + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: Full.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> Full.State: + del identity_weight + + # This method supports the case where the param tangents have an extra + # leading dimension that should be summed over (after the outer products). + # TODO(jamesmartens): add support for this to NaiveDiagonal + + # Copy this first since we mutate it later in this function. + state = state.copy() + + params_tangents = jax.tree_util.tree_leaves( + estimation_data.tangents.params) + + params_tangents_flattened = [] + + assert len(params_tangents) == self.number_of_parameters + + for p_shape, pt in zip(self.parameters_shapes, params_tangents): + + if p_shape: + assert ( + pt.shape[-len(p_shape) :] == p_shape + ), f"{pt.shape=} and {p_shape=}" + + p_size = utils.product(p_shape) + + params_tangents_flattened.append(pt.reshape([-1, p_size])) + + tangents = jnp.concatenate(params_tangents_flattened, axis=1) + + if jnp.iscomplexobj(tangents): + stats = ( + jnp.einsum("ay,az->yz", tangents.real, tangents.real) + - jnp.einsum("ay,az->yz", tangents.imag, tangents.imag)) / batch_size + else: + stats = jnp.einsum("ay,az->yz", tangents, tangents) / batch_size + + state.matrix.update(stats, ema_old, ema_new) + + return state + + +class DenseFull(Full): + """A `Full` block specifically for dense layers.""" + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: Full.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> Full.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + params_tangents = x[:, :, None] * dy[:, None, :] + + if self.number_of_parameters == 2: + params_tangents = jnp.concatenate([params_tangents, dy[:, None]], axis=1) + + params_tangents = jnp.reshape(params_tangents, [batch_size, -1]) + + matrix_update = jnp.matmul(params_tangents.T, params_tangents) / batch_size + state.matrix.update(matrix_update, ema_old, ema_new) + + return state + + +class Conv2DFull(Full): + """A :class:`~Full` block specifically for 2D convolution layers.""" + + def __init__( + self, + layer_tag_eq: tags.LayerTagEqn, + max_elements_for_vmap: int | None = None, + ): + """Initializes the block. + + Since there is no 'nice' formula for computing the average of the + tangents for a 2D convolution, what we do is that we have a function - + ``self.conv2d_tangent_squared`` - that computes for a single feature map the + square of the tangents for the kernel of the convolution. To average over + the batch we have two choices - vmap or loop over the batch sequentially + using scan. This utility function provides a trade-off by being able to + specify the maximum batch that that will be handled in a single iteration + of the loop. This means that the maximum memory usage will be + ``max_batch_size_for_vmap`` times the memory needed when calling + ``self.conv2d_tangent_squared``. And the actual ``vmap`` will be + called ``ceil(total_batch_size / max_batch_size_for_vmap)`` number of times + in a loop to find the final average. + + Args: + layer_tag_eq: The Jax equation corresponding to the layer tag, that this + block will approximate the curvature to. + max_elements_for_vmap: The threshold used for determining how much + computation to the in parallel and how much in serial manner. If + ``None`` will use the value returned by + :func:`~get_max_parallel_elements`. + """ + + self._averaged_tangents_outer_product = utils.loop_and_parallelize_average( + func=self.conv2d_tangent_outer_product, + max_parallel_size=max_elements_for_vmap or + cb_utils.get_max_parallel_elements(), + ) + + super().__init__(layer_tag_eq) + + def conv2d_tangent_outer_product( + self, + inputs: Array, + tangent_of_outputs: Array, + ) -> Array: + """Computes the outer product of a tangent for a single feature map.""" + + extra_params = {k: v for k, v in self.layer_tag_extra_params.items() + if k not in ("lhs_shape", "rhs_shape", "meta")} + + _, vjp = jax.vjp( + functools.partial( + jax.lax.conv_general_dilated, + **extra_params + ), + inputs[None], jnp.zeros(self.parameters_shapes[0]) + ) + + tangents = (vjp(tangent_of_outputs[None])[1],) + + if self.number_of_parameters == 2: + num_axis = tangent_of_outputs.ndim - len(self.parameters_shapes[1]) + sum_axis = tuple(range(num_axis)) + tangents += (jnp.sum(tangent_of_outputs, axis=sum_axis),) + + flat_tangents = self.parameters_list_to_single_vector(tangents) + + return jnp.outer(flat_tangents, flat_tangents) + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: Full.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> Full.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + assert utils.first_dim_is_size(batch_size, x, dy) + + matrix_update = self._averaged_tangents_outer_product(x, dy) + state.matrix.update(matrix_update, ema_old, ema_new) + + return state + + +class ScaleAndShiftFull(Full): + """A full dense approximation specifically for a scale and shift layers.""" + + @property + def _has_scale(self) -> bool: + """Whether this layer's equation has a scale.""" + return self._layer_tag_eq.params["has_scale"] + + @property + def _has_shift(self) -> bool: + """Whether this layer's equation has a shift.""" + return self._layer_tag_eq.params["has_shift"] + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: Full.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> Full.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + assert utils.first_dim_is_size(batch_size, x, dy) + + tangents = [] + + if self._has_scale: + # Scale tangent + scale_shape = estimation_data.primals.params[0].shape + + d_scale = cb_utils.compatible_sum(x * dy, scale_shape, skip_axes=[0]) + d_scale = d_scale.reshape([batch_size, -1]) + + tangents.append(d_scale) + + if self._has_shift: + # Shift tangent + + shift_shape = estimation_data.primals.params[-1].shape + + d_shift = cb_utils.compatible_sum(dy, shift_shape, skip_axes=[0]) + d_shift = d_shift.reshape([batch_size, -1]) + + tangents.append(d_shift) + + tangents = jnp.concatenate(tangents, axis=1) + matrix_update = jnp.matmul(tangents.T, tangents) / batch_size + + state.matrix.update(matrix_update, ema_old, ema_new) + + return state diff --git a/kfac_jax/_src/curvature_blocks/kronecker_factored.py b/kfac_jax/_src/curvature_blocks/kronecker_factored.py new file mode 100644 index 0000000..b0670a9 --- /dev/null +++ b/kfac_jax/_src/curvature_blocks/kronecker_factored.py @@ -0,0 +1,707 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing the Kronecker factored curvature blocks.""" +import abc +import math +from typing import Any, Sequence + +import jax.numpy as jnp +from kfac_jax._src import layers_and_loss_tags as tags +from kfac_jax._src import patches_second_moment as psm +from kfac_jax._src import tracer +from kfac_jax._src import utils +from kfac_jax._src.curvature_blocks import curvature_block +from kfac_jax._src.curvature_blocks import utils as cb_utils +from typing_extensions import Self + + +# Types for annotation +Array = utils.Array +Scalar = utils.Scalar +Numeric = utils.Numeric +PRNGKey = utils.PRNGKey +Shape = utils.Shape +CurvatureBlock = curvature_block.CurvatureBlock + + +class KroneckerFactored(CurvatureBlock, abc.ABC): + """An abstract class for approximating the block with a Kronecker product. + + The constructor takes two special arguments: + - parameters_specs: A list, where each element specifies for each + parameter a "rearrange string". This is in the format `abc->b(ca)` + similar to `einops.rearrange`. + - parameters_concat_axis: The axis along which the parameters will be + concatenated to form a single array after each parameter has been + rearranged according to its "rearrange string". + + The above implies that: + - All parameters must have the same rank after they have been rearranged. + - All parameters must have the same size along all axes except the + concatenation axis after they have been rearranged. + + By default, each parameter is rearanged to a matrix, by merging all dimensions + except the last one. If a parameter is a vector (rank 1), it is rearranged to + a matrix with the first dimension being 1. Then concatenation is done along + axis=0. + """ + + @utils.register_state_class + class State(CurvatureBlock.State): + """Persistent state of the block. + + Attributes: + factors: A tuple of the moving averages of the estimated factors of the + curvature for each axis group. + """ + + factors: tuple[utils.WeightedMovingAverage, ...] + + @classmethod + def from_dict(cls, dict_rep: dict[str, Any]) -> Self: + class_name = dict_rep.pop("__class__", cls.__name__) + assert class_name == cls.__name__ + return cls( + factors=tuple( + utils.WeightedMovingAverage.from_dict(rep) + for rep in dict_rep["factor"] + ) + ) + + def __init__( + self, + layer_tag_eq: tags.LayerTagEqn, + parameters_specs: Sequence[str] | None = None, + parameters_concat_axis: int = 0, + ): + + # Even though the superclass constructor will set this later, we need to do + # it now since it's used below. + self._layer_tag_eq = layer_tag_eq + + if parameters_specs is None: + parameters_specs = [] + + for shape in self.parameters_shapes: + + if len(shape) == 1: + parameters_specs.append("a -> 1a") + + else: + in_str = cb_utils.ALPHABET[:len(shape)] + out_str = f"({in_str[:-1]}){in_str[-1]}" + parameters_specs.append(f"{in_str} -> {out_str}") + + else: + assert len(parameters_specs) == self.number_of_parameters + + self.parameters_specs = parameters_specs + self.parameters_concat_axis = parameters_concat_axis + + super().__init__(layer_tag_eq) + + def __str__(self): + return ( + f"{self.__class__.__name__}(parameter_specs={self.parameters_specs}, " + f"parameters_concat_axis={self.parameters_concat_axis}), " + f"tag name: {self.name}, params shapes: {self.parameters_shapes!r}" + ) + + def parameters_shaped_list_to_array( + self, + parameters_shaped_list: Sequence[Array], + ) -> Array: + """Combines all parameters to a single array.""" + values = [] + for p, spec in zip( + parameters_shaped_list, + self.parameters_specs, + strict=True, + ): + values.append(utils.rearrange(p, spec)) + + return jnp.concatenate(values, axis=self.parameters_concat_axis) + + def array_to_parameters_shaped_list(self, array: Array) -> tuple[Array, ...]: + """An inverse transformation of ``self.parameters_shaped_list_to_array``.""" + parameters_list = [] + n = 0 + index = [slice(None)] * array.ndim + + for shape, spec in zip( + self.parameters_shapes, + self.parameters_specs, + strict=True, + ): + zero = utils.rearrange(jnp.zeros(shape), spec) + d = zero.shape[self.parameters_concat_axis] + index[self.parameters_concat_axis] = slice(n, n + d) + p = array[tuple(index)] + parameters_list.append(p.reshape(shape)) + n += d + + return tuple(parameters_list) + + @property + def array_shape(self) -> Shape: + """The shape of the single non axis grouped array.""" + avals = [jnp.zeros(shape) for shape in self.parameters_shapes] + return self.parameters_shaped_list_to_array(avals).shape + + @property + def array_ndim(self) -> int: + """The number of dimensions of the single non axis grouped array.""" + return len(self.array_shape) + + def _init( + self, + rng: PRNGKey, + exact_powers_to_cache: set[Scalar], + approx_powers_to_cache: set[Scalar], + cache_eigenvalues: bool, + ) -> State: + + cache = {} + factors = [] + + for i, d in enumerate(self.array_shape): + + factors.append( + utils.WeightedMovingAverage.zeros_array((d, d), self.dtype) + ) + + if cache_eigenvalues or exact_powers_to_cache: + cache[f"{i}_factor_eigenvalues"] = jnp.zeros((d,), dtype=self.dtype) + + if exact_powers_to_cache: + cache[f"{i}_factor_eigen_vectors"] = jnp.zeros((d, d), dtype=self.dtype) + + for power in approx_powers_to_cache: + + if power != -1: + raise NotImplementedError( + f"Approximations for power {power} is not yet implemented." + ) + + if str(power) not in cache: + cache[str(power)] = {} + + cache[str(power)][f"{i}_factor"] = jnp.zeros((d, d), dtype=self.dtype) + + return KroneckerFactored.State( + cache=cache, + factors=tuple(factors), + ) + + def sync( + self, + state: State, + pmap_axis_name: str, + ) -> State: + + # Copy this first since we mutate it later in this function. + state = state.copy() + + for factor in state.factors: + factor.sync(pmap_axis_name) + + return state + + def _multiply_matpower_unscaled( + self, + state: State, + vector: Sequence[Array], + identity_weight: Numeric, + power: Scalar, + exact_power: bool, + use_cached: bool, + ) -> tuple[Array, ...]: + + assert len(state.factors) == self.array_ndim + + vector = self.parameters_shaped_list_to_array(vector) + + if power == 1: + + factors = [f.value for f in state.factors] + + # state_dependent_scale needs to be included here because it won't be by + # the caller of this function (multiply_matpower) when use_cached=True. + # This is not an issue for other powers because they bake in + # state_dependent_scale. + scale = self.state_dependent_scale(state) if use_cached else 1.0 + + if exact_power: + result = scale * utils.kronecker_product_axis_mul_v(factors, vector) + result = result + identity_weight * vector + + else: + # If compute pi_adjusted_kronecker_factors used a more expensive matrix + # norm in its computation, it might make sense to cache it. But we + # currently don't do that. + + result = scale * utils.kronecker_product_axis_mul_v( + utils.pi_adjusted_kronecker_factors( + *factors, damping=identity_weight / scale), + vector) + + elif exact_power: + + if use_cached: + s = [ + state.cache[f"{i}_factor_eigenvalues"] + for i in range(len(state.factors)) + ] + q = [ + state.cache[f"{i}_factor_eigen_vectors"] + for i in range(len(state.factors)) + ] + + else: + s, q = zip( + *[utils.safe_psd_eigh(factor.value) for factor in state.factors] + ) + + eigenvalues = utils.outer_product(*s) + identity_weight + eigenvalues = jnp.power(eigenvalues, power) + + result = utils.kronecker_eigen_basis_axis_mul_v(q, eigenvalues, vector) + + else: + + if power != -1 and power != -0.5: + raise NotImplementedError( + f"Approximations for power {power} is not yet implemented." + ) + + if use_cached: + + assert power != -0.5 + + factors = [ + state.cache[str(power)][f"{i}_factor"] + for i in range(len(state.factors)) + ] + + else: + + factors = [factor.value for factor in state.factors] + + factors = utils.pi_adjusted_kronecker_factors( + *factors, damping=identity_weight) + + if power == -1: + factors = utils.invert_psd_matrices(factors) + elif power == -0.5: + factors = utils.inverse_sqrt_psd_matrices(factors) + else: + raise NotImplementedError() + + result = utils.kronecker_product_axis_mul_v(factors, vector) + + return self.array_to_parameters_shaped_list(result) + + def _eigenvalues_unscaled( + self, + state: State, + use_cached: bool, + ) -> Array: + + assert len(state.factors) == self.array_ndim + + if use_cached: + s = [ + state.cache[f"{i}_factor_eigenvalues"] + for i in range(len(state.factors)) + ] + else: + s_q = [utils.safe_psd_eigh(factor.value) for factor in state.factors] + s, _ = zip(*s_q) + + return utils.outer_product(*s) + + def _update_cache( + self, + state: State, + identity_weight: Numeric, + exact_powers: set[Scalar], + approx_powers: set[Scalar], + eigenvalues: bool, + ) -> State: + + assert len(state.factors) == self.array_ndim + + # Copy this first since we mutate it later in this function. + state = state.copy() + + scale = self.state_dependent_scale(state) + factor_scale = jnp.power(scale, 1.0 / self.array_ndim) + + if eigenvalues or exact_powers: + + s_q = [utils.safe_psd_eigh(factor.value) for factor in state.factors] + + s, q = zip(*s_q) + + for i in range(len(state.factors)): + state.cache[f"{i}_factor_eigenvalues"] = factor_scale * s[i] + + if exact_powers: + state.cache[f"{i}_factor_eigen_vectors"] = q[i] + + for power in approx_powers: + + if power != -1: + raise NotImplementedError( + f"Approximations for power {power} is not yet implemented." + ) + + cache = state.cache[str(power)] + + # This computes the approximate inverse factors using the generalization + # of the pi-adjusted inversion from the original KFAC paper. + inv_factors = utils.pi_adjusted_kronecker_inverse( + *[factor.value for factor in state.factors], + damping=identity_weight, + ) + + for i in range(len(state.factors)): + cache[f"{i}_factor"] = inv_factors[i] / factor_scale + + return state + + def _norm_unscaled( + self, + state: CurvatureBlock.State, + norm_type: str + ) -> Numeric: + + return utils.product( + utils.psd_matrix_norm(f.value, norm_type=norm_type) + for f in state.factors) + + def _to_dense_unscaled(self, state: "KroneckerFactored.State") -> Array: + + # We currently support this only for 2 parameters + assert 0 < self.number_of_parameters <= 2 + inputs_factor = state.factors[0].value + + if (self.number_of_parameters == 2 and + self.parameters_canonical_order[0] != 0): + + # Permute the matrix according to the parameters canonical order + inputs_factor = utils.block_permuted( + state.factors[0].value, + block_sizes=[state.factors[0].raw_value.shape[0] - 1, 1], + block_order=(1, 0), + ) + + return jnp.kron(inputs_factor, state.factors[1].value) + + +class DenseTwoKroneckerFactored(KroneckerFactored): + """A :class:`~TwoKroneckerFactored` block specifically for dense layers.""" + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + assert 1 <= self.number_of_parameters <= 2 + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + if self.number_of_parameters == 2: + x_one = jnp.ones_like(x[:, :1]) + x = jnp.concatenate([x, x_one], axis=1) + + input_stats = jnp.einsum("ay,az->yz", x, x) / batch_size + output_stats = jnp.einsum("ay,az->yz", dy, dy) / batch_size + + state.factors[0].update(input_stats, ema_old, ema_new) + state.factors[1].update(output_stats, ema_old, ema_new) + + return state + + +class RepeatedDenseKroneckerFactored(DenseTwoKroneckerFactored): + """Block for dense layers applied to tensors with extra time/loc dims.""" + + @utils.register_state_class + class State(KroneckerFactored.State): + """Persistent state of the block. + + Attributes: + average_repeats: A decayed average of the per-case number of non-masked + repeats in the data used to compute the block's statistics. We use the + same decayed averaging for this quantity that we do for the statistics, + so that they "match". + """ + + average_repeats: utils.WeightedMovingAverage + + def __init__( + self, + layer_tag_eq: tags.LayerTagEqn, + use_masking: bool = True, + parameters_specs: Sequence[str] | None = None, + parameters_concat_axis: int = 0, + ): + self._use_masking = use_masking + super().__init__( + layer_tag_eq=layer_tag_eq, + parameters_specs=parameters_specs, + parameters_concat_axis=parameters_concat_axis, + ) + + def _init( + self, + rng: PRNGKey, + exact_powers_to_cache: set[Scalar], + approx_powers_to_cache: set[Scalar], + cache_eigenvalues: bool, + ) -> "RepeatedDenseKroneckerFactored.State": + + super_state = super()._init( + rng, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues + ) + + return RepeatedDenseKroneckerFactored.State( + average_repeats=utils.WeightedMovingAverage.zeros_array((), self.dtype), + **super_state.__dict__, + ) + + def state_dependent_scale( + self, + state: "RepeatedDenseKroneckerFactored.State", + ) -> Numeric: + return 1.0 / state.average_repeats.value + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + if self._use_masking: + + # hack: we identify masked repeats by checking if all corresponding + # entries of dy are zero + mask = 1.0 - jnp.all(dy == 0.0, axis=-1, keepdims=True) + + # zero out corresponding elts of x + x = x * mask + + # compute total non-masked + total = jnp.sum(mask) + + else: + total = math.prod(dy.shape[:-1]) + + x = x.reshape([-1, x.shape[-1]]) + dy = dy.reshape([-1, dy.shape[-1]]) + + if self.number_of_parameters == 2: + x_one = jnp.ones_like(x[:, :1]) + x = jnp.concatenate([x, x_one], axis=1) + + input_stats = jnp.einsum("ay,az->yz", x, x) / batch_size + output_stats = jnp.einsum("ay,az->yz", dy, dy) / batch_size + + state.factors[0].update(input_stats, ema_old, ema_new) + state.factors[1].update(output_stats, ema_old, ema_new) + state.average_repeats.update(total / batch_size, ema_old, ema_new) + + return state + + +class Conv2DTwoKroneckerFactored(KroneckerFactored): + """A :class:`~TwoKroneckerFactored` block specifically for 2D convolution layers.""" + + def fixed_scale(self) -> Numeric: + return float(self.num_locations) + + @property + def kernel_output_axis(self) -> int: + return self._layer_tag_eq.params["dimension_numbers"].rhs_spec[0] + + @property + def outputs_channel_index(self) -> int: + """The ``channels`` index in the outputs of the layer.""" + return self._layer_tag_eq.params["dimension_numbers"].out_spec[1] + + @property + def inputs_channel_index(self) -> int: + """The ``channels`` index in the inputs of the layer.""" + return self._layer_tag_eq.params["dimension_numbers"].lhs_spec[1] + + @property + def weights_output_channel_index(self) -> int: + """The ``channels`` index in weights of the layer.""" + return self._layer_tag_eq.params["dimension_numbers"].rhs_spec[0] + + @property + def weights_spatial_shape(self) -> Shape: + spatial_index = self._layer_tag_eq.params["dimension_numbers"].rhs_spec[2:] + return tuple(self.parameters_shapes[0][i] for i in spatial_index) + + @property + def weights_spatial_size(self) -> int: + """The spatial filter size of the weights.""" + return utils.product(dim for dim in self.weights_spatial_shape) + + @property + def inputs_spatial_shape(self) -> Shape: + spatial_index = self._layer_tag_eq.params["dimension_numbers"].lhs_spec[2:] + return tuple(self.inputs_shapes[0][i] for i in spatial_index) + + @property + def num_locations(self) -> int: + """The number of spatial locations that each filter is applied to.""" + return psm.num_conv_locations( + self.inputs_spatial_shape, + self.weights_spatial_shape, + self._layer_tag_eq.params["window_strides"], + self._layer_tag_eq.params["padding"]) + + def input_size(self) -> int: + if self.has_bias: + return self.num_inputs_channels * self.weights_spatial_size + 1 + else: + return self.num_inputs_channels * self.weights_spatial_size + + def output_size(self) -> int: + return self.num_outputs_channels + + @property + def num_inputs_channels(self) -> int: + """The number of channels in the inputs to the layer.""" + return self._layer_tag_eq.invars[0].aval.shape[ # pytype: disable=attribute-error + self.inputs_channel_index] + + @property + def num_outputs_channels(self) -> int: + """The number of channels in the outputs to the layer.""" + return self._layer_tag_eq.invars[1].aval.shape[ # pytype: disable=attribute-error + self.weights_output_channel_index] + + def compute_inputs_stats( + self, + inputs: Array, + weighting_array: Array | None = None, + ) -> Array: + """Computes the statistics for the inputs factor.""" + batch_size = inputs.shape[0] + + input_cov_m, input_cov_v = psm.patches_moments( + inputs, + kernel_spatial_shape=self.weights_spatial_shape, + strides=self._layer_tag_eq.params["window_strides"], + padding=self._layer_tag_eq.params["padding"], + data_format=None, + dim_numbers=self._layer_tag_eq.params["dimension_numbers"], + precision=self._layer_tag_eq.params.get("precision"), + weighting_array=weighting_array, + ) + + # Flatten the kernel and channels dimensions + k, h, c = input_cov_v.shape + input_cov_v = jnp.reshape(input_cov_v, (k * h * c,)) + input_cov_m = jnp.reshape(input_cov_m, (k * h * c, k * h * c)) + + # Normalize by the `batch size` * `num_locations` + normalizer = batch_size * self.num_locations + input_cov_m = input_cov_m / normalizer + input_cov_v = input_cov_v / normalizer + + if self.number_of_parameters == 1: + return input_cov_m + + if weighting_array is None: + corner = jnp.ones([1], dtype=input_cov_m.dtype) + else: + corner = jnp.mean(weighting_array).reshape([1]) + + input_cov = jnp.concatenate([input_cov_m, input_cov_v[None]], axis=0) + input_cov_v = jnp.concatenate([input_cov_v, corner], axis=0) + + return jnp.concatenate([input_cov, input_cov_v[:, None]], axis=1) + + def compute_outputs_stats(self, tangent_of_output: Array) -> Array: + """Computes the statistics for the outputs factor.""" + lhs_str = utils.replace_char( + cb_utils.ALPHABET[:4], "y", self.outputs_channel_index) + rhs_str = utils.replace_char( + cb_utils.ALPHABET[:4], "z", self.outputs_channel_index) + ein_str = f"{lhs_str},{rhs_str}->yz" + stats = jnp.einsum(ein_str, tangent_of_output, tangent_of_output) + + # Normalize by the `batch size` * `num_locations` + normalizer = tangent_of_output.shape[0] * self.num_locations + return stats / normalizer + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + assert 1 <= self.number_of_parameters <= 2 + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + input_stats = self.compute_inputs_stats(x) + output_stats = self.compute_outputs_stats(dy) + + state.factors[0].update(input_stats, ema_old, ema_new) + state.factors[1].update(output_stats, ema_old, ema_new) + + return state diff --git a/kfac_jax/_src/curvature_blocks/tnt.py b/kfac_jax/_src/curvature_blocks/tnt.py new file mode 100644 index 0000000..956c244 --- /dev/null +++ b/kfac_jax/_src/curvature_blocks/tnt.py @@ -0,0 +1,247 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing TNT curvature blocks.""" +from typing import Sequence + +import jax +import jax.numpy as jnp +import jax.scipy +from kfac_jax._src import layers_and_loss_tags as tags +from kfac_jax._src import tracer +from kfac_jax._src import utils +from kfac_jax._src.curvature_blocks import kronecker_factored +from kfac_jax._src.curvature_blocks import utils as cb_utils + + +# Types for annotation +Array = utils.Array +Numeric = utils.Numeric +KroneckerFactored = kronecker_factored.KroneckerFactored + + +class NaiveTNT(KroneckerFactored): + """A standard TNT block for a single parameter, or weights + bias. + + Each factor of the standard TNT curvature approximation estimates the expected + value of the contraction of the gradients with themselves along all but a + single axis `i`: + ``F_i ~~ E[contract_all_but_one(g, g, i)].`` + where `contrat_all_but_one` is defined as the contraction over all axes except + the i-th of its first two inputs, e.g.: + ``contract_all_but_one(A, B, 1)[a,b] = sum_{i,j} A[i, a, j] B[i, b, j]`` + + The estimation is performed in a naive way by contracting the sum of each + examples' gradients and then dividing by the batch size: + ``F_i = contract_all_but_one(sum_n g_n, sum_n g_n, i) / N`` + where `g_n` is the model gradient of a single example and `N` is the + batch size. Since the expectations of the gradients over the model + distribution is zero and they are independent across cases, this is still an + unbiased estimator. + """ + + def state_dependent_scale( + self, + state: "NaiveTNT.State", + ) -> Numeric: + return utils.tnt_scale([factor.value for factor in state.factors]) + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + dw = self.parameters_shaped_list_to_array(estimation_data.tangents.params) + + assert dw.ndim == len(state.factors) + + in_str = cb_utils.ALPHABET[: dw.ndim] + + for i, factor in enumerate(state.factors): + # For factor i we contract the gradient with itself along all axes, + # except the i-th. + + lhs_str = utils.replace_char(in_str, "y", i) + rhs_str = utils.replace_char(in_str, "z", i) + + # This is a rank-1 mod since it's like we flattened all but dim i together + # and then did an outer product + factor_update = ( + jnp.einsum(f"{lhs_str},{rhs_str}->yz", dw, dw) / batch_size + ) + + factor.update(factor_update, ema_old, ema_new) + + return state + + +class DenseTNT(kronecker_factored.DenseTwoKroneckerFactored): + """A TNT block for dense layers. + + This TNT block modifies :class:`~NaiveTNTBlock` by the way it estimates each + factor specifically for a dense layer. Instead of using the contraction over + the summed gradient, it performs the contraction over each individual batch + elements and then averages over the batch: + ``F_i = sum_n contract_all_but_one(g_n, g_n, i) / N`` + The estimator is unbiased, and will have lower variance then the naive one. + """ + + def state_dependent_scale(self, state: "DenseTNT.State") -> Numeric: + return utils.tnt_scale([factor.value for factor in state.factors]) + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + if self.number_of_parameters == 2: + x_one = jnp.ones_like(x[:, :1]) + x = jnp.concatenate([x, x_one], axis=1) + + # We multiply each x by the norm_y, and each dy by the norm of x + dy_norms = jnp.linalg.norm(dy, axis=-1, keepdims=True) + x_norms = jnp.linalg.norm(x, axis=-1, keepdims=True) + x = x * dy_norms + dy = dy * x_norms + + input_stats = jnp.einsum("ay,az->yz", x, x) / batch_size + output_stats = jnp.einsum("ay,az->yz", dy, dy) / batch_size + + state.factors[0].update(input_stats, ema_old, ema_new) + state.factors[1].update(output_stats, ema_old, ema_new) + + return state + + +class Conv2DTNT(kronecker_factored.Conv2DTwoKroneckerFactored): + """A TNT block for Conv2D layers. + + This TNT block modifies :class:`~NaiveTNTBlock` by the way it estimates each + factor specifically for a conv2D layer. Importantly, it assumes "location + independence" similar to :class:~`Conv2DTwoKroneckerFactored`. Given this + assumption, instead of using the contraction over the summed gradient, it + performs the contraction for each individual example in the batch, and each + individual spatial location, and then averages over these: + ``F_i = sum_n sum_t contract_all_but_one(g_{n,t}, g_{n,t}, i) / (N * T)`` + where T here is the number of spatial locations. The estimator is unbiased + (under the "location independence" approximation), and will have lower + variance then the naive one. + + If the argument `weighting_per_location` is set to `False`, then the block + uses a mixture between location-independence and not, in the sense that it + computes the contractions per example, while the matrix factor statistics + still assume location independence. + """ + + def __init__( + self, + layer_tag_eq: tags.LayerTagEqn, + weighting_per_location: bool = True, + parameters_specs: Sequence[str] | None = None, + parameters_concat_axis: int = 0, + ): + self.weighting_per_location = weighting_per_location + super().__init__( + layer_tag_eq=layer_tag_eq, + parameters_specs=parameters_specs, + parameters_concat_axis=parameters_concat_axis, + ) + + def state_dependent_scale( + self, state: "Conv2DTNT.State" + ) -> Numeric: + return utils.tnt_scale([factor.value for factor in state.factors]) + + def x_squared_spatial_norms(self, x: Array) -> Array: + + kernel_shape = list(self.parameters_shapes[0]) + kernel_shape[self.kernel_output_axis] = 1 + + return jax.lax.conv_general_dilated( + lhs=x * x, + rhs=jnp.ones(kernel_shape), + window_strides=self.layer_tag_extra_params["window_strides"], + padding=self.layer_tag_extra_params["padding"], + lhs_dilation=self.layer_tag_extra_params["lhs_dilation"], + rhs_dilation=self.layer_tag_extra_params["rhs_dilation"], + dimension_numbers=self.layer_tag_extra_params["dimension_numbers"], + feature_group_count=self.layer_tag_extra_params["feature_group_count"], + precision=self.layer_tag_extra_params["precision"], + preferred_element_type= + self.layer_tag_extra_params["preferred_element_type"], + ) + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + # We multiply each x by the norm_y, and each dy by the norm of x + dy_sq_norms = jnp.sum(dy * dy, axis=self.outputs_channel_index) + x_sq_norms = self.x_squared_spatial_norms(x) + + if self.number_of_parameters == 2: + # When we have a bias we need to add 1 coming from it to the squared norm + x_sq_norms = x_sq_norms + 1 + + if not self.weighting_per_location: + dy_sq_norms = jnp.sum(dy_sq_norms, axis=[1, 2]) + x_sq_norms = jnp.sum(x_sq_norms, axis=[1, 2, 3], keepdims=True) + + input_cov = self.compute_inputs_stats(x, weighting_array=dy_sq_norms) + output_cov = self.compute_outputs_stats(dy * jnp.sqrt(x_sq_norms)) + + state.factors[0].update(input_cov, ema_old, ema_new) + state.factors[1].update(output_cov, ema_old, ema_new) + + return state diff --git a/kfac_jax/_src/curvature_blocks/utils.py b/kfac_jax/_src/curvature_blocks/utils.py new file mode 100644 index 0000000..0dbad99 --- /dev/null +++ b/kfac_jax/_src/curvature_blocks/utils.py @@ -0,0 +1,137 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing utility functions for curvature blocks.""" +import collections +import string +from typing import Sequence + +import jax.numpy as jnp +from kfac_jax._src import utils + + +# Types for annotation +Scalar = utils.Scalar +ScalarOrSequence = Scalar | Sequence[Scalar] + +# Special global variables +# This is used for einsum strings +ALPHABET = string.ascii_lowercase +# The default value that would be used for the argument +# ``max_elements_for_vmap``, when it is set to ``None`` in the +# ``Conv2DDiagonal`` and ``Conv2DFull` curvature blocks. +_MAX_PARALLEL_ELEMENTS: int = 2 ** 23 +# The default value that would be used for the argument +# ``eigen_decomposition_threshold``, when it is set to ``None`` in any of the +# curvature blocks that inherit from ``Full`. +_DEFAULT_EIGEN_DECOMPOSITION_THRESHOLD = 5 + + +def set_max_parallel_elements(value: int): + """Sets the default value of maximum parallel elements in the module. + + This value is used to determine the parallel-to-memory tradeoff in the + curvature estimation procedure of :class:`~Conv2DDiagonal` and + :class:`~Conv2DFull`. See their corresponding docs for further details. + + Args: + value: The default value for maximum number of parallel elements. + """ + global _MAX_PARALLEL_ELEMENTS + _MAX_PARALLEL_ELEMENTS = value + + +def get_max_parallel_elements() -> int: + """Returns the default value of maximum parallel elements in the module. + + This value is used to determine the parallel-to-memory tradeoff in the + curvature estimation procedure of :class:`~Conv2DDiagonal` and + :class:`~Conv2DFull`. See their corresponding docs for further details. + + Returns: + The default value for maximum number of parallel elements. + """ + return _MAX_PARALLEL_ELEMENTS + + +def set_default_eigen_decomposition_threshold(value: int): + """Sets the default value of the eigen decomposition threshold. + + This value is used in :class:`~Full` to determine when updating the cache, + at what number of different powers to switch the implementation from a simple + matrix power to an eigenvector decomposition. + + Args: + value: The default value for eigen decomposition threshold. + """ + global _DEFAULT_EIGEN_DECOMPOSITION_THRESHOLD + _DEFAULT_EIGEN_DECOMPOSITION_THRESHOLD = value + + +def get_default_eigen_decomposition_threshold() -> int: + """Returns the default value of the eigen decomposition threshold. + + This value is used in :class:`~Full` to determine when updating the cache, + at what number of different powers to switch the implementation from a simple + matrix power to an eigenvector decomposition. + + Returns: + The default value of the eigen decomposition threshold. + """ + return _DEFAULT_EIGEN_DECOMPOSITION_THRESHOLD + + +def to_real_set( + number_or_sequence: ScalarOrSequence | None +) -> set[Scalar]: + """Converts the optional number or sequence to a set.""" + if number_or_sequence is None: + return set() + elif isinstance(number_or_sequence, set): + return number_or_sequence + elif isinstance(number_or_sequence, (float, int)): + return {number_or_sequence} + elif (isinstance(number_or_sequence, collections.abc.Sequence) and + all(isinstance(x, (int, float)) for x in number_or_sequence)): + return set(number_or_sequence) + else: + raise ValueError(f"Expecting a real-number or a sequence of reals, but got " + f"{type(number_or_sequence)}.") + + +def compatible_shapes(ref_shape, target_shape): + + if len(target_shape) > len(ref_shape): + raise ValueError("Target shape should be smaller.") + + for ref_d, target_d in zip(reversed(ref_shape), reversed(target_shape)): + if ref_d != target_d and target_d != 1: + raise ValueError(f"{target_shape} is incompatible with {ref_shape}.") + + +def compatible_sum(tensor, target_shape, skip_axes): + """Compute sum over ``tensor`` to achieve shape given by ``target_shape``.""" + + compatible_shapes(tensor.shape, target_shape) + + n = tensor.ndim - len(target_shape) + + axis = [i + n for i, t in enumerate(target_shape) + if t == 1 and i + n not in skip_axes] + + tensor = jnp.sum(tensor, axis=axis, keepdims=True) + + axis = [i for i in range(tensor.ndim - len(target_shape)) + if i not in skip_axes] + + return jnp.sum(tensor, axis=axis)