Skip to content

Commit

Permalink
Remove usage of native_serialization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653208445
  • Loading branch information
shaobohou authored and TF2JAXDev committed Jul 22, 2024
1 parent 2d2cb8d commit fa36da7
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 144 deletions.
4 changes: 2 additions & 2 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ else
N_JOBS=$(grep -c ^processor /proc/cpuinfo)
fi

pytest -n "${N_JOBS}" --pyargs tf2jax
CHECK_CUSTOM_CALLS_TEST=0 pytest -n "${N_JOBS}" --pyargs tf2jax

# Native lowering is in active development so we test against nightly and github head.
pip uninstall --yes tensorflow
pip install tf-nightly
pip install git+https://github.com/google/jax.git
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
USE_JAX2TF_NATIVE_SERIALIZATION_IN_ROUNDTRIP_TEST=1 pytest -n "${N_JOBS}" --pyargs tf2jax._src.roundtrip_test
CHECK_CUSTOM_CALLS_TEST=0 pytest -n "${N_JOBS}" --pyargs tf2jax._src.roundtrip_test
cd ..

set +u
Expand Down
15 changes: 0 additions & 15 deletions tf2jax/_src/numpy_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from absl.testing import absltest
from absl.testing import parameterized

import jax
from jax.experimental import jax2tf
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -46,19 +44,6 @@ def test_dtype_conversion(self, np_module, dtype_map):
else:
self.assertIs(dtype_map[src], getattr(np_module, dst))

def test_is_poly_dim(self):
@jax.jit
def jax_func(x):
self.assertTrue(numpy_compat.is_poly_dim(x.shape[0]))
self.assertEqual(x.shape[1], 4)
return x + 3.14

tf_func = jax2tf.convert(
jax_func, polymorphic_shapes=["(b, _)"], native_serialization=False
)
tf_forward = tf.function(tf_func, autograph=False)
tf_forward.get_concrete_function(tf.TensorSpec(shape=(None, 4)))


if __name__ == "__main__":
absltest.main()
145 changes: 31 additions & 114 deletions tf2jax/_src/roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def _bool_env(varname: str, value: bool) -> bool:
return {"1": True, "0": False, "True": True, "False": False}[val]


