From 24cca5836bdb2958b7962e044dcc0e60c4b361ae Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 15 Nov 2023 01:40:31 -0800 Subject: [PATCH] [JAX] Change users of jnp.where() to pass the condition, x, and y arguments as positional arguments. Support for passing the condition, x, and y arguments via keyword arguments is being removed from jax.numpy.where() to match numpy.where(). PiperOrigin-RevId: 582586452 Change-Id: I5e8711685b9e50849e5941df20805ca6afaf2355 --- lightweight_mmm/core/core_utils.py | 4 ++-- lightweight_mmm/core/transformations/saturation.py | 2 +- lightweight_mmm/media_transforms.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lightweight_mmm/core/core_utils.py b/lightweight_mmm/core/core_utils.py index c128baa..ea2bd39 100644 --- a/lightweight_mmm/core/core_utils.py +++ b/lightweight_mmm/core/core_utils.py @@ -72,5 +72,5 @@ def apply_exponent_safe(data: jnp.ndarray, Returns: The result of the exponent operation with the inputs provided. """ - exponent_safe = jnp.where(condition=(data == 0), x=1, y=data)**exponent - return jnp.where(condition=(data == 0), x=0, y=exponent_safe) + exponent_safe = jnp.where(data == 0, 1, data) ** exponent + return jnp.where(data == 0, 0, exponent_safe) diff --git a/lightweight_mmm/core/transformations/saturation.py b/lightweight_mmm/core/transformations/saturation.py index 26a9b81..8a6f4df 100644 --- a/lightweight_mmm/core/transformations/saturation.py +++ b/lightweight_mmm/core/transformations/saturation.py @@ -45,7 +45,7 @@ def _hill( """ save_transform = core_utils.apply_exponent_safe( data=data / half_max_effective_concentration, exponent=-slope) - return jnp.where(save_transform == 0, x=0, y=1. / (1 + save_transform)) + return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform)) def hill( diff --git a/lightweight_mmm/media_transforms.py b/lightweight_mmm/media_transforms.py index 6517afa..4edd036 100644 --- a/lightweight_mmm/media_transforms.py +++ b/lightweight_mmm/media_transforms.py @@ -112,7 +112,7 @@ def hill(data: jnp.ndarray, half_max_effective_concentration: jnp.ndarray, """ save_transform = apply_exponent_safe( data=data / half_max_effective_concentration, exponent=-slope) - return jnp.where(save_transform == 0, x=0, y=1. / (1 + save_transform)) + return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform)) @functools.partial(jax.vmap, in_axes=(1, 1, None), out_axes=1) @@ -186,5 +186,5 @@ def apply_exponent_safe( Returns: The result of the exponent operation with the inputs provided. """ - exponent_safe = jnp.where(condition=(data == 0), x=1, y=data) ** exponent - return jnp.where(condition=(data == 0), x=0, y=exponent_safe) + exponent_safe = jnp.where(data == 0, 1, data) ** exponent + return jnp.where(data == 0, 0, exponent_safe)