diff --git a/kfac_jax/_src/curvature_estimator/block_diagonal.py b/kfac_jax/_src/curvature_estimator/block_diagonal.py index 91ccb36..806dd87 100644 --- a/kfac_jax/_src/curvature_estimator/block_diagonal.py +++ b/kfac_jax/_src/curvature_estimator/block_diagonal.py @@ -42,6 +42,7 @@ conv2d=curvature_blocks.Conv2DTwoKroneckerFactored, generic=curvature_blocks.NaiveDiagonal, scale_and_shift=curvature_blocks.ScaleAndShiftDiagonal, + repeated_dense=curvature_blocks.RepeatedDenseKroneckerFactored, ) @@ -860,4 +861,3 @@ def to_diagonal_block_dense_matrix(self, state: State) -> tuple[Array, ...]: @utils.auto_scope_method def to_dense_matrix(self, state: State) -> Array: return scipy.linalg.block_diag(*self.to_diagonal_block_dense_matrix(state)) - diff --git a/kfac_jax/_src/layers_and_loss_tags.py b/kfac_jax/_src/layers_and_loss_tags.py index c9c8716..48ed380 100644 --- a/kfac_jax/_src/layers_and_loss_tags.py +++ b/kfac_jax/_src/layers_and_loss_tags.py @@ -13,11 +13,13 @@ # limitations under the License. """K-FAC losses and layers tagging Jax primitives.""" import dataclasses +import functools from typing import Any, Generic, Sequence, TypeVar import jax from jax import core + # Types for annotation T = TypeVar("T") Array = jax.Array @@ -452,6 +454,12 @@ def register_scale_and_shift( ) +register_repeated_dense = functools.partial( + register_dense, + variant="repeated_dense", +) + + class LossTagEqn(core.JaxprEqn): """A class used only for annotation purposes.""" primitive: LossTag diff --git a/kfac_jax/_src/tag_graph_matcher.py b/kfac_jax/_src/tag_graph_matcher.py index d9253ff..a478503 100644 --- a/kfac_jax/_src/tag_graph_matcher.py +++ b/kfac_jax/_src/tag_graph_matcher.py @@ -1113,19 +1113,31 @@ def _dense_parameter_extractor( assert False -def _make_dense_pattern( +def _make_general_dense_pattern( with_bias: bool, reshape: bool, + num_repeated_axes: int, in_dim: int = 13, out_dim: int = 7, ) -> GraphPattern: - x_shape = [2, in_dim] + """Creates a pattern for a dense or repeated dense layer.""" + in_axes = [0, [None, None]] if with_bias else [0, [None]] + f = _dense_with_reshape if reshape else _dense + for _ in range(num_repeated_axes): + f = jax.vmap(f, in_axes=in_axes) + + x_shape = [9] * num_repeated_axes + [2, in_dim] p_shapes = ([[in_dim, out_dim], [out_dim]] if with_bias else [[in_dim, out_dim]]) + + name = "dense_with_bias" if with_bias else "dense_no_bias", + if num_repeated_axes > 0: + name = f"repeated[{num_repeated_axes}]_{name}" + return GraphPattern( - name="dense_with_bias" if with_bias else "dense_no_bias", + name=name, tag_primitive=tags.layer_tag, - compute_func=_dense_with_reshape if reshape else _dense, + compute_func=f, parameters_extractor_func=_dense_parameter_extractor, example_args=[np.zeros(x_shape), [np.zeros(s) for s in p_shapes]], ) @@ -1338,9 +1350,15 @@ def _make_normalization_haiku_pattern( DEFAULT_GRAPH_PATTERNS = ( - _make_dense_pattern(True, False), - _make_dense_pattern(True, True), - _make_dense_pattern(False, False), + _make_general_dense_pattern(True, False, 0), + _make_general_dense_pattern(True, False, 1), + _make_general_dense_pattern(True, False, 2), + _make_general_dense_pattern(True, True, 0), + _make_general_dense_pattern(True, True, 1), + _make_general_dense_pattern(True, True, 2), + _make_general_dense_pattern(False, False, 0), + _make_general_dense_pattern(False, False, 1), + _make_general_dense_pattern(False, False, 2), _make_conv2d_pattern(True), _make_conv2d_pattern(False), _make_scale_and_shift_pattern(1, True, True),