_NATIVE_SERIALIZATION = flags.DEFINE_bool(
"use_jax2tf_native_serialization_in_roundtrip_test",
_bool_env("USE_JAX2TF_NATIVE_SERIALIZATION_IN_ROUNDTRIP_TEST", False),
"Whether to call jax2tf.convert with native serialization enabled.",
_CHECK_CUSTOM_CALLS = flags.DEFINE_bool(
"check_custom_calls_test",
_bool_env("CHECK_CUSTOM_CALLS_TEST", True),
"Whether to enable custom calls test.",
)


Expand All @@ -56,29 +56,8 @@ def fn(*args):
return jax.grad(lambda *args: jnp.sum(fn(*args)))(*inputs)


def _jax2tf_convert(func, **kwargs):
return jax2tf.convert(
func, native_serialization=_NATIVE_SERIALIZATION.value, **kwargs
)


def uses_native_serialization():
return _NATIVE_SERIALIZATION.value


class Jax2TfTest(test_util.TestCase):

def setUp(self):
super().setUp()
if not uses_native_serialization():
self._xla_op = tf2jax.ops._jax_ops.pop("XlaCallModule", None)

def tearDown(self):
super().tearDown()
if not uses_native_serialization():
if self._xla_op is not None:
tf2jax.ops._jax_ops["XlaCallModule"] = self._xla_op

def _test_convert(
self,
jax_func,
Expand All @@ -88,12 +67,11 @@ def _test_convert(
with_custom_grad=False,
grad_tols=None,
):
if uses_native_serialization():
if with_grad and not with_custom_grad:
self.skipTest(
"native_serialization does not support differentiation without "
"custom gradient."
)
if with_grad and not with_custom_grad:
self.skipTest(
"native_serialization does not support differentiation without "
"custom gradient."
)

grad_tols = grad_tols or {}

Expand All @@ -107,7 +85,7 @@ def assert_grad_all_close(*args):
jax_grads = _compute_gradients(jax_func, *inputs)

# Jax -> TF
tf_func = _jax2tf_convert(jax_func, with_gradient=with_grad)
tf_func = jax2tf.convert(jax_func, with_gradient=with_grad)
tf_func = tf.function(tf_func, jit_compile=True, autograph=False)
tf_outputs = tf_func(*inputs)
jax.tree.map(self.assertAllClose, jax_outputs, tf_outputs)
Expand Down Expand Up @@ -520,13 +498,6 @@ def forward(x):
with config.override_config(
"infer_cumulative_reduction_from_jax2tf", use_heuristic
):
roundtrip_forward = tf2jax.convert_functional(
tf.function(_jax2tf_convert(forward), autograph=False), inputs
)
roundtrip_jaxpr = jax.make_jaxpr(roundtrip_forward)(inputs)
if use_heuristic and not uses_native_serialization():
self.assertNotIn("reduce_window", roundtrip_jaxpr.pretty_print())

if with_grad and reducer is jax.lax.cumprod and not use_heuristic:
self.skipTest(
"No differentiation rule for `reduce_window` with "
Expand Down Expand Up @@ -689,7 +660,7 @@ def forward(x):
self.assertAllClose(tf_outputs, jax_outputs)

# TF -> JAX -> TF
new_tf_forward = _jax2tf_convert(
new_tf_forward = jax2tf.convert(
jax_func, polymorphic_shapes=["(b, _)"], with_gradient=with_grad
)
new_tf_forward = tf.function(new_tf_forward, autograph=False)
Expand Down Expand Up @@ -719,7 +690,7 @@ def forward(x, w):
expected_outputs = forward(x, w)

# JAX -> TF
tf_fn = _jax2tf_convert(
tf_fn = jax2tf.convert(
forward, polymorphic_shapes=["(b, _)", None], with_gradient=with_grad
)
tf_fn = tf.function(tf_fn, autograph=False)
Expand All @@ -737,7 +708,7 @@ def forward(x, w):
self.assertAllClose(expected_outputs, jax_outputs)

# JAX -> TF -> JAX -> TF
tf_fn2 = _jax2tf_convert(
tf_fn2 = jax2tf.convert(
jax_fn, polymorphic_shapes=["(b, _)", None], with_gradient=with_grad
)
tf_fn2 = tf.function(tf_fn2, autograph=False)
Expand Down Expand Up @@ -771,7 +742,7 @@ def forward(x, y):
expected_outputs = forward(x, y)

# JAX -> TF
tf_fn = _jax2tf_convert(
tf_fn = jax2tf.convert(
forward,
polymorphic_shapes=["(b, _)", "(_, b, _)"],
with_gradient=with_grad,
Expand All @@ -791,7 +762,7 @@ def forward(x, y):
self.assertAllClose(expected_outputs, jax_outputs)

# JAX -> TF -> JAX -> TF
tf_fn2 = _jax2tf_convert(
tf_fn2 = jax2tf.convert(
jax_fn,
polymorphic_shapes=["(b, _)", "(_, b, _)"],
with_gradient=with_grad,
Expand Down Expand Up @@ -836,7 +807,7 @@ def grad(dy):
expected_grads = jax.grad(forward)(inputs)

# JAX -> TF
tf_forward = _jax2tf_convert(forward, with_gradient=with_grad)
tf_forward = jax2tf.convert(forward, with_gradient=with_grad)
tf_forward = tf.function(tf_forward, autograph=False)

# JAX -> TF -> JAX
Expand Down Expand Up @@ -920,13 +891,13 @@ def grad(dy):
expected_grads = jax.grad(forward)(inputs)

# JAX -> TF
tf_fn = _jax2tf_convert(forward, with_gradient=with_grad)
tf_fn = jax2tf.convert(forward, with_gradient=with_grad)
tf_fn = tf.function(tf_fn, autograph=False)

# JAX -> TF -> CALL_TF -> TF.
# This creates dependencies between custom gradients.
call_tf_fn = jax2tf.call_tf(tf_fn)
tf_fn_too = _jax2tf_convert(call_tf_fn, with_gradient=with_grad)
tf_fn_too = jax2tf.convert(call_tf_fn, with_gradient=with_grad)
tf_fn_too = tf.function(tf_fn_too, autograph=False)

# JAX -> TF -> CALL_TF -> TF -> JAX
Expand Down Expand Up @@ -978,7 +949,7 @@ def forward(x):
jax.grad(forward)(inputs)

# JAX -> TF
tf_forward = _jax2tf_convert(forward, with_gradient=True)
tf_forward = jax2tf.convert(forward, with_gradient=True)
tf_forward = tf.function(tf_forward, autograph=False)

# JAX -> TF -> JAX
Expand Down Expand Up @@ -1023,46 +994,12 @@ def tf2jax_fn(x):
jax_fn = tf2jax.convert_functional(tf_fn, x)
return jnp.sin(self.variant(jax_fn)(x))

tf2jax2tf_fn = _jax2tf_convert(tf2jax_fn)
tf2jax2tf_fn = jax2tf.convert(tf2jax_fn)
tf2jax2tf_fn = tf.function(tf2jax2tf_fn, autograph=False)

inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3))
self.assertAllClose(tf.sin(tf_fn(inputs)), tf2jax2tf_fn(inputs))

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
chex.params_product(
(("without_gradient", False), ("with_gradient", True)),
named=True,
)
)
def test_remat(self, with_gradient):
def fn(x):
return jnp.sin(jnp.sin(x))

remat_fn = jax.checkpoint(fn)

inputs = (np.linspace(0, 1, 10 * 5, dtype=np.float32).reshape(10, 5),)
self._test_convert(
remat_fn, inputs, with_grad=with_gradient, with_custom_grad=True
)

if uses_native_serialization():
self.skipTest("Skip remat jaxpr test with native_serialization.")

# Check jaxpr.
tf_fn = tf.function(
_jax2tf_convert(remat_fn, with_gradient=True), autograph=False
)
jax_fn = tf2jax.convert_functional(
tf_fn, tf.TensorSpec((10, 5), tf.float32)
)
jax_fn = self.variant(jax_fn)
out_jaxpr = jax.make_jaxpr(jax_fn)(*inputs)
self.assertNotRegex(str(out_jaxpr), "remat")
grad_jaxpr = jax.make_jaxpr(jax.grad(lambda x: jnp.sum(jax_fn(x))))(*inputs)
self.assertRegex(str(grad_jaxpr), "optimization_barrier")

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
chex.params_product(
Expand Down Expand Up @@ -1121,10 +1058,10 @@ def test_triangular_solve(
unit_diagonal,
shapes,
):
if uses_native_serialization():
if not _CHECK_CUSTOM_CALLS.value:
self.skipTest(
"native_serialization: Cannot serialize code with custom calls whose "
"targets have no compatibility guarantees."
"Disable tests with custom calls whose targets have no compatibility"
" guarantees."
)

np.random.seed(42)
Expand All @@ -1150,35 +1087,19 @@ def forward(a, b):
grad_tols=tols,
)

def test_explicit_native_serialization(self):
def test_platform_check(self):
def forward(x):
return x + 3.14

tf_fn = jax2tf.convert(forward, native_serialization=True)
tf_fn = jax2tf.convert(forward)
tf_fn = tf.function(tf_fn, autograph=False)

try:
jax_fn = tf2jax.convert_functional(
tf_fn, tf.TensorSpec((2, 3), tf.float32)
)
jax_fn = jax.jit(jax_fn)
inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3))
self.assertAllClose(jax_fn(inputs), tf_fn(inputs))
except ValueError as e:
if uses_native_serialization():
raise ValueError("Native lowering support failed.") from e
elif r"Unsupported operations in graph: ['XlaCallModule']" not in str(e):
raise ValueError("Unexpected unsupported operations found.") from e
elif r"check for import error if you see this message" not in str(e):
raise ValueError("Import/dependency error message not found.") from e
else: # Expected failure.
return

if not uses_native_serialization():
raise ValueError(
"Unexpected success with native serialization. Test may be "
"misconfigured."
)
jax_fn = tf2jax.convert_functional(
tf_fn, tf.TensorSpec((2, 3), tf.float32)
)
jax_fn = jax.jit(jax_fn)
inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3))
self.assertAllClose(jax_fn(inputs), tf_fn(inputs))

if jax.default_backend().lower() != "cpu":
with jax.default_device(jax.local_devices(backend="cpu")[0]):
Expand All @@ -1191,9 +1112,6 @@ def forward(x):

@chex.variants(with_jit=True, without_jit=True)
def test_platform_index(self):
if not uses_native_serialization():
self.skipTest("Skip platform_index test without native serialization.")

@jax.jit
def forward(x):
return jnp.tile(x**2, 2).reshape((2, *x.shape))
Expand All @@ -1203,7 +1121,6 @@ def forward(x):

func = jax2tf.convert(
forward,
native_serialization=True,
polymorphic_shapes=("(B, d * 128)",),
native_serialization_platforms=("cuda", "cpu", "tpu"),
)
Expand Down
15 changes: 2 additions & 13 deletions tf2jax/_src/sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
"""Tests for JAX -> TF -> JAX with partitioning."""

from absl.testing import absltest
from absl.testing import parameterized
import chex
import haiku as hk
import jax
from jax.experimental import jax2tf
Expand Down Expand Up @@ -49,13 +47,7 @@ def _get_param_pspecs():

class ShardingTest(test_util.TestCase):

@parameterized.named_parameters(
chex.params_product(
(('native_serialization', True), ('graph_serialization', False)),
named=True,
)
)
def test_sharding(self, native_serialization):
def test_sharding(self):
if jax.default_backend().upper() != 'TPU':
self.skipTest('Only run sharding tests on TPU.')

Expand Down Expand Up @@ -107,10 +99,7 @@ def partitioned_grad(params, xs):
# Convert to TF and save.
@tf.function(autograph=False, jit_compile=True)
def tf_fn(params, inputs):
return jax2tf.convert(
partitioned_apply,
native_serialization=native_serialization,
)(params, inputs)
return jax2tf.convert(partitioned_apply)(params, inputs)

tf_fn(params, images)
module = tf.Module()
Expand Down

0 comments on commit fa36da7

Please sign in to comment.