diff --git a/tf2jax/_src/roundtrip_test.py b/tf2jax/_src/roundtrip_test.py index 5302710..74e0513 100644 --- a/tf2jax/_src/roundtrip_test.py +++ b/tf2jax/_src/roundtrip_test.py @@ -19,13 +19,11 @@ from absl import flags from absl.testing import absltest from absl.testing import parameterized - import chex import haiku as hk import jax from jax.experimental import jax2tf import jax.numpy as jnp - import numpy as np import tensorflow as tf from tf2jax._src import config @@ -54,6 +52,7 @@ def _bool_env(varname: str, value: bool) -> bool: def _compute_gradients(func, *inputs): def fn(*args): return jax.tree_util.tree_leaves(func(*args))[0] + return jax.grad(lambda *args: jnp.sum(fn(*args)))(*inputs) @@ -86,19 +85,18 @@ def _test_convert( inputs, *, with_grad, - enable_xla, with_custom_grad=False, grad_tols=None, ): if uses_native_serialization(): - if not enable_xla: - self.skipTest("native_serialization does not support enable_xla=False.") if with_grad and not with_custom_grad: self.skipTest( "native_serialization does not support differentiation without " - "custom gradient.") + "custom gradient." + ) grad_tols = grad_tols or {} + def assert_grad_all_close(*args): return self.assertAllClose(*args, **grad_tols) @@ -109,9 +107,7 @@ def assert_grad_all_close(*args): jax_grads = _compute_gradients(jax_func, *inputs) # Jax -> TF - tf_func = _jax2tf_convert( - jax_func, with_gradient=with_grad, enable_xla=enable_xla - ) + tf_func = _jax2tf_convert(jax_func, with_gradient=with_grad) tf_func = tf.function(tf_func, jit_compile=True, autograph=False) tf_outputs = tf_func(*inputs) jax.tree.map(self.assertAllClose, jax_outputs, tf_outputs) @@ -119,7 +115,8 @@ def assert_grad_all_close(*args): # Jax -> TF -> Jax with config.override_config("convert_custom_gradient", with_custom_grad): rejax_func = tf2jax.convert_functional( - tf_func, *tree.map_structure(np.zeros_like, inputs)) + tf_func, *tree.map_structure(np.zeros_like, inputs) + ) rejax_func = self.variant(rejax_func) rejax_outputs = rejax_func(*inputs) jax.tree.map(self.assertAllClose, rejax_outputs, tf_outputs) @@ -140,7 +137,8 @@ def assert_grad_all_close(*args): # Jax -> TF -> SavedModel -> TF -> Jax with config.override_config("convert_custom_gradient", with_custom_grad): rejax_too_func = tf2jax.convert_functional( - restored.f, *tree.map_structure(np.zeros_like, inputs)) + restored.f, *tree.map_structure(np.zeros_like, inputs) + ) rejax_too_func = self.variant(rejax_too_func) rejax_too_outputs = rejax_too_func(*inputs) jax.tree.map(self.assertAllClose, rejax_too_outputs, tf_outputs) @@ -152,11 +150,11 @@ def assert_grad_all_close(*args): @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), (("without_custom_gradient", False), ("with_custom_gradient", True)), named=True, - )) - def test_simple(self, with_grad, enable_xla, with_custom_grad): + ) + ) + def test_simple(self, with_grad, with_custom_grad): np.random.seed(42) def forward(x): @@ -164,20 +162,21 @@ def forward(x): inputs = np.random.normal((3, 2)).astype(np.float32) self._test_convert( - forward, [inputs], + forward, + [inputs], with_grad=with_grad, - enable_xla=enable_xla, - with_custom_grad=with_custom_grad) + with_custom_grad=with_custom_grad, + ) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), (("without_custom_gradient", False), ("with_custom_gradient", True)), named=True, - )) - def test_mlp(self, with_grad, enable_xla, with_custom_grad): + ) + ) + def test_mlp(self, with_grad, with_custom_grad): np.random.seed(42) def forward(x): @@ -190,20 +189,21 @@ def forward(x): variables = hk.data_structures.to_mutable_dict(variables) jax_fn = hk.without_apply_rng(forward).apply self._test_convert( - jax_fn, [variables, inputs], + jax_fn, + [variables, inputs], with_grad=with_grad, - enable_xla=enable_xla, - with_custom_grad=with_custom_grad) + with_custom_grad=with_custom_grad, + ) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), (("without_custom_gradient", False), ("with_custom_gradient", True)), named=True, - )) - def test_batch_norm(self, with_grad, enable_xla, with_custom_grad): + ) + ) + def test_batch_norm(self, with_grad, with_custom_grad): np.random.seed(42) def forward(x): @@ -213,40 +213,46 @@ def forward(x): inputs = np.random.normal(size=(8, 17)) forward = hk.transform_with_state(forward) variables, states = forward.init( - jax.random.PRNGKey(42), jnp.zeros_like(inputs)) + jax.random.PRNGKey(42), jnp.zeros_like(inputs) + ) variables = hk.data_structures.to_mutable_dict(variables) states = hk.data_structures.to_mutable_dict(states) + def jax_fn(params, states, x): outputs, states = hk.without_apply_rng(forward).apply(params, states, x) return outputs, hk.data_structures.to_mutable_dict(states) # Perturb variables and states. variables = tree.map_structure( - lambda x: x + np.random.uniform(size=x.shape), variables) + lambda x: x + np.random.uniform(size=x.shape), variables + ) states = tree.map_structure( - lambda x: x + np.random.normal(size=x.shape), states) + lambda x: x + np.random.normal(size=x.shape), states + ) self._test_convert( - jax_fn, [variables, states, inputs], + jax_fn, + [variables, states, inputs], with_grad=with_grad, - enable_xla=enable_xla, - with_custom_grad=with_custom_grad) + with_custom_grad=with_custom_grad, + ) # Conv2D uses jax.lax.conv_general_dilated which is translated to XlaConv. @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), named=True, - )) - def test_conv2d(self, with_grad, enable_xla): + ) + ) + def test_conv2d(self, with_grad): np.random.seed(42) tols = dict(rtol=1e-5) if jax.default_backend().lower() == "gpu" else {} def forward(x): conv = hk.Conv2D( - output_channels=7, kernel_shape=3, stride=1, padding="SAME") + output_channels=7, kernel_shape=3, stride=1, padding="SAME" + ) return conv(x) inputs = np.random.normal(size=(8, 28, 28, 3)) @@ -255,29 +261,28 @@ def forward(x): variables = hk.data_structures.to_mutable_dict(variables) jax_fn = hk.without_apply_rng(forward).apply self._test_convert( - jax_fn, [variables, inputs], - with_grad=with_grad, - enable_xla=enable_xla, - grad_tols=tols) + jax_fn, [variables, inputs], with_grad=with_grad, grad_tols=tols + ) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), ( ("default_group_counts", dict()), ("feature_group_count", dict(feature_group_count=3)), ("batch_group_count", dict(batch_group_count=2)), ), named=True, - )) - def test_xla_conv(self, with_grad, enable_xla, group_counts): + ) + ) + def test_xla_conv(self, with_grad, group_counts): np.random.seed(42) kernels = np.random.normal(size=(3, 3, 3, 12)) dimension_numbers = jax.lax.ConvDimensionNumbers( - lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) + lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2) + ) def forward(x): return jax.lax.conv_general_dilated( @@ -288,7 +293,8 @@ def forward(x): lhs_dilation=(1, 1), rhs_dilation=(1, 1), dimension_numbers=dimension_numbers, - **group_counts) + **group_counts, + ) feature_dim = 3 * group_counts.get("feature_group_count", 1) inputs = np.random.normal(size=(8, 28, 28, feature_dim)) @@ -296,17 +302,16 @@ def forward(x): variables = forward.init(jax.random.PRNGKey(42), jnp.zeros_like(inputs)) variables = hk.data_structures.to_mutable_dict(variables) jax_fn = hk.without_apply_rng(forward).apply - self._test_convert( - jax_fn, [variables, inputs], with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(jax_fn, [variables, inputs], with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), named=True, - )) - def test_dot(self, with_grad, enable_xla): + ) + ) + def test_dot(self, with_grad): def forward(lhs, rhs): return jax.lax.dot(lhs, rhs) @@ -314,17 +319,16 @@ def forward(lhs, rhs): np.linspace(0, 1, 10 * 5).reshape(10, 5), np.linspace(-1, 0, 5 * 3).reshape(5, 3), ) - self._test_convert( - forward, inputs, with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, inputs, with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), named=True, - )) - def test_dot_general(self, with_grad, enable_xla): + ) + ) + def test_dot_general(self, with_grad): dimension_numbers = (((2,), (1,)), ((0,), (0,))) def forward(lhs, rhs): @@ -334,32 +338,30 @@ def forward(lhs, rhs): np.linspace(0, 1, 2 * 10 * 5).reshape((2, 10, 5)), np.linspace(-1, 0, 2 * 5 * 3).reshape((2, 5, 3)), ) - self._test_convert( - forward, inputs, with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, inputs, with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), named=True, - )) - def test_dynamic_slice(self, with_grad, enable_xla): + ) + ) + def test_dynamic_slice(self, with_grad): def forward(x): return jax.lax.dynamic_slice(x, (1, 1), (2, 3)) inputs = np.linspace(0, 1, 12).reshape(3, 4) - self._test_convert( - forward, [inputs], with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, [inputs], with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), named=True, - )) - def test_dynamic_update_slice(self, with_grad, enable_xla): + ) + ) + def test_dynamic_update_slice(self, with_grad): def forward(x, y): return jax.lax.dynamic_update_slice(x, y, (1, 2)) @@ -367,17 +369,16 @@ def forward(x, y): np.linspace(0, 1, 12).reshape(3, 4), 1.0 - np.linspace(0, 1, 12).reshape(3, 4), ] - self._test_convert( - forward, inputs, with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, inputs, with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), named=True, - )) - def test_gather(self, with_grad, enable_xla): + ) + ) + def test_gather(self, with_grad): dimension_numbers = jax.lax.GatherDimensionNumbers((1,), (0,), (0, 1)) slice_sizes = (1, 3) @@ -389,17 +390,16 @@ def forward(operand, indices): np.linspace(0, 1, 10 * 5).reshape(10, 5), np.array([[4, 2], [3, 2]]), ) - self._test_convert( - forward, inputs, with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, inputs, with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), named=True, - )) - def test_pad(self, with_grad, enable_xla): + ) + ) + def test_pad(self, with_grad): padding_config = [(1, 2, 1), (0, 1, 0)] def forward(operand, padding_value): @@ -409,39 +409,37 @@ def forward(operand, padding_value): np.linspace(0, 1, 2 * 3).reshape(2, 3), np.array(0.42), ) - self._test_convert( - forward, inputs, with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, inputs, with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), ( ("min", jax.lax.min, jnp.inf), ("max", jax.lax.max, -jnp.inf), ("add", jax.lax.add, 0.0), ), named=True, - )) - def test_reduce(self, with_grad, enable_xla, reduce_fn, init_value): + ) + ) + def test_reduce(self, with_grad, reduce_fn, init_value): def forward(x): dimensions = [1, 2] return jax.lax.reduce(x, init_value, reduce_fn, dimensions) inputs = np.linspace(0, 1, 2 * 5 * 5 * 3).reshape((2, 5, 5, 3)) - self._test_convert( - forward, [inputs], with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, [inputs], with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), named=True, - )) - def test_reduce_variadic(self, with_grad, enable_xla): + ) + ) + def test_reduce_variadic(self, with_grad): def forward(args): return jax.lax.reduce(args, (0.0, 1.0), lambda xs, ys: xs, [1, 2]) @@ -450,39 +448,37 @@ def forward(args): np.linspace(0, 1, 2 * 5 * 5 * 3).reshape((2, 5, 5, 3)), np.linspace(2, 3, 2 * 5 * 5 * 3).reshape((2, 5, 5, 3)), ) - self._test_convert( - forward, [inputs], with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, [inputs], with_grad=with_grad) # jax.lax.reduce_window is translated to XlaReduceWindow. @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), ( ("min", jax.lax.min, jnp.inf), ("max", jax.lax.max, -jnp.inf), ("add", jax.lax.add, 0.0), ), named=True, - )) - def test_reduce_window(self, with_grad, enable_xla, reduce_fn, init_value): + ) + ) + def test_reduce_window(self, with_grad, reduce_fn, init_value): np.random.seed(42) def forward(x): window_shape = [1, 2, 2, 1] - return jax.lax.reduce_window(x, init_value, reduce_fn, window_shape, - window_shape, "SAME") + return jax.lax.reduce_window( + x, init_value, reduce_fn, window_shape, window_shape, "SAME" + ) inputs = np.random.normal(size=(8, 28, 28, 3)) - self._test_convert( - forward, [inputs], with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, [inputs], with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), ( ("min", jax.lax.cummin), ("max", jax.lax.cummax), @@ -502,11 +498,16 @@ def forward(x): ("no_heuristic", False), ), named=True, - )) - def test_cumulative_reduction(self, with_grad, enable_xla, reducer, axis, - reverse, use_heuristic): - if (with_grad and not use_heuristic and - jax.default_backend().lower() == "tpu"): + ) + ) + def test_cumulative_reduction( + self, with_grad, reducer, axis, reverse, use_heuristic + ): + if ( + with_grad + and not use_heuristic + and jax.default_backend().lower() == "tpu" + ): self.skipTest("Gradient of reduce-window not always supported on TPU") np.random.seed(42) @@ -516,28 +517,28 @@ def forward(x): inputs = np.random.normal(size=(4, 3)) - with config.override_config("infer_cumulative_reduction_from_jax2tf", - use_heuristic): + with config.override_config( + "infer_cumulative_reduction_from_jax2tf", use_heuristic + ): roundtrip_forward = tf2jax.convert_functional( - tf.function(_jax2tf_convert(forward), autograph=False), inputs) + tf.function(_jax2tf_convert(forward), autograph=False), inputs + ) roundtrip_jaxpr = jax.make_jaxpr(roundtrip_forward)(inputs) - if (use_heuristic and - not uses_native_serialization()): + if use_heuristic and not uses_native_serialization(): self.assertNotIn("reduce_window", roundtrip_jaxpr.pretty_print()) - if (with_grad and enable_xla and reducer is jax.lax.cumprod and - not use_heuristic): - self.skipTest("No differentiation rule for `reduce_window` with " - "`jax.lax.cumprod`.") + if with_grad and reducer is jax.lax.cumprod and not use_heuristic: + self.skipTest( + "No differentiation rule for `reduce_window` with " + "`jax.lax.cumprod`." + ) - self._test_convert( - forward, [inputs], with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, [inputs], with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False),), - (("enable_xla", True),), (("uint32", np.uint32),), ( ("default", xla_data_pb2.RandomAlgorithm.RNG_DEFAULT), @@ -545,19 +546,20 @@ def forward(x): ("philox", xla_data_pb2.RandomAlgorithm.RNG_PHILOX), ), named=True, - )) - def test_rng_bit_generator(self, with_grad, enable_xla, dtype, algorithm): + ) + ) + def test_rng_bit_generator(self, with_grad, dtype, algorithm): def forward(key): return jax.lax.rng_bit_generator( - key, shape=(10, 5), dtype=dtype, algorithm=algorithm) + key, shape=(10, 5), dtype=dtype, algorithm=algorithm + ) if dtype == np.uint32: key = np.array([6, 7, 8, 9], dtype=np.uint32) else: raise ValueError(f"Unsupported dtype={dtype}") - self._test_convert( - forward, [key], with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, [key], with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( @@ -570,14 +572,16 @@ def forward(key): ("scatter_max", jax.lax.scatter_max), ), (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), (("unique_indices", True), ("non_unique_indices", False)), named=True, - )) - def test_scatter(self, scatter_fn, with_grad, enable_xla, unique_indices): + ) + ) + def test_scatter(self, scatter_fn, with_grad, unique_indices): if scatter_fn is jax.lax.scatter_mul and with_grad and not unique_indices: - self.skipTest("Gradient is disallowed for jax.lax.scatter_mul if " - "unique_indices=False") + self.skipTest( + "Gradient is disallowed for jax.lax.scatter_mul if " + "unique_indices=False" + ) dimension_numbers = jax.lax.ScatterDimensionNumbers((1,), (0,), (0,)) @@ -587,44 +591,43 @@ def forward(operand, indices, updates): indices, updates, dimension_numbers, - unique_indices=unique_indices) + unique_indices=unique_indices, + ) inputs = ( - np.linspace(0, 1, 10*5).reshape(10, 5), + np.linspace(0, 1, 10 * 5).reshape(10, 5), np.array([[1], [8], [4]]), np.linspace(0, 9, 9).reshape(3, 3), ) - self._test_convert( - forward, inputs, with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, inputs, with_grad=with_grad) # Derivative of jax.lax.reduce_window uses XlaSelectAndScatter. @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False),), - (("enable_xla", True),), ( ("min", jax.lax.min, jnp.inf), ("max", jax.lax.max, -jnp.inf), ("add", jax.lax.add, 0.0), ), named=True, - )) - def test_select_and_scatter(self, with_grad, enable_xla, reduce_fn, - init_value): + ) + ) + def test_select_and_scatter(self, with_grad, reduce_fn, init_value): np.random.seed(42) def forward(x): window_shape = [1, 2, 2, 1] - return jax.lax.reduce_window(x, init_value, reduce_fn, window_shape, - window_shape, "SAME") + return jax.lax.reduce_window( + x, init_value, reduce_fn, window_shape, window_shape, "SAME" + ) inputs = np.random.normal(size=(8, 5, 5, 3)) jax_fn = jax.jacrev(forward) try: - self._test_convert( - jax_fn, [inputs], with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(jax_fn, [inputs], with_grad=with_grad) except tf.errors.InvalidArgumentError as e: if jax.default_backend().lower() == "tpu": # Can fail on older TPUs. @@ -636,39 +639,34 @@ def forward(x): @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("enable_xla", True),), (("2nd_last_dim", -2), ("last_dim", -1)), (("not_stable", False), ("is_stable", True)), (("one_keys", 1), ("two_keys", 2), ("three_keys", 3)), named=True, - )) - def test_sort_variadic(self, with_grad, enable_xla, dim, is_stable, num_keys): + ) + ) + def test_sort_variadic(self, with_grad, dim, is_stable, num_keys): def forward(args): return jax.lax.sort( - args, dimension=dim, is_stable=is_stable, num_keys=num_keys) + args, dimension=dim, is_stable=is_stable, num_keys=num_keys + ) inputs = ( - np.array([[6., 2.], [4., 2.], [4., 1.]], np.float32), - np.array([[1., 2.], [3., 4.], [5., 6.]], np.float32), - np.array([[6., 5.], [4., 3.], [2., 1.]], np.float32), + np.array([[6.0, 2.0], [4.0, 2.0], [4.0, 1.0]], np.float32), + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], np.float32), + np.array([[6.0, 5.0], [4.0, 3.0], [2.0, 1.0]], np.float32), ) - self._test_convert( - forward, [inputs], with_grad=with_grad, enable_xla=enable_xla) + self._test_convert(forward, [inputs], with_grad=with_grad) @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), named=True, ) ) - def test_polymorphic_shape(self, with_grad, enable_xla): - if uses_native_serialization(): - if not enable_xla: - self.skipTest("native_serialization does not support enable_xla=False.") - + def test_polymorphic_shape(self, with_grad): inputs = np.array(range(36), dtype=np.float32).reshape(9, 4) # TF @@ -681,6 +679,7 @@ def forward(x): outputs = tf.concat([outputs] * 2, axis=0) # Stack along unknown dim outputs = tf.concat([outputs] * 2, axis=1) # Stack along knonwn dim return outputs / tf.cast(shape[0], tf.float32) # Divide by unknown dim + tf_outputs = forward(inputs) # TF -> JAX @@ -691,15 +690,15 @@ def forward(x): # TF -> JAX -> TF new_tf_forward = _jax2tf_convert( - jax_func, - polymorphic_shapes=["(b, _)"], - with_gradient=with_grad, - enable_xla=enable_xla) + jax_func, polymorphic_shapes=["(b, _)"], with_gradient=with_grad + ) new_tf_forward = tf.function(new_tf_forward, autograph=False) concrete_new_tf_forward = new_tf_forward.get_concrete_function( - tf.TensorSpec(shape=(None, 4))) - self.assertEqual(concrete_new_tf_forward.structured_outputs.shape.as_list(), - [None, 8]) + tf.TensorSpec(shape=(None, 4)) + ) + self.assertEqual( + concrete_new_tf_forward.structured_outputs.shape.as_list(), [None, 8] + ) new_tf_outputs = concrete_new_tf_forward(inputs) self.assertAllClose(new_tf_outputs, jax_outputs) @@ -707,15 +706,10 @@ def forward(x): @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), named=True, ) ) - def test_polymorphic_shape_refinement_dot(self, with_grad, enable_xla): - if uses_native_serialization(): - if not enable_xla: - self.skipTest("native_serialization does not support enable_xla=False.") - + def test_polymorphic_shape_refinement_dot(self, with_grad): @jax.jit def forward(x, w): return jnp.dot(x, w) @@ -726,13 +720,12 @@ def forward(x, w): # JAX -> TF tf_fn = _jax2tf_convert( - forward, - polymorphic_shapes=["(b, _)", None], - with_gradient=with_grad, - enable_xla=enable_xla) + forward, polymorphic_shapes=["(b, _)", None], with_gradient=with_grad + ) tf_fn = tf.function(tf_fn, autograph=False) concrete_tf_fn = tf_fn.get_concrete_function( - tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(4, 5))) + tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(4, 5)) + ) tf_outputs = concrete_tf_fn(x, w) self.assertAllClose(expected_outputs, tf_outputs) @@ -745,13 +738,12 @@ def forward(x, w): # JAX -> TF -> JAX -> TF tf_fn2 = _jax2tf_convert( - jax_fn, - polymorphic_shapes=["(b, _)", None], - with_gradient=with_grad, - enable_xla=enable_xla) + jax_fn, polymorphic_shapes=["(b, _)", None], with_gradient=with_grad + ) tf_fn2 = tf.function(tf_fn2, autograph=False) concrete_tf_fn2 = tf_fn2.get_concrete_function( - tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(4, 5))) + tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(4, 5)) + ) tf_outputs2 = concrete_tf_fn2(x, w) self.assertAllClose(expected_outputs, tf_outputs2) @@ -766,15 +758,10 @@ def forward(x, w): @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), named=True, ) ) - def test_polymorphic_shape_refinement_broadcast(self, with_grad, enable_xla): - if uses_native_serialization(): - if not enable_xla: - self.skipTest("native_serialization does not support enable_xla=False.") - + def test_polymorphic_shape_refinement_broadcast(self, with_grad): @jax.jit def forward(x, y): return (jnp.broadcast_to(x, y.shape), x + y) @@ -788,10 +775,11 @@ def forward(x, y): forward, polymorphic_shapes=["(b, _)", "(_, b, _)"], with_gradient=with_grad, - enable_xla=enable_xla) + ) tf_fn = tf.function(tf_fn, autograph=False) concrete_tf_fn = tf_fn.get_concrete_function( - tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(2, None, 4))) + tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(2, None, 4)) + ) tf_outputs = concrete_tf_fn(x, y) self.assertAllClose(expected_outputs, tf_outputs) @@ -807,10 +795,11 @@ def forward(x, y): jax_fn, polymorphic_shapes=["(b, _)", "(_, b, _)"], with_gradient=with_grad, - enable_xla=enable_xla) + ) tf_fn2 = tf.function(tf_fn2, autograph=False) concrete_tf_fn2 = tf_fn2.get_concrete_function( - tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(2, None, 4))) + tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(2, None, 4)) + ) tf_outputs2 = concrete_tf_fn2(x, y) self.assertAllClose(expected_outputs, tf_outputs2) @@ -825,37 +814,34 @@ def forward(x, y): @parameterized.named_parameters( chex.params_product( (("with_gradient", True),), - (("disable_xla", False), ("enable_xla", True)), named=True, - )) - def test_custom_gradient(self, with_grad, enable_xla): - if uses_native_serialization(): - if not enable_xla: - self.skipTest("native_serialization does not support enable_xla=False.") - + ) + ) + def test_custom_gradient(self, with_grad): inputs = np.array(range(6), dtype=np.float32).reshape(3, 2) # JAX @jax.custom_gradient def forward(x): e = jnp.exp(x) + def grad(dy): # This is deliberately the wrong gradient. return dy * (1 - 1 / (1 + e)) * jnp.sin(x) + 0.42 + return jnp.sum(jnp.log(1 + e)), grad + forward = self.variant(forward) expected_outputs = forward(inputs) expected_grads = jax.grad(forward)(inputs) # JAX -> TF - tf_forward = _jax2tf_convert( - forward, with_gradient=with_grad, enable_xla=enable_xla) + tf_forward = _jax2tf_convert(forward, with_gradient=with_grad) tf_forward = tf.function(tf_forward, autograph=False) # JAX -> TF -> JAX with config.override_config("convert_custom_gradient", True): - jax_forward = tf2jax.convert_functional(tf_forward, - tf.zeros_like(inputs)) + jax_forward = tf2jax.convert_functional(tf_forward, tf.zeros_like(inputs)) jax_forward = self.variant(jax_forward) jax_outputs = jax_forward(inputs) jax_grads = jax.grad(jax_forward)(inputs) @@ -872,8 +858,9 @@ def grad(dy): # Jax -> TF -> SavedModel -> TF -> Jax with config.override_config("convert_custom_gradient", True): - re_jax_forward = tf2jax.convert_functional(restored.f, - tf.zeros_like(inputs)) + re_jax_forward = tf2jax.convert_functional( + restored.f, tf.zeros_like(inputs) + ) re_jax_forward = self.variant(re_jax_forward) re_jax_outputs = re_jax_forward(inputs) re_jax_grads = jax.grad(re_jax_forward)(inputs) @@ -911,23 +898,21 @@ def test_custom_gradient_saved_model(self): @parameterized.named_parameters( chex.params_product( (("with_gradient", True),), - (("disable_xla", False), ("enable_xla", True)), named=True, - )) - def test_custom_gradient_nested(self, with_grad, enable_xla): - if uses_native_serialization(): - if not enable_xla: - self.skipTest("native_serialization does not support enable_xla=False.") - + ) + ) + def test_custom_gradient_nested(self, with_grad): inputs = np.array(range(6), dtype=np.float32).reshape(3, 2) # JAX @jax.custom_gradient def forward(x): e = jnp.exp(x) + def grad(dy): # This is deliberately the wrong gradient. return dy * (1 - 1 / (1 + e)) * jnp.sin(x) + 0.42 + return jnp.sum(jnp.log(1 + e)), grad forward = self.variant(forward) @@ -935,8 +920,7 @@ def grad(dy): expected_grads = jax.grad(forward)(inputs) # JAX -> TF - tf_fn = _jax2tf_convert( - forward, with_gradient=with_grad, enable_xla=enable_xla) + tf_fn = _jax2tf_convert(forward, with_gradient=with_grad) tf_fn = tf.function(tf_fn, autograph=False) # JAX -> TF -> CALL_TF -> TF. @@ -958,29 +942,26 @@ def grad(dy): @parameterized.named_parameters( chex.params_product( (("without_gradient", False), ("with_gradient", True)), - (("disable_xla", False), ("enable_xla", True)), (("without_custom_gradient", False), ("with_custom_gradient", True)), named=True, - )) - def test_relu(self, with_grad, enable_xla, with_custom_grad): + ) + ) + def test_relu(self, with_grad, with_custom_grad): inputs = np.array([-1.0, 0.0, 1.0], np.float32) self._test_convert( - jax.nn.relu, [inputs], + jax.nn.relu, + [inputs], with_grad=with_grad, - enable_xla=enable_xla, - with_custom_grad=with_custom_grad) + with_custom_grad=with_custom_grad, + ) @chex.variants(with_jit=True, without_jit=True) @parameterized.named_parameters( chex.params_product( - (("disable_xla", False), ("enable_xla", True)), named=True, - )) - def test_empty_return(self, enable_xla): - if uses_native_serialization(): - if not enable_xla: - self.skipTest("native_serialization does not support enable_xla=False.") - + ) + ) + def test_empty_return(self): np.random.seed(42) def forward(x): @@ -992,23 +973,23 @@ def forward(x): inputs = np.random.normal((3, 2)) with self.assertRaisesRegex( TypeError, - "Gradient only defined for scalar-output functions. Output was ()."): + "Gradient only defined for scalar-output functions. Output was ().", + ): jax.grad(forward)(inputs) # JAX -> TF - tf_forward = _jax2tf_convert( - forward, with_gradient=True, enable_xla=enable_xla) + tf_forward = _jax2tf_convert(forward, with_gradient=True) tf_forward = tf.function(tf_forward, autograph=False) # JAX -> TF -> JAX with config.override_config("convert_custom_gradient", True): - jax_forward = tf2jax.convert_functional(tf_forward, - tf.zeros_like(inputs)) + jax_forward = tf2jax.convert_functional(tf_forward, tf.zeros_like(inputs)) jax_forward = self.variant(jax_forward) with self.assertRaisesRegex( TypeError, - "Gradient only defined for scalar-output functions. Output was ()."): + "Gradient only defined for scalar-output functions. Output was ().", + ): jax.grad(jax_forward)(inputs) # Jax -> TF -> SavedModel -> TF @@ -1021,13 +1002,15 @@ def forward(x): # Jax -> TF -> SavedModel -> TF -> Jax with config.override_config("convert_custom_gradient", True): - re_jax_forward = tf2jax.convert_functional(restored.f, - tf.zeros_like(inputs)) + re_jax_forward = tf2jax.convert_functional( + restored.f, tf.zeros_like(inputs) + ) re_jax_forward = self.variant(re_jax_forward) with self.assertRaisesRegex( TypeError, - "Gradient only defined for scalar-output functions. Output was ()."): + "Gradient only defined for scalar-output functions. Output was ().", + ): jax.grad(re_jax_forward)(inputs) @chex.variants(with_jit=True, without_jit=True) @@ -1043,7 +1026,7 @@ def tf2jax_fn(x): tf2jax2tf_fn = _jax2tf_convert(tf2jax_fn) tf2jax2tf_fn = tf.function(tf2jax2tf_fn, autograph=False) - inputs = np.linspace(-1., 1., 6, dtype=np.float32).reshape((2, 3)) + inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3)) self.assertAllClose(tf.sin(tf_fn(inputs)), tf2jax2tf_fn(inputs)) @chex.variants(with_jit=True, without_jit=True) @@ -1051,31 +1034,29 @@ def tf2jax_fn(x): chex.params_product( (("without_gradient", False), ("with_gradient", True)), named=True, - )) + ) + ) def test_remat(self, with_gradient): def fn(x): return jnp.sin(jnp.sin(x)) + remat_fn = jax.checkpoint(fn) - inputs = ( - np.linspace(0, 1, 10 * 5, dtype=np.float32).reshape(10, 5), - ) + inputs = (np.linspace(0, 1, 10 * 5, dtype=np.float32).reshape(10, 5),) self._test_convert( - remat_fn, - inputs, - with_grad=with_gradient, - enable_xla=True, - with_custom_grad=True) + remat_fn, inputs, with_grad=with_gradient, with_custom_grad=True + ) if uses_native_serialization(): self.skipTest("Skip remat jaxpr test with native_serialization.") # Check jaxpr. tf_fn = tf.function( - _jax2tf_convert(remat_fn, with_gradient=True, enable_xla=True), - autograph=False) - jax_fn = tf2jax.convert_functional(tf_fn, tf.TensorSpec((10, 5), - tf.float32)) + _jax2tf_convert(remat_fn, with_gradient=True), autograph=False + ) + jax_fn = tf2jax.convert_functional( + tf_fn, tf.TensorSpec((10, 5), tf.float32) + ) jax_fn = self.variant(jax_fn) out_jaxpr = jax.make_jaxpr(jax_fn)(*inputs) self.assertNotRegex(str(out_jaxpr), "remat") @@ -1086,14 +1067,16 @@ def fn(x): @parameterized.named_parameters( chex.params_product( (("without_gradient", False),), - (("enable_xla", True),), (("without_custom_gradient", False), ("with_custom_gradient", True)), named=True, - )) - def test_reduce_precision(self, with_grad, enable_xla, with_custom_grad): + ) + ) + def test_reduce_precision(self, with_grad, with_custom_grad): if jax.__version_info__ <= (0, 4, 4): - self.skipTest("jax.lax.reduce_precision is only supported from 0.4.4 and " - f"onward, found {jax.__version__}.") + self.skipTest( + "jax.lax.reduce_precision is only supported from 0.4.4 and " + f"onward, found {jax.__version__}." + ) np.random.seed(42) @@ -1103,16 +1086,16 @@ def forward(x): inputs = np.random.normal((3, 2)).astype(np.float32) self._test_convert( - forward, [inputs], + forward, + [inputs], with_grad=with_grad, - enable_xla=enable_xla, - with_custom_grad=with_custom_grad) + with_custom_grad=with_custom_grad, + ) @chex.variants(with_jit=True, without_jit=True) @parameterized.named_parameters( chex.params_product( (("without_gradient", True),), - (("enable_xla", True), ("disable_xla", False)), (("with_custom_gradient", True),), ( ("lower", True), @@ -1128,11 +1111,11 @@ def forward(x): ("more_batched", ((2, 3, 5, 5), (2, 3, 5, 6))), ), named=True, - )) + ) + ) def test_triangular_solve( self, with_grad, - enable_xla, with_custom_grad, lower, unit_diagonal, @@ -1141,7 +1124,8 @@ def test_triangular_solve( if uses_native_serialization(): self.skipTest( "native_serialization: Cannot serialize code with custom calls whose " - "targets have no compatibility guarantees.") + "targets have no compatibility guarantees." + ) np.random.seed(42) @@ -1159,11 +1143,12 @@ def forward(a, b): tols = dict(atol=1e-5) if jax.default_backend().lower() == "tpu" else {} self._test_convert( - forward, inputs, + forward, + inputs, with_grad=with_grad, - enable_xla=enable_xla, with_custom_grad=with_custom_grad, - grad_tols=tols) + grad_tols=tols, + ) def test_explicit_native_serialization(self): def forward(x): @@ -1177,7 +1162,7 @@ def forward(x): tf_fn, tf.TensorSpec((2, 3), tf.float32) ) jax_fn = jax.jit(jax_fn) - inputs = np.linspace(-1., 1., 6, dtype=np.float32).reshape((2, 3)) + inputs = np.linspace(-1.0, 1.0, 6, dtype=np.float32).reshape((2, 3)) self.assertAllClose(jax_fn(inputs), tf_fn(inputs)) except ValueError as e: if uses_native_serialization(): diff --git a/tf2jax/_src/sharding_test.py b/tf2jax/_src/sharding_test.py index e516a5b..444afcc 100644 --- a/tf2jax/_src/sharding_test.py +++ b/tf2jax/_src/sharding_test.py @@ -16,13 +16,11 @@ from absl.testing import absltest from absl.testing import parameterized - import chex import haiku as hk import jax from jax.experimental import jax2tf import jax.numpy as jnp - import numpy as np import tensorflow as tf from tf2jax._src import test_util @@ -53,20 +51,14 @@ class ShardingTest(test_util.TestCase): @parameterized.named_parameters( chex.params_product( - (('enable_xla', True), ('disable_xla', False)), (('native_serialization', True), ('graph_serialization', False)), named=True, ) ) - def test_sharding(self, enable_xla, native_serialization): + def test_sharding(self, native_serialization): if jax.default_backend().upper() != 'TPU': self.skipTest('Only run sharding tests on TPU.') - if not enable_xla and native_serialization: - self.skipTest( - 'native_serializaton is only supported with enable_xla=True.' - ) - # Set up network and inputs. transformed = hk.without_apply_rng(hk.transform(_net)) rng = jax.random.PRNGKey(42) @@ -79,11 +71,12 @@ def test_sharding(self, enable_xla, native_serialization): # Partitioned to 8 devices. assert jax.device_count() == 8, jax.device_count() mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape((2, 4)), ('data', 'model')) + np.array(jax.devices()).reshape((2, 4)), ('data', 'model') + ) params_pspecs = _get_param_pspecs() + def to_xla_sharding(pspecs): - return jax.tree.map( - lambda x: jax.sharding.NamedSharding(mesh, x), pspecs) + return jax.tree.map(lambda x: jax.sharding.NamedSharding(mesh, x), pspecs) partitioned_apply = jax.jit( transformed.apply, @@ -116,7 +109,6 @@ def partitioned_grad(params, xs): def tf_fn(params, inputs): return jax2tf.convert( partitioned_apply, - enable_xla=enable_xla, native_serialization=native_serialization, )(params, inputs) @@ -145,6 +137,7 @@ def tf_fn(params, inputs): @jax.grad def reloaded_grad(params, xs): return jnp.sum(jax.jit(jax_fn)(params, xs)) + self.assertAllClose( jax.jit(unpartitioned_grad)(params, images), jax.jit(reloaded_grad)(params, images),