Skip to content

Commit

Permalink
[kfac-jax] Update _Linear layer to support the new "algorithm" tuning…
Browse files Browse the repository at this point in the history
… parameters for dot_general that will be included in the next JAX release.

PiperOrigin-RevId: 674292750
  • Loading branch information
dfm authored and KfacJaxDev committed Sep 20, 2024
1 parent baaec40 commit 20b9ce8
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,19 @@ def __call__(self, inputs: LayerInputs, *_) -> LayerInputs: # pytype: disable=s
else:
assert all(p.dtype == y.dtype for p in params if p is not None)
preferred_element_type = y.dtype
# The dot algorithm and transpose_algorithm parameters are available (and
# required) after JAX version 0.4.33.
if jax_version <= (0, 4, 33):
algorithm_kwargs = {}
else:
algorithm_kwargs = dict(algorithm=None, transpose_algorithm=None)

y = tags.register_dense(
y, x, *params,
dimension_numbers=(((1,), (0,)), ((), ())),
precision=(jax.lax.Precision.HIGHEST, jax.lax.Precision.HIGHEST),
preferred_element_type=preferred_element_type,
**algorithm_kwargs,
)
layer_values.append((x, y))

Expand Down

0 comments on commit 20b9ce8

Please sign in to comment.