Skip to content

Commit

Permalink
Fix gradient function lookup error when higher order gradient functio…
Browse files Browse the repository at this point in the history
…ns are referenced in a jax2tf model but not serialized.

PiperOrigin-RevId: 568665145
  • Loading branch information
shaobohou authored and TF2JAXDev committed Sep 27, 2023
1 parent b7a1b39 commit c089dde
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 3 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ include LICENSE
include requirements.txt
include requirements_tests.txt
include tf2jax/py.typed
recursive-include tf2jax/test_data *
38 changes: 36 additions & 2 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,10 +1044,38 @@ def _func(
class _IdentityN(_HigherOrderFunction):
"""Represents a IdentityN Op."""

name: str
gradient_op_type: str # For debug, custom_gradient is handled by _Subgraph.
with_custom_gradient: bool

def __call__(self, *args):
return args
@jax.custom_gradient
def _raise_func(
*operands: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, ...], Callable[..., Any]]:
def grad_fn(_):
raise LookupError(
f"Custom gradient `{self.gradient_op_type}` was expected but not"
f" found for the node `{self.name}` (op type: IdentityN). This"
" function is just a placeholder. The subgraph corresponding to"
" this IdentityN node should have been wrapped in"
" jax.custom_gradient with the actual gradient function found in"
" the TensorFlow gradient registry."
)
return operands, grad_fn

def _warn_func(*operands: jnp.ndarray) -> Tuple[jnp.ndarray, ...]:
logging.warn(
"Ignored custom gradient `%s` on the node `%s` (op type: IdentityN).",
self.gradient_op_type,
self.name,
)
return operands

if self.with_custom_gradient:
return _raise_func(*args)
else:
return _warn_func(*args)


@register_operation("IdentityN")
Expand All @@ -1059,7 +1087,13 @@ def _identity_n(proto):
if gradient_op_type:
logging.info("Found custom gradient %s", gradient_op_type)

return _IdentityN({}, gradient_op_type=gradient_op_type)
return _IdentityN(
{},
name=proto.name,
gradient_op_type=gradient_op_type,
# Caching the config at conversion time.
with_custom_gradient=config.get_config("convert_custom_gradient"),
)


class _IfOp(_HigherOrderFunction):
Expand Down
27 changes: 27 additions & 0 deletions tf2jax/_src/roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,33 @@ def grad(dy):
self.assertAllClose(expected_outputs, re_jax_outputs)
self.assertAllClose(expected_grads, re_jax_grads)

@chex.variants(with_jit=True)
def test_custom_gradient_saved_model(self):
model = tf.saved_model.load(
os.path.join(os.path.dirname(os.path.split(__file__)[0]), "test_data/custom_gradient_cubed")
)

x = np.array(42.0, dtype=np.float32)
jax_fn = tf2jax.convert_functional(model.f, np.array(0.0, dtype=np.float32))
jax_fn = self.variant(jax_fn)
jax_y = jax_fn(x)
jax_dy_dx = jax.grad(jax_fn)(x)

# TODO(b/302195165) This has to happen after the convert, otherwise there
# is an input lookup error in the gradient function.
x = tf.constant(42.0, dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
with tf.GradientTape() as tape2:
tape2.watch(x)
tf_y = model.f(x)
with self.assertRaises(LookupError):
_ = tape2.gradient(tf_y, x)
tf_dy_dx = tape.gradient(tf_y, x)

self.assertAllClose(jax_y, tf_y)
self.assertAllClose(jax_dy_dx, tf_dy_dx)

@chex.variants(with_jit=True)
@parameterized.named_parameters(
chex.params_product(
Expand Down
14 changes: 13 additions & 1 deletion tf2jax/_src/tf2jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,10 @@ def _extract_subgraphs(graphdef, nodes, library):
grad_fn_name = str(node.attr["_gradient_op_type"].s, "utf-8")
grad_fn = library[grad_fn_name]

# The gradient is not available (not serialized in a saved model?)
if grad_fn is None:
continue

output_node = op_map[node.name][1]
assert len(node.input) == len(output_node.inputs)

Expand Down Expand Up @@ -1384,7 +1388,7 @@ def _convert_all_gradient_functions(
def _convert_gradient_function(
proto: tf.compat.v1.NodeDef,
graph: Any,
library: Dict[str, _LibraryFunction],
library: Dict[str, Optional[_LibraryFunction]],
) -> None:
"""Convert a custom_gradient function."""
op = graph.as_graph_element(proto.name)
Expand All @@ -1393,6 +1397,14 @@ def _convert_gradient_function(
if grad_fn_name in library:
return

# Higher order gradient functions may be referenced in JAX2TF produced models
# but not actually serialized.
try:
tf_ops.gradient_registry.lookup(grad_fn_name)
except LookupError:
library[grad_fn_name] = None
return

@tf.function
def tf_grad_fn(*grad_args, **grad_kwargs):
fn = tf_ops.gradient_registry.lookup(grad_fn_name)
Expand Down
1 change: 1 addition & 0 deletions tf2jax/test_data/custom_gradient_cubed/fingerprint.pb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
���؍���Q�俖�ў���������� ��ɢ��<(��Տ�����2
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit c089dde

Please sign in to comment.