Skip to content

Commit

Permalink
Add the MulNoNan op.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669469935
  • Loading branch information
TF2JAXDev authored and TF2JAXDev committed Aug 30, 2024
1 parent 88f4665 commit 48a6a80
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,46 @@ def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray:
return _func


def _mul_no_nan_forward(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Returns zero if y is zero, even if x if infinite or NaN."""
mul = anp.multiply(x, y)
return jnp.where(y == 0.0, jnp.zeros_like(mul), mul)


@register_operation("MulNoNan")
def _mul_no_nan(proto):
"""Parse a MulNoNan op."""
_check_attrs(proto, {"T"})

@jax.custom_gradient
def _func(
x: jnp.ndarray, y: jnp.ndarray
) -> Tuple[
jnp.ndarray,
Callable[[jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]],
]:
def _grad(g: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Given upstream grad G and a Div op: Z = X*Y, the gradients are.
dX = G * Y
dY = X * G
Args:
g: Upstream gradient.
Returns:
forward value: mul_no_nan(x, y)
grad: Gradient information in TF format.
"""
dx = _mul_no_nan_forward(g, y)
dy = _mul_no_nan_forward(x, g)
return dx, dy

return _mul_no_nan_forward(x, y), _grad

return _func


@register_operation("OneHot")
def _one_hot(proto):
"""Parse a OneHot Op."""
Expand Down
31 changes: 31 additions & 0 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2276,6 +2276,37 @@ def div_no_nan(inputs):
with self.subTest("check_backward_pass"):
self.assertAllClose(jax_gradient, np.asarray(tf_gradient))

@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
(2., 2.), (0., 0.), (np.nan, 0.), (np.inf, 0.), (0., np.nan), (0., np.inf)
)
def test_mul_no_nan(self, x, y):

@tf.function
def mul_no_nan(inputs):
x, y = inputs
return tf.raw_ops.MulNoNan(x=x, y=y)

x = np.array(x)
y = np.array(y)
tf_x = tf.convert_to_tensor(x)
tf_y = tf.convert_to_tensor(y)

with tf.GradientTape() as g:
g.watch(tf_x)
g.watch(tf_y)
output = mul_no_nan([tf_x, tf_y])
tf_gradient = g.gradient(output, [tf_x, tf_y])

jax_func = tf2jax.convert_functional(mul_no_nan, [x, y])
jax_func = self.variant(jax_func)
jax_gradient = jax.grad(jax_func)([x, y])

with self.subTest("check_forward_pass"):
self.assertAllClose(jax_func([x, y]), np.asarray(output))
with self.subTest("check_backward_pass"):
self.assertAllClose(jax_gradient, np.asarray(tf_gradient))

@chex.variants(with_jit=True, without_jit=True)
def test_angle(self):
inputs = np.array([-2.25 + 4.75j, 3.25 + 5.75j], dtype=np.csingle)
Expand Down

0 comments on commit 48a6a80

Please sign in to comment.