Skip to content

Commit

Permalink
Add LowerBound and UpperBound operation to tf2jax to support searchso…
Browse files Browse the repository at this point in the history
…rted.

PiperOrigin-RevId: 541680002
  • Loading branch information
TF2JAXDev authored and TF2JAXDev committed Jun 21, 2023
1 parent 475aeff commit 9738677
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2808,3 +2808,34 @@ def _xla_select_and_scatter(proto):
select = proto.attr["select"].func.name

return _XlaSelectAndScatter(dict(scatter=scatter, select=select))


def _searchsorted(a: jnp.ndarray, v: jnp.ndarray, side: str):
"""Vmapped version of searchsorted to implement LowerBound and UpperBound."""
return jax.vmap(
functools.partial(jnp.searchsorted, side=side),
in_axes=0,
out_axes=0,
)(a, v)


def _lower_upper_bound(proto, side: str):
"""Parse a LowerBound or UpperBound op using searchsorted."""
_check_attrs(proto, {"T", "Tvalues", "out_type"})
dtype = tf.as_dtype(proto.attr["out_type"].type)
if dtype != tf.int32:
raise ValueError(
f"Return type {dtype} not supported for LowerBound and UpperBound.")
return lambda a, v: _searchsorted(a, v, side=side)


@register_operation("LowerBound")
def _lower_bound(proto):
"""Parse a LowerBound op."""
return _lower_upper_bound(proto, side="left")


@register_operation("UpperBound")
def _upper_bound(proto):
"""Parse an UpperBound op."""
return _lower_upper_bound(proto, side="right")
21 changes: 21 additions & 0 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,6 +2283,27 @@ def var_handle():
):
self._test_convert(var_handle, [])

@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
"LowerBound",
"UpperBound",
"SearchSorted",
)
def test_lower_upper_bound(self, op_name):
np.random.seed(42)
inputs = (
np.array([[0, 1, 2, 3, 4], [-4, -3, -2, -1, 0]], dtype=np.float32),
np.array(
[[3.5, 0, 1.5, 10, -1], [-3.5, 0, -1.5, -10, 1]], dtype=np.float32)
)
if op_name == "SearchSorted":
tf_func = lambda x, y: tf.searchsorted(x, y, out_type=tf.int32)
else:
# LowerBound and UpperBound expect keyword arguments.
def tf_func(x, y):
return getattr(tf.raw_ops, op_name)(sorted_inputs=x, values=y)
self._test_convert(tf_func, inputs)


if __name__ == "__main__":
tf.test.main()

0 comments on commit 9738677

Please sign in to comment.