From e8a6bfe7c3dfa16fc01b7a2ef739a8a1482dee02 Mon Sep 17 00:00:00 2001 From: TF2JAXDev Date: Mon, 22 Jul 2024 03:55:51 -0700 Subject: [PATCH] Handle bad_indices_policy for GatherNd and ScatterNd PiperOrigin-RevId: 654685263 --- tf2jax/_src/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index 5205693..cc296b1 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -2249,7 +2249,7 @@ def _func(xs: jnp.ndarray, shape: jnp.ndarray) -> ArrayLike: @register_operation("TensorScatterUpdate") def _tensor_scatter_update(proto): """Parse an TensorScatterUpdate Op.""" - _check_attrs(proto, {"T", "Tindices"}) + _check_attrs(proto, {"T", "Tindices", "bad_indices_policy"}) def _func( operand: jnp.ndarray,