Skip to content

Commit

Permalink
[JAX] add support for gather/scatter batching dims following the new …
Browse files Browse the repository at this point in the history
…attributes in stablehlo.

This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See openxla/stablehlo#2259

PiperOrigin-RevId: 647647825
  • Loading branch information
tomnatan30 authored and TF2JAXDev committed Sep 21, 2024
1 parent db3f7d1 commit cab8b2d
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions tf2jax/_src/xla_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,24 @@ def gather_dimension_numbers_from_proto(
message) -> jax.lax.GatherDimensionNumbers:
proto = xla_data_pb2.GatherDimensionNumbers().FromString(message)
return jax.lax.GatherDimensionNumbers(
tuple(proto.offset_dims), tuple(proto.collapsed_slice_dims),
tuple(proto.start_index_map))
tuple(proto.offset_dims),
tuple(proto.collapsed_slice_dims),
tuple(proto.start_index_map),
tuple(proto.operand_batching_dims),
tuple(proto.start_indices_batching_dims),
)


def scatter_dimension_numbers_from_proto(
message) -> jax.lax.ScatterDimensionNumbers:
proto = xla_data_pb2.ScatterDimensionNumbers().FromString(message)
return jax.lax.ScatterDimensionNumbers(
tuple(proto.update_window_dims), tuple(proto.inserted_window_dims),
tuple(proto.scatter_dims_to_operand_dims))
tuple(proto.update_window_dims),
tuple(proto.inserted_window_dims),
tuple(proto.scatter_dims_to_operand_dims),
tuple(proto.input_batching_dims),
tuple(proto.scatter_indices_batching_dims),
)


def precision_config_from_proto(
Expand Down

0 comments on commit cab8b2d

Please sign in to comment.