Skip to content

Commit

Permalink
Remove usage of native_serialization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653208445
  • Loading branch information
shaobohou authored and TF2JAXDev committed Jul 22, 2024
1 parent 2d2cb8d commit 516bb65
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 143 deletions.
4 changes: 2 additions & 2 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ else
N_JOBS=$(grep -c ^processor /proc/cpuinfo)
fi

pytest -n "${N_JOBS}" --pyargs tf2jax
CHECK_CUSTOM_CALLS_TEST=0 pytest -n "${N_JOBS}" --pyargs tf2jax

# Native lowering is in active development so we test against nightly and github head.
pip uninstall --yes tensorflow
pip install tf-nightly
pip install git+https://github.com/google/jax.git
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
USE_JAX2TF_NATIVE_SERIALIZATION_IN_ROUNDTRIP_TEST=1 pytest -n "${N_JOBS}" --pyargs tf2jax._src.roundtrip_test
CHECK_CUSTOM_CALLS_TEST=0 pytest -n "${N_JOBS}" --pyargs tf2jax._src.roundtrip_test
cd ..

set +u
Expand Down
21 changes: 7 additions & 14 deletions tf2jax/_src/numpy_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from absl.testing import absltest
from absl.testing import parameterized

import jax
from jax.experimental import jax2tf
from jax import export
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -46,18 +45,12 @@ def test_dtype_conversion(self, np_module, dtype_map):
else:
self.assertIs(dtype_map[src], getattr(np_module, dst))

def test_is_poly_dim(self):
@jax.jit
def jax_func(x):
self.assertTrue(numpy_compat.is_poly_dim(x.shape[0]))
self.assertEqual(x.shape[1], 4)
return x + 3.14

tf_func = jax2tf.convert(
jax_func, polymorphic_shapes=["(b, _)"], native_serialization=False
)
tf_forward = tf.function(tf_func, autograph=False)
tf_forward.get_concrete_function(tf.TensorSpec(shape=(None, 4)))
def test_poly_dim(self):
dims = export.symbolic_shape("(h*2, w, 2, _)", like=[None, None, 2, 7])
self.assertTrue(numpy_compat.is_poly_dim(dims[0]))
self.assertTrue(numpy_compat.is_poly_dim(dims[1]))
self.assertFalse(numpy_compat.is_poly_dim(dims[2]))
self.assertFalse(numpy_compat.is_poly_dim(dims[3]))


if __name__ == "__main__":
Expand Down
145 changes: 31 additions & 114 deletions tf2jax/_src/roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def _bool_env(varname: str, value: bool) -> bool:
return {"1": True, "0": False, "True": True, "False": False}[val]


