diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index 5205693..cc296b1 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -2249,7 +2249,7 @@ def _func(xs: jnp.ndarray, shape: jnp.ndarray) -> ArrayLike: @register_operation("TensorScatterUpdate") def _tensor_scatter_update(proto): """Parse an TensorScatterUpdate Op.""" - _check_attrs(proto, {"T", "Tindices"}) + _check_attrs(proto, {"T", "Tindices", "bad_indices_policy"}) def _func( operand: jnp.ndarray,