Skip to content

Commit

Permalink
Drop support for mhlo in JAX's public API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657551590
  • Loading branch information
hawkinsp authored and TF2JAXDev committed Jul 30, 2024
1 parent 8c12d97 commit 34508c5
Showing 1 changed file with 36 additions and 49 deletions.
85 changes: 36 additions & 49 deletions tf2jax/experimental/mhlo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,23 @@

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized

import chex
import jax
import numpy as np

from tf2jax.experimental import mhlo


def _convert_to_mhlo(jax_fn, inputs, *, dialect):
def _convert_to_mhlo(jax_fn, inputs):
lowered_forward = jax_fn.lower(*inputs)
mhlo_text = lowered_forward.as_text(dialect=dialect)
mhlo_text = lowered_forward.as_text(dialect="stablehlo")
return mhlo_text


def _check_transforms(fn, inputs, *, dialect):
def _check_transforms(fn, inputs):
jaxpr = jax.make_jaxpr(fn)(*inputs)
logging.info(jaxpr)

mhlo_text = jax.jit(fn).lower(*inputs).as_text(dialect=dialect)
mhlo_text = jax.jit(fn).lower(*inputs).as_text(dialect="stablehlo")
logging.info(mhlo_text)


Expand All @@ -51,40 +48,34 @@ def _assert_all_close(self, expect_fn, actual_fn, inputs):
chex.assert_trees_all_close(expect_outputs, actual_outputs)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
("mhlo", "mhlo"),
("stablehlo", "stablehlo"),
)
def test_one_input_and_one_output(self, dialect):
def test_one_input_and_one_output(self):
@jax.jit
def fn(x):
return x * 2 + 3.14

inputs = (np.ones((3, 2), dtype=np.float32) * 10,)
mhlo_text = _convert_to_mhlo(
fn, jax.tree.map(np.zeros_like, inputs), dialect=dialect)
mhlo_text = _convert_to_mhlo(fn, jax.tree.map(np.zeros_like, inputs))
mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name="test_module")

chex.assert_trees_all_close(
mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs))
mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs)
)

def make_top_fn(sub_fn):
def top_fn(x):
return sub_fn(x + 8) * 10

return top_fn

expect_top_fn = make_top_fn(fn)
actual_top_fn = make_top_fn(
functools.partial(mhlo.mhlo_apply, module=mhlo_module))
functools.partial(mhlo.mhlo_apply, module=mhlo_module)
)
self._assert_all_close(expect_top_fn, actual_top_fn, inputs)
_check_transforms(actual_top_fn, inputs, dialect=dialect)
_check_transforms(actual_top_fn, inputs)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
("mhlo", "mhlo"),
("stablehlo", "stablehlo"),
)
def test_two_inputs_and_one_output(self, dialect):
def test_two_inputs_and_one_output(self):
@jax.jit
def fn(x, y):
return x * 2 + y
Expand All @@ -93,29 +84,27 @@ def fn(x, y):
np.ones((3, 2), dtype=np.float32) * 10,
np.ones((3, 2), dtype=np.float32) * 20,
)
mhlo_text = _convert_to_mhlo(
fn, jax.tree.map(np.zeros_like, inputs), dialect=dialect)
mhlo_text = _convert_to_mhlo(fn, jax.tree.map(np.zeros_like, inputs))
mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name="test_module")
chex.assert_trees_all_close(
mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs))
mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs)
)

def make_top_fn(sub_fn):
def top_fn(x, y):
return sub_fn(x + 8, y + 9) * 10

return top_fn

expect_top_fn = make_top_fn(fn)
actual_top_fn = make_top_fn(
functools.partial(mhlo.mhlo_apply, module=mhlo_module))
functools.partial(mhlo.mhlo_apply, module=mhlo_module)
)
self._assert_all_close(expect_top_fn, actual_top_fn, inputs)
_check_transforms(actual_top_fn, inputs, dialect=dialect)
_check_transforms(actual_top_fn, inputs)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
("mhlo", "mhlo"),
("stablehlo", "stablehlo"),
)
def test_two_inputs_and_two_outputs(self, dialect):
def test_two_inputs_and_two_outputs(self):
@jax.jit
def fn(x, y):
return x * 2 + y, x + y * 3
Expand All @@ -124,51 +113,49 @@ def fn(x, y):
np.ones((3, 2), dtype=np.float32) * 10,
np.ones((3, 2), dtype=np.float32) * 20,
)
mhlo_text = _convert_to_mhlo(
fn, jax.tree.map(np.zeros_like, inputs), dialect=dialect)
mhlo_text = _convert_to_mhlo(fn, jax.tree.map(np.zeros_like, inputs))
mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name="test_module")
chex.assert_trees_all_close(
mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs))
mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs)
)

def make_top_fn(sub_fn):
def top_fn(x, y):
res0, res1 = sub_fn(x + 8, y + 9)
return res0 * 10, res1 * 3.14

return top_fn

expect_top_fn = make_top_fn(fn)
actual_top_fn = make_top_fn(
functools.partial(mhlo.mhlo_apply, module=mhlo_module))
functools.partial(mhlo.mhlo_apply, module=mhlo_module)
)
self._assert_all_close(expect_top_fn, actual_top_fn, inputs)
_check_transforms(actual_top_fn, inputs, dialect=dialect)
_check_transforms(actual_top_fn, inputs)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
("mhlo", "mhlo"),
("stablehlo", "stablehlo"),
)
def test_boolean(self, dialect):
def test_boolean(self):
@jax.jit
def fn(x):
return x > 0

inputs = (
np.ones((4,), dtype=np.int32) * 10,
)
mhlo_text = _convert_to_mhlo(
fn, jax.tree.map(np.zeros_like, inputs), dialect=dialect)
inputs = (np.ones((4,), dtype=np.int32) * 10,)
mhlo_text = _convert_to_mhlo(fn, jax.tree.map(np.zeros_like, inputs))
mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name="test_module")
chex.assert_trees_all_close(
mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs))
mhlo.mhlo_apply(*inputs, module=mhlo_module), fn(*inputs)
)

def make_top_fn(sub_fn):
def top_fn(x):
return sub_fn(x) * x

return top_fn

expect_top_fn = make_top_fn(fn)
actual_top_fn = make_top_fn(
functools.partial(mhlo.mhlo_apply, module=mhlo_module))
functools.partial(mhlo.mhlo_apply, module=mhlo_module)
)
self._assert_all_close(expect_top_fn, actual_top_fn, inputs)


Expand Down

0 comments on commit 34508c5

Please sign in to comment.