_NATIVE_SERIALIZATION = flags.DEFINE_bool(
"use_jax2tf_native_serialization_in_roundtrip_test",
_bool_env("USE_JAX2TF_NATIVE_SERIALIZATION_IN_ROUNDTRIP_TEST", False),
"Whether to call jax2tf.convert with native serialization enabled.",
_CHECK_CUSTOM_CALLS = flags.DEFINE_bool(
"check_custom_calls_test",
_bool_env("CHECK_CUSTOM_CALLS_TEST", True),
"Whether to enable custom calls test.",
)


Expand All @@ -56,29 +56,8 @@ def fn(*args):
return jax.grad(lambda *args: jnp.sum(fn(*args)))(*inputs)


def _jax2tf_convert(func, **kwargs):
return jax2tf.convert(
func, native_serialization=_NATIVE_SERIALIZATION.value, **kwargs
)


def uses_native_serialization():
return _NATIVE_SERIALIZATION.value


class Jax2TfTest(test_util.TestCase):

def setUp(self):
super().setUp()
if not uses_native_serialization():
self._xla_op = tf2jax.ops._jax_ops.pop("XlaCallModule", None)

def tearDown(self):
super().tearDown()
if not uses_native_serialization():
if self._xla_op is not None:
tf2jax.ops._jax_ops["XlaCallModule"] = self._xla_op

def _test_convert(
self,
jax_func,
Expand All @@ -88,12 +67,11 @@ def _test_convert(
with_custom_grad=False,
grad_tols=None,
):
if uses_native_serialization():
if with_grad and not with_custom_grad:
self.skipTest(
"native_serialization does not support differentiation without "
"custom gradient."
)
if with_grad and not with_custom_grad:
self.skipTest(
"native_serialization does not support differentiation without "
"custom gradient."
)

grad_tols = grad_tols or {}

Expand All @@ -107,7 +85,7 @@ def assert_grad_all_close(*args):
jax_grads = _compute_gradients(jax_func, *inputs)

# Jax -> TF
tf_func = _jax2tf_convert(jax_func, with_gradient=with_grad)
tf_func = jax2tf.convert(jax_func, with_gradient=with_grad)
tf_func = tf.function(tf_func, jit_compile=True, autograph=False)
tf_outputs = tf_func(*inputs)
jax.tree.map(self.assertAllClose, jax_outputs, tf_outputs)
Expand Down Expand Up @@ -520,13 +498,6 @@ def forward(x):
with config.override_config(
"infer_cumulative_reduction_from_jax2tf", use_heuristic
):
roundtrip_forward = tf2jax.convert_functional(
tf.function(_jax2tf_convert(forward), autograph=False), inputs
)
roundtrip_jaxpr = jax.make_jaxpr(roundtrip_forward)(inputs)
if use_heuristic and not uses_native_serialization():
self.assertNotIn("reduce_window", roundtrip_jaxpr.pretty_print())

if with_grad and reducer is jax.lax.cumprod and not use_heuristic:
self.skipTest(
"No differentiation rule for `reduce_window` with "
Expand Down Expand Up @@ -689,7 +660,7 @@ def forward(x):
self.assertAllClose(tf_outputs, jax_outputs)

# TF -> JAX -> TF
new_tf_forward = _jax2tf_convert(
new_tf_forward = jax2tf.convert(
jax_func, polymorphic_shapes=["(b, _)"], with_gradient=with_grad
)
new_tf_forward = tf.function(new_tf_forward, autograph=False)
Expand Down Expand Up @@ -719,7 +690,7 @@ def forward(x, w):
expected_outputs = forward(x, w)

# JAX -> TF
tf_fn = _jax2tf_convert(
tf_fn = jax2tf.convert(
forward, polymorphic_shapes=["(b, _)", None], with_gradient=with_grad
)
tf_fn = tf.function(tf_fn, autograph=False)
Expand All @@ -737,7 +708,7 @@ def forward(x, w):
self.assertAllClose(expected_outputs, jax_outputs)

# JAX -> TF -> JAX -> TF
tf_fn2 = _jax2tf_convert(
tf_fn2 = jax2tf.convert(
jax_fn, polymorphic_shapes=["(b, _)", None], with_gradient=with_grad
)
tf_fn2 = tf.function(tf_fn2, autograph=False)
Expand Down Expand Up @@ -771,7 +742,7 @@ def forward(x, y):
expected_outputs = forward(x, y)

# JAX -> TF
tf_fn = _jax2tf_convert(
tf_fn = jax2tf.convert(
forward,
polymorphic_shapes=["(b, _)", "(_, b, _)"],
with_gradient=with_grad,
Expand All @@ -791,7 +762,7 @@ def forward(x, y):
self.assertAllClose(expected_outputs, jax_outputs)

# JAX -> TF -> JAX -> TF
tf_fn2 = _jax2tf_convert(
tf_fn2 = jax2tf.convert(
jax_fn,
polymorphic_shapes=["(b, _)", "(_, b, _)"],
with_gradient=with_grad,
Expand Down Expand Up @@ -836,7 +807,7 @@ def grad(dy):
expected_grads = jax.grad(forward)(inputs)

# JAX -> TF
tf_forward = _jax2tf_convert(forward, with_gradient=with_grad)
tf_forward = jax2tf.convert(forward, with_gradient=with_grad)
tf_forward = tf.function(tf_forward, autograph=False)

# JAX -> TF -> JAX
Expand Down Expand Up @@ -920,13 +891,13 @@ def grad(dy):
expected_grads = jax.grad(forward)(inputs)

# JAX -> TF
tf_fn = _jax2tf_convert(forward, with_gradient=with_grad)
tf_fn = jax2tf.convert(forward, with_gradient=with_grad)
tf_fn = tf.function(tf_fn, autograph=False)

# JAX -> TF -> CALL_TF -> TF.
# This creates dependencies between custom gradients.
call_tf_fn = jax2tf.call_tf(tf_fn)
tf_fn_too = _jax2tf_convert(call_tf_fn, with_gradient=with_grad)
tf_fn_too = jax2tf.convert(call_tf_fn, with_gradient=with_grad)
tf_fn_too = tf.function(tf_fn_too, autograph=False)

# JAX -> TF -> CALL_TF -> TF -> JAX
Expand Down Expand Up @@ -978,7 +949,7 @@ def forward(x):
jax.grad(forward)(inputs)

# JAX -> TF
tf_forward = _jax2tf_convert(forward, with_gradient=True)
tf_forward = jax2tf.convert(forward, with_gradient=True)
tf_forward = tf.function(tf_forward, autograph=False)

# JAX -> TF -> JAX
Expand Down Expand Up @@ -1023,46 +994,12 @@ def tf2jax_fn(x):
jax_fn = tf2jax.convert_functional(tf_fn, x)
return jnp.sin(self.variant(jax_fn)(x))

tf2jax2tf_fn = _jax2tf_convert(tf2jax_fn)
tf2jax2tf_fn = jax2tf.convert(tf2jax_fn)
tf2jax2tf_fn = tf.function(tf2jax2tf_fn, autograph=False)

inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3))
self.assertAllClose(tf.sin(tf_fn(inputs)), tf2jax2tf_fn(inputs))

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
chex.params_product(
(("without_gradient", False), ("with_gradient", True)),
named=True,
)
)
def test_remat(self, with_gradient):
def fn(x):
return jnp.sin(jnp.sin(x))

remat_fn = jax.checkpoint(fn)

inputs = (np.linspace(0, 1, 10 * 5, dtype=np.float32).reshape(10, 5),)
self._test_convert(
remat_fn, inputs, with_grad=with_gradient, with_custom_grad=True
)

if uses_native_serialization():
self.skipTest("Skip remat jaxpr test with native_serialization.")

# Check jaxpr.
tf_fn = tf.function(
_jax2tf_convert(remat_fn, with_gradient=True), autograph=False
)
jax_fn = tf2jax.convert_functional(
tf_fn, tf.TensorSpec((10, 5), tf.float32)
)
jax_fn = self.variant(jax_fn)
out_jaxpr = jax.make_jaxpr(jax_fn)(*inputs)
self.assertNotRegex(str(out_jaxpr), "remat")
grad_jaxpr = jax.make_jaxpr(jax.grad(lambda x: jnp.sum(jax_fn(x))))(*inputs)
self.assertRegex(str(grad_jaxpr), "optimization_barrier")

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
chex.params_product(
Expand Down Expand Up @@ -1121,10 +1058,10 @@ def test_triangular_solve(
unit_diagonal,
shapes,
):
if uses_native_serialization():
if not _CHECK_CUSTOM_CALLS.value:
self.skipTest(
"native_serialization: Cannot serialize code with custom calls whose "
"targets have no compatibility guarantees."
"Disable tests with custom calls whose targets have no compatibility"
" guarantees."
)

np.random.seed(42)
Expand All @@ -1150,35 +1087,19 @@ def forward(a, b):
grad_tols=tols,
)

def test_explicit_native_serialization(self):
def test_platform_check(self):
def forward(x):
return x + 3.14

tf_fn = jax2tf.convert(forward, native_serialization=True)
tf_fn = jax2tf.convert(forward)
tf_fn = tf.function(tf_fn, autograph=False)

try:
jax_fn = tf2jax.convert_functional(
tf_fn, tf.TensorSpec((2, 3), tf.float32)
)
jax_fn = jax.jit(jax_fn)
inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3))
self.assertAllClose(jax_fn(inputs), tf_fn(inputs))
except ValueError as e:
if uses_native_serialization():
raise ValueError("Native lowering support failed.") from e
elif r"Unsupported operations in graph: ['XlaCallModule']" not in str(e):
raise ValueError("Unexpected unsupported operations found.") from e
elif r"check for import error if you see this message" not in str(e):
raise ValueError("Import/dependency error message not found.") from e
else: # Expected failure.
return

if not uses_native_serialization():
raise ValueError(
"Unexpected success with native serialization. Test may be "
"misconfigured."
)
jax_fn = tf2jax.convert_functional(
tf_fn, tf.TensorSpec((2, 3), tf.float32)
)
jax_fn = jax.jit(jax_fn)
inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3))
self.assertAllClose(jax_fn(inputs), tf_fn(inputs))

if jax.default_backend().lower() != "cpu":
with jax.default_device(jax.local_devices(backend="cpu")[0]):
Expand All @@ -1191,9 +1112,6 @@ def forward(x):

@chex.variants(with_jit=True, without_jit=True)
def test_platform_index(self):
if not uses_native_serialization():
self.skipTest("Skip platform_index test without native serialization.")

@jax.jit
def forward(x):
return jnp.tile(x**2, 2).reshape((2, *x.shape))
Expand All @@ -1203,7 +1121,6 @@ def forward(x):

func = jax2tf.convert(
forward,
native_serialization=True,
polymorphic_shapes=("(B, d * 128)",),
native_serialization_platforms=("cuda", "cpu", "tpu"),
)
Expand Down
15 changes: 2 additions & 13 deletions tf2jax/_src/sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
"""Tests for JAX -> TF -> JAX with partitioning."""

from absl.testing import absltest
from absl.testing import parameterized
import chex
import haiku as hk
import jax
from jax.experimental import jax2tf
Expand Down Expand Up @@ -49,13 +47,7 @@ def _get_param_pspecs():

class ShardingTest(test_util.TestCase):

@parameterized.named_parameters(
chex.params_product(
(('native_serialization', True), ('graph_serialization', False)),
named=True,
)
)
def test_sharding(self, native_serialization):
def test_sharding(self):
if jax.default_backend().upper() != 'TPU':
self.skipTest('Only run sharding tests on TPU.')

Expand Down Expand Up @@ -107,10 +99,7 @@ def partitioned_grad(params, xs):
# Convert to TF and save.
@tf.function(autograph=False, jit_compile=True)
def tf_fn(params, inputs):
return jax2tf.convert(
partitioned_apply,
native_serialization=native_serialization,
)(params, inputs)
return jax2tf.convert(partitioned_apply)(params, inputs)

tf_fn(params, images)
module = tf.Module()
Expand Down

0 comments on commit 516bb65

Please sign in to comment.