diff --git a/lightweight_mmm/core/core_utils.py b/lightweight_mmm/core/core_utils.py index c128baa..ea2bd39 100644 --- a/lightweight_mmm/core/core_utils.py +++ b/lightweight_mmm/core/core_utils.py @@ -72,5 +72,5 @@ def apply_exponent_safe(data: jnp.ndarray, Returns: The result of the exponent operation with the inputs provided. """ - exponent_safe = jnp.where(condition=(data == 0), x=1, y=data)**exponent - return jnp.where(condition=(data == 0), x=0, y=exponent_safe) + exponent_safe = jnp.where(data == 0, 1, data) ** exponent + return jnp.where(data == 0, 0, exponent_safe) diff --git a/lightweight_mmm/core/transformations/saturation.py b/lightweight_mmm/core/transformations/saturation.py index 26a9b81..8a6f4df 100644 --- a/lightweight_mmm/core/transformations/saturation.py +++ b/lightweight_mmm/core/transformations/saturation.py @@ -45,7 +45,7 @@ def _hill( """ save_transform = core_utils.apply_exponent_safe( data=data / half_max_effective_concentration, exponent=-slope) - return jnp.where(save_transform == 0, x=0, y=1. / (1 + save_transform)) + return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform)) def hill( diff --git a/lightweight_mmm/media_transforms.py b/lightweight_mmm/media_transforms.py index 6517afa..4edd036 100644 --- a/lightweight_mmm/media_transforms.py +++ b/lightweight_mmm/media_transforms.py @@ -112,7 +112,7 @@ def hill(data: jnp.ndarray, half_max_effective_concentration: jnp.ndarray, """ save_transform = apply_exponent_safe( data=data / half_max_effective_concentration, exponent=-slope) - return jnp.where(save_transform == 0, x=0, y=1. / (1 + save_transform)) + return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform)) @functools.partial(jax.vmap, in_axes=(1, 1, None), out_axes=1) @@ -186,5 +186,5 @@ def apply_exponent_safe( Returns: The result of the exponent operation with the inputs provided. """ - exponent_safe = jnp.where(condition=(data == 0), x=1, y=data) ** exponent - return jnp.where(condition=(data == 0), x=0, y=exponent_safe) + exponent_safe = jnp.where(data == 0, 1, data) ** exponent + return jnp.where(data == 0, 0, exponent_safe)