From 34508c54669b05de774af51c61f7e787d15b2b23 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 30 Jul 2024 05:27:23 -0700 Subject: [PATCH] Drop support for mhlo in JAX's public API. PiperOrigin-RevId: 657551590 --- tf2jax/experimental/mhlo_test.py | 85 ++++++++++++++------------------ 1 file changed, 36 insertions(+), 49 deletions(-) diff --git a/tf2jax/experimental/mhlo_test.py b/tf2jax/experimental/mhlo_test.py index a4eb91b..4ba5b8b 100644 --- a/tf2jax/experimental/mhlo_test.py +++ b/tf2jax/experimental/mhlo_test.py @@ -18,26 +18,23 @@ from absl import logging from absl.testing import absltest -from absl.testing import parameterized - import chex import jax import numpy as np - from tf2jax.experimental import mhlo -def _convert_to_mhlo(jax_fn, inputs, *, dialect): +def _convert_to_mhlo(jax_fn, inputs): lowered_forward = jax_fn.lower(*inputs) - mhlo_text = lowered_forward.as_text(dialect=dialect) + mhlo_text = lowered_forward.as_text(dialect="stablehlo") return mhlo_text -def _check_transforms(fn, inputs, *, dialect): +def _check_transforms(fn, inputs): jaxpr = jax.make_jaxpr(fn)(*inputs) logging.info(jaxpr) - mhlo_text = jax.jit(fn).lower(*inputs).as_text(dialect=dialect) + mhlo_text = jax.jit(fn).lower(*inputs).as_text(dialect="stablehlo") logging.info(mhlo_text) @@ -51,40 +48,34 @@ def _assert_all_close(self, expect_fn, actual_fn, inputs): chex.assert_trees_all_close(expect_outputs, actual_outputs) @chex.variants(with_jit=True, without_jit=True) - @parameterized.named_parameters( - ("mhlo", "mhlo"), - ("stablehlo", "stablehlo"), - ) - def test_one_input_and_one_output(self, dialect): + def test_one_input_and_one_output(self): @jax.jit def fn(x): return x * 2 + 3.14 inputs = (np.ones((3, 2), dtype=np.float32) * 10,) - mhlo_text = _convert_to_mhlo( - fn, jax.tree.map(np.zeros_like, inputs), dialect=dialect) + mhlo_text = _convert_to_mhlo(fn, jax.tree.map(np.zeros_like, inputs)) mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name="test_module") chex.assert_trees_all_close( - mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs)) + mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs) + ) def make_top_fn(sub_fn): def top_fn(x): return sub_fn(x + 8) * 10 + return top_fn expect_top_fn = make_top_fn(fn) actual_top_fn = make_top_fn( - functools.partial(mhlo.mhlo_apply, module=mhlo_module)) + functools.partial(mhlo.mhlo_apply, module=mhlo_module) + ) self._assert_all_close(expect_top_fn, actual_top_fn, inputs) - _check_transforms(actual_top_fn, inputs, dialect=dialect) + _check_transforms(actual_top_fn, inputs) @chex.variants(with_jit=True, without_jit=True) - @parameterized.named_parameters( - ("mhlo", "mhlo"), - ("stablehlo", "stablehlo"), - ) - def test_two_inputs_and_one_output(self, dialect): + def test_two_inputs_and_one_output(self): @jax.jit def fn(x, y): return x * 2 + y @@ -93,29 +84,27 @@ def fn(x, y): np.ones((3, 2), dtype=np.float32) * 10, np.ones((3, 2), dtype=np.float32) * 20, ) - mhlo_text = _convert_to_mhlo( - fn, jax.tree.map(np.zeros_like, inputs), dialect=dialect) + mhlo_text = _convert_to_mhlo(fn, jax.tree.map(np.zeros_like, inputs)) mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name="test_module") chex.assert_trees_all_close( - mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs)) + mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs) + ) def make_top_fn(sub_fn): def top_fn(x, y): return sub_fn(x + 8, y + 9) * 10 + return top_fn expect_top_fn = make_top_fn(fn) actual_top_fn = make_top_fn( - functools.partial(mhlo.mhlo_apply, module=mhlo_module)) + functools.partial(mhlo.mhlo_apply, module=mhlo_module) + ) self._assert_all_close(expect_top_fn, actual_top_fn, inputs) - _check_transforms(actual_top_fn, inputs, dialect=dialect) + _check_transforms(actual_top_fn, inputs) @chex.variants(with_jit=True, without_jit=True) - @parameterized.named_parameters( - ("mhlo", "mhlo"), - ("stablehlo", "stablehlo"), - ) - def test_two_inputs_and_two_outputs(self, dialect): + def test_two_inputs_and_two_outputs(self): @jax.jit def fn(x, y): return x * 2 + y, x + y * 3 @@ -124,51 +113,49 @@ def fn(x, y): np.ones((3, 2), dtype=np.float32) * 10, np.ones((3, 2), dtype=np.float32) * 20, ) - mhlo_text = _convert_to_mhlo( - fn, jax.tree.map(np.zeros_like, inputs), dialect=dialect) + mhlo_text = _convert_to_mhlo(fn, jax.tree.map(np.zeros_like, inputs)) mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name="test_module") chex.assert_trees_all_close( - mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs)) + mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs) + ) def make_top_fn(sub_fn): def top_fn(x, y): res0, res1 = sub_fn(x + 8, y + 9) return res0 * 10, res1 * 3.14 + return top_fn expect_top_fn = make_top_fn(fn) actual_top_fn = make_top_fn( - functools.partial(mhlo.mhlo_apply, module=mhlo_module)) + functools.partial(mhlo.mhlo_apply, module=mhlo_module) + ) self._assert_all_close(expect_top_fn, actual_top_fn, inputs) - _check_transforms(actual_top_fn, inputs, dialect=dialect) + _check_transforms(actual_top_fn, inputs) @chex.variants(with_jit=True, without_jit=True) - @parameterized.named_parameters( - ("mhlo", "mhlo"), - ("stablehlo", "stablehlo"), - ) - def test_boolean(self, dialect): + def test_boolean(self): @jax.jit def fn(x): return x > 0 - inputs = ( - np.ones((4,), dtype=np.int32) * 10, - ) - mhlo_text = _convert_to_mhlo( - fn, jax.tree.map(np.zeros_like, inputs), dialect=dialect) + inputs = (np.ones((4,), dtype=np.int32) * 10,) + mhlo_text = _convert_to_mhlo(fn, jax.tree.map(np.zeros_like, inputs)) mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name="test_module") chex.assert_trees_all_close( - mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs)) + mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs) + ) def make_top_fn(sub_fn): def top_fn(x): return sub_fn(x) * x + return top_fn expect_top_fn = make_top_fn(fn) actual_top_fn = make_top_fn( - functools.partial(mhlo.mhlo_apply, module=mhlo_module)) + functools.partial(mhlo.mhlo_apply, module=mhlo_module) + ) self._assert_all_close(expect_top_fn, actual_top_fn, inputs)