Skip to content

Commit

Permalink
[JAX] Change users of jnp.where() to pass the condition, x, and y arg…
Browse files Browse the repository at this point in the history
…uments as positional arguments.

Support for passing the condition, x, and y arguments via keyword arguments is being removed from jax.numpy.where() to match numpy.where().

PiperOrigin-RevId: 582586452
Change-Id: I5e8711685b9e50849e5941df20805ca6afaf2355
  • Loading branch information
hawkinsp authored and copybara-github committed Nov 15, 2023
1 parent 8152c43 commit 24cca58
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions lightweight_mmm/core/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ def apply_exponent_safe(data: jnp.ndarray,
Returns:
The result of the exponent operation with the inputs provided.
"""
exponent_safe = jnp.where(condition=(data == 0), x=1, y=data)**exponent
return jnp.where(condition=(data == 0), x=0, y=exponent_safe)
exponent_safe = jnp.where(data == 0, 1, data) ** exponent
return jnp.where(data == 0, 0, exponent_safe)
2 changes: 1 addition & 1 deletion lightweight_mmm/core/transformations/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _hill(
"""
save_transform = core_utils.apply_exponent_safe(
data=data / half_max_effective_concentration, exponent=-slope)
return jnp.where(save_transform == 0, x=0, y=1. / (1 + save_transform))
return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform))


def hill(
Expand Down
6 changes: 3 additions & 3 deletions lightweight_mmm/media_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def hill(data: jnp.ndarray, half_max_effective_concentration: jnp.ndarray,
"""
save_transform = apply_exponent_safe(
data=data / half_max_effective_concentration, exponent=-slope)
return jnp.where(save_transform == 0, x=0, y=1. / (1 + save_transform))
return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform))


@functools.partial(jax.vmap, in_axes=(1, 1, None), out_axes=1)
Expand Down Expand Up @@ -186,5 +186,5 @@ def apply_exponent_safe(
Returns:
The result of the exponent operation with the inputs provided.
"""
exponent_safe = jnp.where(condition=(data == 0), x=1, y=data) ** exponent
return jnp.where(condition=(data == 0), x=0, y=exponent_safe)
exponent_safe = jnp.where(data == 0, 1, data) ** exponent
return jnp.where(data == 0, 0, exponent_safe)

0 comments on commit 24cca58

Please sign in to comment.