Skip to content

Commit

Permalink
Handle bad_indices_policy for GatherNd and ScatterNd
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654685263
  • Loading branch information
TF2JAXDev authored and TF2JAXDev committed Jul 22, 2024
1 parent 12d8393 commit 5214129
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2249,7 +2249,7 @@ def _func(xs: jnp.ndarray, shape: jnp.ndarray) -> ArrayLike:
@register_operation("TensorScatterUpdate")
def _tensor_scatter_update(proto):
"""Parse an TensorScatterUpdate Op."""
_check_attrs(proto, {"T", "Tindices"})
_check_attrs(proto, {"T", "Tindices", "bad_indices_policy"})

def _func(
operand: jnp.ndarray,
Expand Down

0 comments on commit 5214129

Please sign in to comment.