Skip to content

Commit

Permalink
Adding the repeated dense graph patterns.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668823269
  • Loading branch information
botev authored and KfacJaxDev committed Aug 29, 2024
1 parent 14dc8fa commit 5d89401
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
2 changes: 1 addition & 1 deletion kfac_jax/_src/curvature_estimator/block_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
conv2d=curvature_blocks.Conv2DTwoKroneckerFactored,
generic=curvature_blocks.NaiveDiagonal,
scale_and_shift=curvature_blocks.ScaleAndShiftDiagonal,
repeated_dense=curvature_blocks.RepeatedDenseKroneckerFactored,
)


Expand Down Expand Up @@ -860,4 +861,3 @@ def to_diagonal_block_dense_matrix(self, state: State) -> tuple[Array, ...]:
@utils.auto_scope_method
def to_dense_matrix(self, state: State) -> Array:
return scipy.linalg.block_diag(*self.to_diagonal_block_dense_matrix(state))

8 changes: 8 additions & 0 deletions kfac_jax/_src/layers_and_loss_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
"""K-FAC losses and layers tagging Jax primitives."""
import dataclasses
import functools
from typing import Any, Generic, Sequence, TypeVar

import jax
from jax import core


# Types for annotation
T = TypeVar("T")
Array = jax.Array
Expand Down Expand Up @@ -452,6 +454,12 @@ def register_scale_and_shift(
)


register_repeated_dense = functools.partial(
register_dense,
variant="repeated_dense",
)


class LossTagEqn(core.JaxprEqn):
"""A class used only for annotation purposes."""
primitive: LossTag
Expand Down
32 changes: 25 additions & 7 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,19 +1113,31 @@ def _dense_parameter_extractor(
assert False


def _make_dense_pattern(
def _make_general_dense_pattern(
with_bias: bool,
reshape: bool,
num_repeated_axes: int,
in_dim: int = 13,
out_dim: int = 7,
) -> GraphPattern:
x_shape = [2, in_dim]
"""Creates a pattern for a dense or repeated dense layer."""
in_axes = [0, [None, None]] if with_bias else [0, [None]]
f = _dense_with_reshape if reshape else _dense
for _ in range(num_repeated_axes):
f = jax.vmap(f, in_axes=in_axes)

x_shape = [9] * num_repeated_axes + [2, in_dim]
p_shapes = ([[in_dim, out_dim], [out_dim]] if with_bias else
[[in_dim, out_dim]])

name = "dense_with_bias" if with_bias else "dense_no_bias",
if num_repeated_axes > 0:
name = f"repeated[{num_repeated_axes}]_{name}"

return GraphPattern(
name="dense_with_bias" if with_bias else "dense_no_bias",
name=name,
tag_primitive=tags.layer_tag,
compute_func=_dense_with_reshape if reshape else _dense,
compute_func=f,
parameters_extractor_func=_dense_parameter_extractor,
example_args=[np.zeros(x_shape), [np.zeros(s) for s in p_shapes]],
)
Expand Down Expand Up @@ -1338,9 +1350,15 @@ def _make_normalization_haiku_pattern(


DEFAULT_GRAPH_PATTERNS = (
_make_dense_pattern(True, False),
_make_dense_pattern(True, True),
_make_dense_pattern(False, False),
_make_general_dense_pattern(True, False, 0),
_make_general_dense_pattern(True, False, 1),
_make_general_dense_pattern(True, False, 2),
_make_general_dense_pattern(True, True, 0),
_make_general_dense_pattern(True, True, 1),
_make_general_dense_pattern(True, True, 2),
_make_general_dense_pattern(False, False, 0),
_make_general_dense_pattern(False, False, 1),
_make_general_dense_pattern(False, False, 2),
_make_conv2d_pattern(True),
_make_conv2d_pattern(False),
_make_scale_and_shift_pattern(1, True, True),
Expand Down

0 comments on commit 5d89401

Please sign in to comment.