Skip to content

Commit

Permalink
Support refinement of polymorphic shapes given static input shapes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 540862936
  • Loading branch information
shaobohou authored and TF2JAXDev committed Jun 16, 2023
1 parent 8f87f82 commit 475aeff
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 3 deletions.
76 changes: 76 additions & 0 deletions tf2jax/_src/roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,82 @@ def forward(x):
new_tf_outputs = concrete_new_tf_forward(inputs)
self.assertAllClose(new_tf_outputs, jax_outputs)

@chex.variants(with_jit=True)
@parameterized.named_parameters(
chex.params_product(
(("without_gradient", False), ("with_gradient", True)),
(("disable_xla", False), ("enable_xla", True)),
named=True,
)
)
def test_polymorphic_shape_refinement_dot(self, with_grad, enable_xla):
if uses_native_serialization():
if not enable_xla:
self.skipTest("native_serialization does not support enable_xla=False.")

@jax.jit
def forward(x, w):
return jnp.dot(x, w)

x = np.array(range(12), dtype=np.float32).reshape((3, 4))
w = np.array(range(20), dtype=np.float32).reshape((4, 5))
expected_outputs = forward(x, w)

tf_fn = _jax2tf_convert(
forward,
polymorphic_shapes=["(b, _)", None],
with_gradient=with_grad,
enable_xla=enable_xla)
tf_fn = tf.function(tf_fn, autograph=False)
concrete_tf_fn = tf_fn.get_concrete_function(
tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(4, 5)))
tf_outputs = concrete_tf_fn(x, w)
self.assertAllClose(expected_outputs, tf_outputs)

jax_fn = tf2jax.convert_functional(
tf_fn, np.zeros_like(x), np.zeros_like(w)
)
jax_outputs = self.variant(jax_fn)(x, w)
self.assertAllClose(expected_outputs, jax_outputs)

@chex.variants(with_jit=True)
@parameterized.named_parameters(
chex.params_product(
(("without_gradient", False), ("with_gradient", True)),
(("disable_xla", False), ("enable_xla", True)),
named=True,
)
)
def test_polymorphic_shape_refinement_broadcast(self, with_grad, enable_xla):
if uses_native_serialization():
if not enable_xla:
self.skipTest("native_serialization does not support enable_xla=False.")

@jax.jit
def forward(x, y):
return (jnp.broadcast_to(x, y.shape), x + y)

x = np.array(range(12), dtype=np.float32).reshape((3, 4))
y = np.array(range(24), dtype=np.float32).reshape((2, 3, 4))
expected_outputs = forward(x, y)

tf_fn = _jax2tf_convert(
forward,
polymorphic_shapes=["(b, _)", "(_, b, _)"],
with_gradient=with_grad,
enable_xla=enable_xla)
tf_fn = tf.function(tf_fn, autograph=False)
concrete_tf_fn = tf_fn.get_concrete_function(
tf.TensorSpec(shape=(None, 4)), tf.TensorSpec(shape=(2, None, 4)))
tf_outputs = concrete_tf_fn(x, y)
self.assertAllClose(expected_outputs, tf_outputs)

jax_fn = tf2jax.convert_functional(
tf_fn, np.zeros_like(x), np.zeros_like(y)
)
jax_outputs = self.variant(jax_fn)(x, y)
self.assertAllClose(expected_outputs, jax_outputs)

