diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index c08f3e6..9e93cff 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -2808,3 +2808,34 @@ def _xla_select_and_scatter(proto): select = proto.attr["select"].func.name return _XlaSelectAndScatter(dict(scatter=scatter, select=select)) + + +def _searchsorted(a: jnp.ndarray, v: jnp.ndarray, side: str): + """Vmapped version of searchsorted to implement LowerBound and UpperBound.""" + return jax.vmap( + functools.partial(jnp.searchsorted, side=side), + in_axes=0, + out_axes=0, + )(a, v) + + +def _lower_upper_bound(proto, side: str): + """Parse a LowerBound or UpperBound op using searchsorted.""" + _check_attrs(proto, {"T", "Tvalues", "out_type"}) + dtype = tf.as_dtype(proto.attr["out_type"].type) + if dtype != tf.int32: + raise ValueError( + f"Return type {dtype} not supported for LowerBound and UpperBound.") + return lambda a, v: _searchsorted(a, v, side=side) + + +@register_operation("LowerBound") +def _lower_bound(proto): + """Parse a LowerBound op.""" + return _lower_upper_bound(proto, side="left") + + +@register_operation("UpperBound") +def _upper_bound(proto): + """Parse an UpperBound op.""" + return _lower_upper_bound(proto, side="right") diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index cc1495f..5f5e6bd 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -2283,6 +2283,27 @@ def var_handle(): ): self._test_convert(var_handle, []) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.parameters( + "LowerBound", + "UpperBound", + "SearchSorted", + ) + def test_lower_upper_bound(self, op_name): + np.random.seed(42) + inputs = ( + np.array([[0, 1, 2, 3, 4], [-4, -3, -2, -1, 0]], dtype=np.float32), + np.array( + [[3.5, 0, 1.5, 10, -1], [-3.5, 0, -1.5, -10, 1]], dtype=np.float32) + ) + if op_name == "SearchSorted": + tf_func = lambda x, y: tf.searchsorted(x, y, out_type=tf.int32) + else: + # LowerBound and UpperBound expect keyword arguments. + def tf_func(x, y): + return getattr(tf.raw_ops, op_name)(sorted_inputs=x, values=y) + self._test_convert(tf_func, inputs) + if __name__ == "__main__": tf.test.main()