diff --git a/test.sh b/test.sh index 75f11a7..ae68641 100755 --- a/test.sh +++ b/test.sh @@ -72,7 +72,7 @@ pip uninstall --yes tensorflow pip install tf-nightly pip install git+https://github.com/google/jax.git pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -USE_JAX2TF_NATIVE_SERIALIZATION_IN_ROUNDTRIP_TEST=1 pytest -n "${N_JOBS}" --pyargs tf2jax._src.roundtrip_test +pytest -n "${N_JOBS}" --pyargs tf2jax._src.roundtrip_test cd .. set +u diff --git a/tf2jax/_src/numpy_compat_test.py b/tf2jax/_src/numpy_compat_test.py index cfc933a..82186e1 100644 --- a/tf2jax/_src/numpy_compat_test.py +++ b/tf2jax/_src/numpy_compat_test.py @@ -17,8 +17,6 @@ from absl.testing import absltest from absl.testing import parameterized -import jax -from jax.experimental import jax2tf import jax.numpy as jnp import numpy as np import tensorflow as tf @@ -46,19 +44,6 @@ def test_dtype_conversion(self, np_module, dtype_map): else: self.assertIs(dtype_map[src], getattr(np_module, dst)) - def test_is_poly_dim(self): - @jax.jit - def jax_func(x): - self.assertTrue(numpy_compat.is_poly_dim(x.shape[0])) - self.assertEqual(x.shape[1], 4) - return x + 3.14 - - tf_func = jax2tf.convert( - jax_func, polymorphic_shapes=["(b, _)"], native_serialization=False - ) - tf_forward = tf.function(tf_func, autograph=False) - tf_forward.get_concrete_function(tf.TensorSpec(shape=(None, 4))) - if __name__ == "__main__": absltest.main() diff --git a/tf2jax/_src/roundtrip_test.py b/tf2jax/_src/roundtrip_test.py index 74e0513..fee303b 100644 --- a/tf2jax/_src/roundtrip_test.py +++ b/tf2jax/_src/roundtrip_test.py @@ -14,9 +14,6 @@ # ============================================================================== """Tests for JAX -> TF -> JAX.""" -import os - -from absl import flags from absl.testing import absltest from absl.testing import parameterized import chex @@ -37,18 +34,6 @@ jax.config.parse_flags_with_absl() -def _bool_env(varname: str, value: bool) -> bool: - val = os.getenv(varname, str(value)) - return {"1": True, "0": False, "True": True, "False": False}[val] - - -_NATIVE_SERIALIZATION = flags.DEFINE_bool( - "use_jax2tf_native_serialization_in_roundtrip_test", - _bool_env("USE_JAX2TF_NATIVE_SERIALIZATION_IN_ROUNDTRIP_TEST", False), - "Whether to call jax2tf.convert with native serialization enabled.", -) - - def _compute_gradients(func, *inputs): def fn(*args): return jax.tree_util.tree_leaves(func(*args))[0] @@ -56,29 +41,8 @@ def fn(*args): return jax.grad(lambda *args: jnp.sum(fn(*args)))(*inputs) -def _jax2tf_convert(func, **kwargs): - return jax2tf.convert( - func, native_serialization=_NATIVE_SERIALIZATION.value, **kwargs - ) - - -def uses_native_serialization(): - return _NATIVE_SERIALIZATION.value - - class Jax2TfTest(test_util.TestCase): - def setUp(self): - super().setUp() - if not uses_native_serialization(): - self._xla_op = tf2jax.ops._jax_ops.pop("XlaCallModule", None) - - def tearDown(self): - super().tearDown() - if not uses_native_serialization(): - if self._xla_op is not None: - tf2jax.ops._jax_ops["XlaCallModule"] = self._xla_op - def _test_convert( self, jax_func, @@ -88,12 +52,11 @@ def _test_convert( with_custom_grad=False, grad_tols=None, ): - if uses_native_serialization(): - if with_grad and not with_custom_grad: - self.skipTest( - "native_serialization does not support differentiation without " - "custom gradient." - ) + if with_grad and not with_custom_grad: + self.skipTest( + "native_serialization does not support differentiation without " + "custom gradient." + ) grad_tols = grad_tols or {} @@ -107,7 +70,7 @@ def assert_grad_all_close(*args): jax_grads = _compute_gradients(jax_func, *inputs) # Jax -> TF - tf_func = _jax2tf_convert(jax_func, with_gradient=with_grad) + tf_func = jax2tf.convert(jax_func, with_gradient=with_grad) tf_func = tf.function(tf_func, jit_compile=True, autograph=False) tf_outputs = tf_func(*inputs) jax.tree.map(self.assertAllClose, jax_outputs, tf_outputs) @@ -520,13 +483,6 @@ def forward(x): with config.override_config( "infer_cumulative_reduction_from_jax2tf", use_heuristic ): - roundtrip_forward = tf2jax.convert_functional( - tf.function(_jax2tf_convert(forward), autograph=False), inputs - ) - roundtrip_jaxpr = jax.make_jaxpr(roundtrip_forward)(inputs) - if use_heuristic and not uses_native_serialization(): - self.assertNotIn("reduce_window", roundtrip_jaxpr.pretty_print()) - if with_grad and reducer is jax.lax.cumprod and not use_heuristic: self.skipTest( "No differentiation rule for `reduce_window` with " @@ -689,7 +645,7 @@ def forward(x): self.assertAllClose(tf_outputs, jax_outputs) # TF -> JAX -> TF - new_tf_forward = _jax2tf_convert( + new_tf_forward = jax2tf.convert( jax_func, polymorphic_shapes=["(b, _)"], with_gradient=with_grad ) new_tf_forward = tf.function(new_tf_forward, autograph=False) @@ -719,7 +675,7 @@ def forward(x, w): expected_outputs = forward(x, w) # JAX -> TF - tf_fn = _jax2tf_convert( + tf_fn = jax2tf.convert( forward, polymorphic_shapes=["(b, _)", None], with_gradient=with_grad ) tf_fn = tf.function(tf_fn, autograph=False) @@ -737,7 +693,7 @@ def forward(x, w): self.assertAllClose(expected_outputs, jax_outputs) # JAX -> TF -> JAX -> TF - tf_fn2 = _jax2tf_convert( + tf_fn2 = jax2tf.convert( jax_fn, polymorphic_shapes=["(b, _)", None], with_gradient=with_grad ) tf_fn2 = tf.function(tf_fn2, autograph=False) @@ -771,7 +727,7 @@ def forward(x, y): expected_outputs = forward(x, y) # JAX -> TF - tf_fn = _jax2tf_convert( + tf_fn = jax2tf.convert( forward, polymorphic_shapes=["(b, _)", "(_, b, _)"], with_gradient=with_grad, @@ -791,7 +747,7 @@ def forward(x, y): self.assertAllClose(expected_outputs, jax_outputs) # JAX -> TF -> JAX -> TF - tf_fn2 = _jax2tf_convert( + tf_fn2 = jax2tf.convert( jax_fn, polymorphic_shapes=["(b, _)", "(_, b, _)"], with_gradient=with_grad, @@ -836,7 +792,7 @@ def grad(dy): expected_grads = jax.grad(forward)(inputs) # JAX -> TF - tf_forward = _jax2tf_convert(forward, with_gradient=with_grad) + tf_forward = jax2tf.convert(forward, with_gradient=with_grad) tf_forward = tf.function(tf_forward, autograph=False) # JAX -> TF -> JAX @@ -920,13 +876,13 @@ def grad(dy): expected_grads = jax.grad(forward)(inputs) # JAX -> TF - tf_fn = _jax2tf_convert(forward, with_gradient=with_grad) + tf_fn = jax2tf.convert(forward, with_gradient=with_grad) tf_fn = tf.function(tf_fn, autograph=False) # JAX -> TF -> CALL_TF -> TF. # This creates dependencies between custom gradients. call_tf_fn = jax2tf.call_tf(tf_fn) - tf_fn_too = _jax2tf_convert(call_tf_fn, with_gradient=with_grad) + tf_fn_too = jax2tf.convert(call_tf_fn, with_gradient=with_grad) tf_fn_too = tf.function(tf_fn_too, autograph=False) # JAX -> TF -> CALL_TF -> TF -> JAX @@ -978,7 +934,7 @@ def forward(x): jax.grad(forward)(inputs) # JAX -> TF - tf_forward = _jax2tf_convert(forward, with_gradient=True) + tf_forward = jax2tf.convert(forward, with_gradient=True) tf_forward = tf.function(tf_forward, autograph=False) # JAX -> TF -> JAX @@ -1023,46 +979,12 @@ def tf2jax_fn(x): jax_fn = tf2jax.convert_functional(tf_fn, x) return jnp.sin(self.variant(jax_fn)(x)) - tf2jax2tf_fn = _jax2tf_convert(tf2jax_fn) + tf2jax2tf_fn = jax2tf.convert(tf2jax_fn) tf2jax2tf_fn = tf.function(tf2jax2tf_fn, autograph=False) inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3)) self.assertAllClose(tf.sin(tf_fn(inputs)), tf2jax2tf_fn(inputs)) - @chex.variants(with_jit=True, without_jit=True) - @parameterized.named_parameters( - chex.params_product( - (("without_gradient", False), ("with_gradient", True)), - named=True, - ) - ) - def test_remat(self, with_gradient): - def fn(x): - return jnp.sin(jnp.sin(x)) - - remat_fn = jax.checkpoint(fn) - - inputs = (np.linspace(0, 1, 10 * 5, dtype=np.float32).reshape(10, 5),) - self._test_convert( - remat_fn, inputs, with_grad=with_gradient, with_custom_grad=True - ) - - if uses_native_serialization(): - self.skipTest("Skip remat jaxpr test with native_serialization.") - - # Check jaxpr. - tf_fn = tf.function( - _jax2tf_convert(remat_fn, with_gradient=True), autograph=False - ) - jax_fn = tf2jax.convert_functional( - tf_fn, tf.TensorSpec((10, 5), tf.float32) - ) - jax_fn = self.variant(jax_fn) - out_jaxpr = jax.make_jaxpr(jax_fn)(*inputs) - self.assertNotRegex(str(out_jaxpr), "remat") - grad_jaxpr = jax.make_jaxpr(jax.grad(lambda x: jnp.sum(jax_fn(x))))(*inputs) - self.assertRegex(str(grad_jaxpr), "optimization_barrier") - @chex.variants(with_jit=True, without_jit=True) @parameterized.named_parameters( chex.params_product( @@ -1121,12 +1043,6 @@ def test_triangular_solve( unit_diagonal, shapes, ): - if uses_native_serialization(): - self.skipTest( - "native_serialization: Cannot serialize code with custom calls whose " - "targets have no compatibility guarantees." - ) - np.random.seed(42) lhs_shape, rhs_shape = shapes @@ -1150,35 +1066,19 @@ def forward(a, b): grad_tols=tols, ) - def test_explicit_native_serialization(self): + def test_platform_check(self): def forward(x): return x + 3.14 - tf_fn = jax2tf.convert(forward, native_serialization=True) + tf_fn = jax2tf.convert(forward) tf_fn = tf.function(tf_fn, autograph=False) - try: - jax_fn = tf2jax.convert_functional( - tf_fn, tf.TensorSpec((2, 3), tf.float32) - ) - jax_fn = jax.jit(jax_fn) - inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3)) - self.assertAllClose(jax_fn(inputs), tf_fn(inputs)) - except ValueError as e: - if uses_native_serialization(): - raise ValueError("Native lowering support failed.") from e - elif r"Unsupported operations in graph: ['XlaCallModule']" not in str(e): - raise ValueError("Unexpected unsupported operations found.") from e - elif r"check for import error if you see this message" not in str(e): - raise ValueError("Import/dependency error message not found.") from e - else: # Expected failure. - return - - if not uses_native_serialization(): - raise ValueError( - "Unexpected success with native serialization. Test may be " - "misconfigured." - ) + jax_fn = tf2jax.convert_functional( + tf_fn, tf.TensorSpec((2, 3), tf.float32) + ) + jax_fn = jax.jit(jax_fn) + inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3)) + self.assertAllClose(jax_fn(inputs), tf_fn(inputs)) if jax.default_backend().lower() != "cpu": with jax.default_device(jax.local_devices(backend="cpu")[0]): @@ -1191,9 +1091,6 @@ def forward(x): @chex.variants(with_jit=True, without_jit=True) def test_platform_index(self): - if not uses_native_serialization(): - self.skipTest("Skip platform_index test without native serialization.") - @jax.jit def forward(x): return jnp.tile(x**2, 2).reshape((2, *x.shape)) @@ -1203,7 +1100,6 @@ def forward(x): func = jax2tf.convert( forward, - native_serialization=True, polymorphic_shapes=("(B, d * 128)",), native_serialization_platforms=("cuda", "cpu", "tpu"), ) diff --git a/tf2jax/_src/sharding_test.py b/tf2jax/_src/sharding_test.py index 444afcc..59ca714 100644 --- a/tf2jax/_src/sharding_test.py +++ b/tf2jax/_src/sharding_test.py @@ -15,8 +15,6 @@ """Tests for JAX -> TF -> JAX with partitioning.""" from absl.testing import absltest -from absl.testing import parameterized -import chex import haiku as hk import jax from jax.experimental import jax2tf @@ -49,13 +47,7 @@ def _get_param_pspecs(): class ShardingTest(test_util.TestCase): - @parameterized.named_parameters( - chex.params_product( - (('native_serialization', True), ('graph_serialization', False)), - named=True, - ) - ) - def test_sharding(self, native_serialization): + def test_sharding(self): if jax.default_backend().upper() != 'TPU': self.skipTest('Only run sharding tests on TPU.') @@ -107,10 +99,7 @@ def partitioned_grad(params, xs): # Convert to TF and save. @tf.function(autograph=False, jit_compile=True) def tf_fn(params, inputs): - return jax2tf.convert( - partitioned_apply, - native_serialization=native_serialization, - )(params, inputs) + return jax2tf.convert(partitioned_apply)(params, inputs) tf_fn(params, images) module = tf.Module()