From 60644ef3c46d68923d04748126b672058ab538be Mon Sep 17 00:00:00 2001 From: Alex Botev Date: Wed, 21 Aug 2024 11:37:26 -0700 Subject: [PATCH] Add TNT blocks to kfac_jax. This CL adds TNT blocks to kfac_jax. TNT blocks are a type of curvature approximation that is generally imposes less structure knowledge of the computation, and mainly relies on the array shape of the parameters. It also adds a RepatedDenseKroneckerFactored which extends the usual DenseTwoKroneckerFacored to the case where a dense layer is applied in parallel over an axis of inputs (like a time axis in sequence models). PiperOrigin-RevId: 665965233 --- kfac_jax/__init__.py | 8 + kfac_jax/_src/curvature_blocks.py | 332 ++++++++++++++++++++++++++++++ kfac_jax/_src/utils/__init__.py | 1 + kfac_jax/_src/utils/math.py | 16 ++ 4 files changed, 357 insertions(+) diff --git a/kfac_jax/__init__.py b/kfac_jax/__init__.py index ddd9e33..64a988d 100644 --- a/kfac_jax/__init__.py +++ b/kfac_jax/__init__.py @@ -84,12 +84,16 @@ KroneckerFactored = curvature_blocks.KroneckerFactored NaiveDiagonal = curvature_blocks.NaiveDiagonal NaiveFull = curvature_blocks.NaiveFull +NaiveTNT = curvature_blocks.NaiveTNT DenseDiagonal = curvature_blocks.DenseDiagonal DenseFull = curvature_blocks.DenseFull DenseTwoKroneckerFactored = curvature_blocks.DenseTwoKroneckerFactored +RepeatedDenseKroneckerFactored = curvature_blocks.RepeatedDenseKroneckerFactored +DenseTNT = curvature_blocks.DenseTNT Conv2DDiagonal = curvature_blocks.Conv2DDiagonal Conv2DFull = curvature_blocks.Conv2DFull Conv2DTwoKroneckerFactored = curvature_blocks.Conv2DTwoKroneckerFactored +Conv2DTNT = curvature_blocks.Conv2DTNT ScaleAndShiftDiagonal = curvature_blocks.ScaleAndShiftDiagonal ScaleAndShiftFull = curvature_blocks.ScaleAndShiftFull set_max_parallel_elements = curvature_blocks.set_max_parallel_elements @@ -165,12 +169,16 @@ "KroneckerFactored", "NaiveDiagonal", "NaiveFull", + "NaiveTNT", "DenseDiagonal", "DenseFull", "DenseTwoKroneckerFactored", + "RepeatedDenseKroneckerFactored", + "DenseTNT", "Conv2DDiagonal", "Conv2DFull", "Conv2DTwoKroneckerFactored", + "Conv2DTNT", "ScaleAndShiftDiagonal", "ScaleAndShiftFull", "set_max_parallel_elements", diff --git a/kfac_jax/_src/curvature_blocks.py b/kfac_jax/_src/curvature_blocks.py index 92d8006..cfc1f7c 100644 --- a/kfac_jax/_src/curvature_blocks.py +++ b/kfac_jax/_src/curvature_blocks.py @@ -15,6 +15,7 @@ import abc import collections import functools +import math import string from typing import Any, Sequence @@ -1432,6 +1433,15 @@ def _to_dense_unscaled(self, state: "KroneckerFactored.State") -> Array: return jnp.kron(inputs_factor, state.factors[1].value) +# _ _ _ +# | \ | | (_) +# | \| | __ _ ___ _____ +# | . ` |/ _` | \ \ / / _ \ +# | |\ | (_| | |\ V / __/ +# |_| \_|\__,_|_| \_/ \___| +# + + class NaiveDiagonal(Diagonal): """Approximates the diagonal of the curvature with in the most obvious way. @@ -1522,6 +1532,71 @@ def update_curvature_matrix_estimate( return state +class NaiveTNT(KroneckerFactored): + """A standard TNT block for a single parameter, or weights + bias. + + Each factor of the standard TNT curvature approximation estimates the expected + value of the contraction of the gradients with themselves along all but a + single axis `i`: + ``F_i ~~ E[contract_all_but_one(g, g, i)].`` + where `contrat_all_but_one` is defined as the contraction over all axes except + the i-th of its first two inputs, e.g.: + ``contract_all_but_one(A, B, 1)[a,b] = sum_{i,j} A[i, a, j] B[i, b, j]`` + + The estimation is performed in a naive way by contracting the sum of each + examples' gradients and then dividing by the batch size: + ``F_i = contract_all_but_one(sum_n g_n, sum_n g_n, i) / N`` + where `g_n` is the model gradient of a single example and `N` is the + batch size. Since the expectations of the gradients over the model + distribution is zero and they are independent across cases, this is still an + unbiased estimator. + """ + + def state_dependent_scale( + self, + state: "NaiveTNT.State", + ) -> Numeric: + return utils.tnt_scale([factor.value for factor in state.factors]) + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + dw = self.parameters_shaped_list_to_array(estimation_data.tangents.params) + + assert dw.ndim == len(state.factors) + + in_str = _ALPHABET[: dw.ndim] + + for i, factor in enumerate(state.factors): + # For factor i we contract the gradient with itself along all axes, + # except the i-th. + + lhs_str = utils.replace_char(in_str, "y", i) + rhs_str = utils.replace_char(in_str, "z", i) + + # This is a rank-1 mod since it's like we flattened all but dim i together + # and then did an outer product + factor_update = ( + jnp.einsum(f"{lhs_str},{rhs_str}->yz", dw, dw) / batch_size + ) + + factor.update(factor_update, ema_old, ema_new) + + return state + + # _____ # | __ \ # | | | | ___ _ __ ___ ___ @@ -1644,6 +1719,164 @@ def update_curvature_matrix_estimate( return state +class RepeatedDenseKroneckerFactored(DenseTwoKroneckerFactored): + """Block for dense layers applied to tensors with extra time/loc dims.""" + + @utils.register_state_class + class State(KroneckerFactored.State): + """Persistent state of the block. + + Attributes: + average_repeats: A decayed average of the per-case number of non-masked + repeats in the data used to compute the block's statistics. We use the + same decayed averaging for this quantity that we do for the statistics, + so that they "match". + """ + + average_repeats: utils.WeightedMovingAverage + + def __init__( + self, + layer_tag_eq: tags.LayerTagEqn, + use_masking: bool = True, + parameters_specs: Sequence[str] | None = None, + parameters_concat_axis: int = 0, + ): + self._use_masking = use_masking + super().__init__( + layer_tag_eq=layer_tag_eq, + parameters_specs=parameters_specs, + parameters_concat_axis=parameters_concat_axis, + ) + + def _init( + self, + rng: PRNGKey, + exact_powers_to_cache: set[Scalar], + approx_powers_to_cache: set[Scalar], + cache_eigenvalues: bool, + ) -> "RepeatedDenseKroneckerFactored.State": + + super_state = super()._init( + rng, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues + ) + + return RepeatedDenseKroneckerFactored.State( + average_repeats=utils.WeightedMovingAverage.zeros_array((), self.dtype), + **super_state.__dict__, + ) + + def state_dependent_scale( + self, + state: "RepeatedDenseKroneckerFactored.State", + ) -> Numeric: + return 1.0 / state.average_repeats.value + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + if self._use_masking: + + # hack: we identify masked repeats by checking if all corresponding + # entries of dy are zero + mask = 1.0 - jnp.all(dy == 0.0, axis=-1, keepdims=True) + + # zero out corresponding elts of x + x = x * mask + + # compute total non-masked + total = jnp.sum(mask) + + else: + total = math.prod(dy.shape[:-1]) + + x = x.reshape([-1, x.shape[-1]]) + dy = dy.reshape([-1, dy.shape[-1]]) + + if self.number_of_parameters == 2: + x_one = jnp.ones_like(x[:, :1]) + x = jnp.concatenate([x, x_one], axis=1) + + input_stats = jnp.einsum("ay,az->yz", x, x) / batch_size + output_stats = jnp.einsum("ay,az->yz", dy, dy) / batch_size + + state.factors[0].update(input_stats, ema_old, ema_new) + state.factors[1].update(output_stats, ema_old, ema_new) + state.average_repeats.update(total / batch_size, ema_old, ema_new) + + return state + + +class DenseTNT(DenseTwoKroneckerFactored): + """A TNT block for dense layers. + + This TNT block modifies :class:`~NaiveTNTBlock` by the way it estimates each + factor specifically for a dense layer. Instead of using the contraction over + the summed gradient, it performs the contraction over each individual batch + elements and then averages over the batch: + ``F_i = sum_n contract_all_but_one(g_n, g_n, i) / N`` + The estimator is unbiased, and will have lower variance then the naive one. + """ + + def state_dependent_scale(self, state: "DenseTNT.State") -> Numeric: + return utils.tnt_scale([factor.value for factor in state.factors]) + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + if self.number_of_parameters == 2: + x_one = jnp.ones_like(x[:, :1]) + x = jnp.concatenate([x, x_one], axis=1) + + # We multiply each x by the norm_y, and each dy by the norm of x + dy_norms = jnp.linalg.norm(dy, axis=-1, keepdims=True) + x_norms = jnp.linalg.norm(x, axis=-1, keepdims=True) + x = x * dy_norms + dy = dy * x_norms + + input_stats = jnp.einsum("ay,az->yz", x, x) / batch_size + output_stats = jnp.einsum("ay,az->yz", dy, dy) / batch_size + + state.factors[0].update(input_stats, ema_old, ema_new) + state.factors[1].update(output_stats, ema_old, ema_new) + + return state + + # _____ ___ _____ # / ____| |__ \| __ \ # | | ___ _ ____ __ ) | | | | @@ -1994,6 +2227,105 @@ def update_curvature_matrix_estimate( return state +class Conv2DTNT(Conv2DTwoKroneckerFactored): + """A TNT block for Conv2D layers. + + This TNT block modifies :class:`~NaiveTNTBlock` by the way it estimates each + factor specifically for a conv2D layer. Importantly, it assumes "location + independence" similar to :class:~`Conv2DTwoKroneckerFactored`. Given this + assumption, instead of using the contraction over the summed gradient, it + performs the contraction for each individual example in the batch, and each + individual spatial location, and then averages over these: + ``F_i = sum_n sum_t contract_all_but_one(g_{n,t}, g_{n,t}, i) / (N * T)`` + where T here is the number of spatial locations. The estimator is unbiased + (under the "location independence" approximation), and will have lower + variance then the naive one. + + If the argument `weighting_per_location` is set to `False`, then the block + uses a mixture between location-independence and not, in the sense that it + computes the contractions per example, while the matrix factor statistics + still assume location independence. + """ + + def __init__( + self, + layer_tag_eq: tags.LayerTagEqn, + weighting_per_location: bool = True, + parameters_specs: Sequence[str] | None = None, + parameters_concat_axis: int = 0, + ): + self.weighting_per_location = weighting_per_location + super().__init__( + layer_tag_eq=layer_tag_eq, + parameters_specs=parameters_specs, + parameters_concat_axis=parameters_concat_axis, + ) + + def state_dependent_scale( + self, state: "Conv2DTNT.State" + ) -> Numeric: + return utils.tnt_scale([factor.value for factor in state.factors]) + + def x_squared_spatial_norms(self, x: Array) -> Array: + + kernel_shape = list(self.parameters_shapes[0]) + kernel_shape[self.kernel_output_axis] = 1 + + return jax.lax.conv_general_dilated( + lhs=x * x, + rhs=jnp.ones(kernel_shape), + window_strides=self.layer_tag_extra_params["window_strides"], + padding=self.layer_tag_extra_params["padding"], + lhs_dilation=self.layer_tag_extra_params["lhs_dilation"], + rhs_dilation=self.layer_tag_extra_params["rhs_dilation"], + dimension_numbers=self.layer_tag_extra_params["dimension_numbers"], + feature_group_count=self.layer_tag_extra_params["feature_group_count"], + precision=self.layer_tag_extra_params["precision"], + preferred_element_type= + self.layer_tag_extra_params["preferred_element_type"], + ) + + @utils.auto_scope_method + def update_curvature_matrix_estimate( + self, + state: KroneckerFactored.State, + estimation_data: tracer.LayerVjpData[Array], + ema_old: Numeric, + ema_new: Numeric, + identity_weight: Numeric, + batch_size: Numeric, + ) -> KroneckerFactored.State: + del identity_weight + + # Copy this first since we mutate it later in this function. + state = state.copy() + + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs + + assert utils.first_dim_is_size(batch_size, x, dy) + + # We multiply each x by the norm_y, and each dy by the norm of x + dy_sq_norms = jnp.sum(dy * dy, axis=self.outputs_channel_index) + x_sq_norms = self.x_squared_spatial_norms(x) + + if self.number_of_parameters == 2: + # When we have a bias we need to add 1 coming from it to the squared norm + x_sq_norms = x_sq_norms + 1 + + if not self.weighting_per_location: + dy_sq_norms = jnp.sum(dy_sq_norms, axis=[1, 2]) + x_sq_norms = jnp.sum(x_sq_norms, axis=[1, 2, 3], keepdims=True) + + input_cov = self.compute_inputs_stats(x, weighting_array=dy_sq_norms) + output_cov = self.compute_outputs_stats(dy * jnp.sqrt(x_sq_norms)) + + state.factors[0].update(input_cov, ema_old, ema_new) + state.factors[1].update(output_cov, ema_old, ema_new) + + return state + + # _____ _ _ _____ _ _ __ _ # / ____| | | /\ | |/ ____| | (_)/ _| | # | (___ ___ __ _| | ___ / \ _ __ __| | (___ | |__ _| |_| |_ diff --git a/kfac_jax/_src/utils/__init__.py b/kfac_jax/_src/utils/__init__.py index 0ccce37..c44cfd1 100644 --- a/kfac_jax/_src/utils/__init__.py +++ b/kfac_jax/_src/utils/__init__.py @@ -123,6 +123,7 @@ kronecker_product_mul_v = math.kronecker_product_mul_v kronecker_eigen_basis_mul_v = math.kronecker_eigen_basis_mul_v safe_psd_eigh = math.safe_psd_eigh +tnt_scale = math.tnt_scale loop_and_parallelize_average = math.loop_and_parallelize_average psd_matrix_norm = math.psd_matrix_norm invert_psd_matrices = math.invert_psd_matrices diff --git a/kfac_jax/_src/utils/math.py b/kfac_jax/_src/utils/math.py index f10a9b6..3ad4275 100644 --- a/kfac_jax/_src/utils/math.py +++ b/kfac_jax/_src/utils/math.py @@ -1001,6 +1001,22 @@ def safe_psd_eigh( return jnp.clip(s, a_min=0.0), q +def tnt_scale(factors: Sequence[Array]) -> Numeric: + """Computes the correct scaling factor for a TNT factorization.""" + + if len(factors) == 1: + return 1.0 + + # These should be the same values + zs = jnp.asarray([jnp.trace(factor) for factor in factors]) + + # We want to compute geometric_mean(zs) ** -(num_factors - 1) + + mean_log = jnp.mean(jnp.log(zs)) + + return jnp.exp(-(len(factors) - 1) * mean_log) + + def loop_and_parallelize_average( func: Callable[..., ArrayTree], max_parallel_size: int,