From 475aeff4c141244c778ec4b872755efca4ae0787 Mon Sep 17 00:00:00 2001 From: Shaobo Hou Date: Fri, 16 Jun 2023 06:38:04 -0700 Subject: [PATCH] Support refinement of polymorphic shapes given static input shapes. PiperOrigin-RevId: 540862936 --- tf2jax/_src/roundtrip_test.py | 76 +++++++++++++++++++++++++++++++++++ tf2jax/experimental/mhlo.py | 25 ++++++++++++ tf2jax/experimental/ops.py | 73 +++++++++++++++++++++++++++++++-- 3 files changed, 171 insertions(+), 3 deletions(-) diff --git a/tf2jax/_src/roundtrip_test.py b/tf2jax/_src/roundtrip_test.py index 812bcb2..178051a 100644 --- a/tf2jax/_src/roundtrip_test.py +++ b/tf2jax/_src/roundtrip_test.py @@ -692,6 +692,82 @@ def forward(x): new_tf_outputs = concrete_new_tf_forward(inputs) self.assertAllClose(new_tf_outputs, jax_outputs) + @chex.variants(with_jit=True) + @parameterized.named_parameters( + chex.params_product( + (("without_gradient", False), ("with_gradient", True)), + (("disable_xla", False), ("enable_xla", True)), + named=True, + ) + ) + def test_polymorphic_shape_refinement_dot(self, with_grad, enable_xla): + if uses_native_serialization(): + if not enable_xla: + self.skipTest("native_serialization does not support enable_xla=False.") + + @jax.jit + def forward(x, w): + return jnp.dot(x, w) + + x = np.array(range(12), dtype=np.float32).reshape((3, 4)) + w = np.array(range(20), dtype=np.float32).reshape((4, 5)) + expected_outputs = forward(x, w) + + tf_fn = _jax2tf_convert( + forward, + polymorphic_shapes=["(b, _)", None], + with_gradient=with_grad, + enable_xla=enable_xla) + tf_fn = tf.function(tf_fn, autograph=False) + concrete_tf_fn = tf_fn.get_concrete_function( + tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(4, 5))) + tf_outputs = concrete_tf_fn(x, w) + self.assertAllClose(expected_outputs, tf_outputs) + + jax_fn = tf2jax.convert_functional( + tf_fn, np.zeros_like(x), np.zeros_like(w) + ) + jax_outputs = self.variant(jax_fn)(x, w) + self.assertAllClose(expected_outputs, jax_outputs) + + @chex.variants(with_jit=True) + @parameterized.named_parameters( + chex.params_product( + (("without_gradient", False), ("with_gradient", True)), + (("disable_xla", False), ("enable_xla", True)), + named=True, + ) + ) + def test_polymorphic_shape_refinement_broadcast(self, with_grad, enable_xla): + if uses_native_serialization(): + if not enable_xla: + self.skipTest("native_serialization does not support enable_xla=False.") + + @jax.jit + def forward(x, y): + return (jnp.broadcast_to(x, y.shape), x + y) + + x = np.array(range(12), dtype=np.float32).reshape((3, 4)) + y = np.array(range(24), dtype=np.float32).reshape((2, 3, 4)) + expected_outputs = forward(x, y) + + tf_fn = _jax2tf_convert( + forward, + polymorphic_shapes=["(b, _)", "(_, b, _)"], + with_gradient=with_grad, + enable_xla=enable_xla) + tf_fn = tf.function(tf_fn, autograph=False) + concrete_tf_fn = tf_fn.get_concrete_function( + tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(2, None, 4))) + tf_outputs = concrete_tf_fn(x, y) + self.assertAllClose(expected_outputs, tf_outputs) + + jax_fn = tf2jax.convert_functional( + tf_fn, np.zeros_like(x), np.zeros_like(y) + ) + jax_outputs = self.variant(jax_fn)(x, y) + self.assertAllClose(expected_outputs, jax_outputs) + @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( diff --git a/tf2jax/experimental/mhlo.py b/tf2jax/experimental/mhlo.py index 240fa7c..dcd851a 100644 --- a/tf2jax/experimental/mhlo.py +++ b/tf2jax/experimental/mhlo.py @@ -104,6 +104,31 @@ def mhlo_apply_abstract_eval(*args, module: MhloModule): mhlo_apply_p.def_abstract_eval(mhlo_apply_abstract_eval) +# Taken from +# github.com/google/jax/blob/main/jax/experimental/jax2tf/jax_export.py#L859 +def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: + """Refine the polymorphic shapes inside a module. + + Given a module with static input shapes, but using dynamic shapes due to + shape polymorphism, run shape refinement to resolve all the dynamic shapes. + + Args: + module: A module with static input shapes but dynamic shapes inside. + + Returns: + The refined module. + """ + if xc.mlir_api_version < 50: + raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") + + refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access + mlir.module_to_bytecode(module) + ) + context = mlir.make_ir_context() + with context: + return ir.Module.parse(refined_module_str) + + def mhlo_apply_lowering( ctx: mlir.LoweringRuleContext, *args, module: MhloModule ): diff --git a/tf2jax/experimental/ops.py b/tf2jax/experimental/ops.py index 16d8205..5c855a6 100644 --- a/tf2jax/experimental/ops.py +++ b/tf2jax/experimental/ops.py @@ -14,7 +14,8 @@ # ============================================================================== """Experimental ops".""" -from typing import Tuple +import functools +from typing import List, Tuple from absl import logging @@ -39,6 +40,69 @@ def _platform_to_alias(platform: str) -> str: return aliases.get(platform, platform) +# Adapted from +# https://github.com/google/jax/commit/ec8b855fa16962b1394716622c8cbc006ce76b1c +@functools.lru_cache(None) +def _refine_with_static_input_shapes( + module_text: str, operands: Tuple[jax.core.ShapedArray, ...] +) -> str: + """Refine the polymorphic shapes inside a module.""" + # Wrap original main within another function with static input shapes. + context = mlir.make_ir_context() + with context, ir.Location.unknown(context): + module = ir.Module.parse(module_text) + symbol_table = ir.SymbolTable(module.operation) + orig_main = symbol_table["main"] + orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") + symbol_table.set_symbol_name(orig_main, "_orig_main") + orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value + + # Use static shapes + new_main_input_types = [mlir.aval_to_ir_type(x) for x in operands] + orig_output_types = orig_main.type.results + new_main_ftype = ir.FunctionType.get( + new_main_input_types, orig_output_types + ) + new_main_op = mlir.func_dialect.FuncOp( + "main", + new_main_ftype, + ip=ir.InsertionPoint.at_block_begin(module.body), + ) + + try: + new_main_op.attributes["arg_attrs"] = ir.ArrayAttr(orig_main.arg_attrs) + assert new_main_op.arg_attrs == orig_main.arg_attrs + except KeyError: + pass + try: + new_main_op.attributes["res_attrs"] = ir.ArrayAttr(orig_main.result_attrs) + assert new_main_op.result_attrs == orig_main.result_attrs + except KeyError: + pass + new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public") + symbol_table.insert(new_main_op) + entry_block = new_main_op.add_entry_block() + + with ir.InsertionPoint(entry_block): + orig_main_args: List[ir.Value] = [] + for new_arg, orig_arg_type in jax.util.safe_zip( + new_main_op.arguments, orig_main.type.inputs + ): + # TODO(shaobohou) Why is the ConvertOp needed? + orig_main_args.append(mlir.hlo.ConvertOp(orig_arg_type, new_arg).result) + call = mlir.func_dialect.CallOp( + orig_output_types, + ir.FlatSymbolRefAttr.get(orig_main_name), + orig_main_args, + ) + mlir.func_dialect.ReturnOp(call.results) + symbol_table.set_symbol_name(new_main_op, "main") + + # Refinement passes. + module = mhlo.refine_polymorphic_shapes(module) + return mlir.module_to_string(module) + + @ops.register_operation("XlaCallModule") def _xla_call_module(proto): """Parse a XlaCallModule op.""" @@ -136,10 +200,13 @@ def check_platforms(): module = ir.Module.parse(proto.attr["module"].s) mhlo_text = mlir.module_to_string(module) - mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name=proto.name) - def _func(*operands: jnp.ndarray) -> Tuple[jnp.ndarray, ...]: check_platforms() + refined_mhlo_text = _refine_with_static_input_shapes( + mhlo_text, + tuple(jax.core.ShapedArray(x.shape, x.dtype) for x in operands), + ) + mhlo_module = mhlo.MhloModule(module=refined_mhlo_text, fun_name=proto.name) return mhlo.mhlo_apply(*operands, module=mhlo_module) return _func