@chex.variants(with_jit=True)
@parameterized.named_parameters(
chex.params_product(
Expand Down
25 changes: 25 additions & 0 deletions tf2jax/experimental/mhlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,31 @@ def mhlo_apply_abstract_eval(*args, module: MhloModule):
mhlo_apply_p.def_abstract_eval(mhlo_apply_abstract_eval)


# Taken from
# github.com/google/jax/blob/main/jax/experimental/jax2tf/jax_export.py#L859
def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
"""Refine the polymorphic shapes inside a module.
Given a module with static input shapes, but using dynamic shapes due to
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
Args:
module: A module with static input shapes but dynamic shapes inside.
Returns:
The refined module.
"""
if xc.mlir_api_version < 50:
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")

refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access
mlir.module_to_bytecode(module)
)
context = mlir.make_ir_context()
with context:
return ir.Module.parse(refined_module_str)


def mhlo_apply_lowering(
ctx: mlir.LoweringRuleContext, *args, module: MhloModule
):
Expand Down
73 changes: 70 additions & 3 deletions tf2jax/experimental/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# ==============================================================================
"""Experimental ops"."""

from typing import Tuple
import functools
from typing import List, Tuple

from absl import logging

Expand All @@ -39,6 +40,69 @@ def _platform_to_alias(platform: str) -> str:
return aliases.get(platform, platform)


# Adapted from
# https://github.com/google/jax/commit/ec8b855fa16962b1394716622c8cbc006ce76b1c
@functools.lru_cache(None)
def _refine_with_static_input_shapes(
module_text: str, operands: Tuple[jax.core.ShapedArray, ...]
) -> str:
"""Refine the polymorphic shapes inside a module."""
# Wrap original main within another function with static input shapes.
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
module = ir.Module.parse(module_text)
symbol_table = ir.SymbolTable(module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
symbol_table.set_symbol_name(orig_main, "_orig_main")
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value

# Use static shapes
new_main_input_types = [mlir.aval_to_ir_type(x) for x in operands]
orig_output_types = orig_main.type.results
new_main_ftype = ir.FunctionType.get(
new_main_input_types, orig_output_types
)
new_main_op = mlir.func_dialect.FuncOp(
"main",
new_main_ftype,
ip=ir.InsertionPoint.at_block_begin(module.body),
)

try:
new_main_op.attributes["arg_attrs"] = ir.ArrayAttr(orig_main.arg_attrs)
assert new_main_op.arg_attrs == orig_main.arg_attrs
except KeyError:
pass
try:
new_main_op.attributes["res_attrs"] = ir.ArrayAttr(orig_main.result_attrs)
assert new_main_op.result_attrs == orig_main.result_attrs
except KeyError:
pass
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()

with ir.InsertionPoint(entry_block):
orig_main_args: List[ir.Value] = []
for new_arg, orig_arg_type in jax.util.safe_zip(
new_main_op.arguments, orig_main.type.inputs
):
# TODO(shaobohou) Why is the ConvertOp needed?
orig_main_args.append(mlir.hlo.ConvertOp(orig_arg_type, new_arg).result)
call = mlir.func_dialect.CallOp(
orig_output_types,
ir.FlatSymbolRefAttr.get(orig_main_name),
orig_main_args,
)
mlir.func_dialect.ReturnOp(call.results)
symbol_table.set_symbol_name(new_main_op, "main")

# Refinement passes.
module = mhlo.refine_polymorphic_shapes(module)
return mlir.module_to_string(module)


@ops.register_operation("XlaCallModule")
def _xla_call_module(proto):
"""Parse a XlaCallModule op."""
Expand Down Expand Up @@ -136,10 +200,13 @@ def check_platforms():
module = ir.Module.parse(proto.attr["module"].s)
mhlo_text = mlir.module_to_string(module)

mhlo_module = mhlo.MhloModule(module=mhlo_text, fun_name=proto.name)

def _func(*operands: jnp.ndarray) -> Tuple[jnp.ndarray, ...]:
check_platforms()
refined_mhlo_text = _refine_with_static_input_shapes(
mhlo_text,
tuple(jax.core.ShapedArray(x.shape, x.dtype) for x in operands),
)
mhlo_module = mhlo.MhloModule(module=refined_mhlo_text, fun_name=proto.name)
return mhlo.mhlo_apply(*operands, module=mhlo_module)

return _func

0 comments on commit 475aeff

Please sign in to comment.