Skip to content

Commit

Permalink
Remove extra definitions and wrap as_tensor_variable
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloRoque committed Sep 23, 2024
1 parent 245f771 commit 04e6367
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 32 deletions.
5 changes: 4 additions & 1 deletion pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def function(self, x, b):
"""

import numpy as np
import pytensor.tensor as pt
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

Expand Down Expand Up @@ -337,7 +338,9 @@ class MichaelisMentenSaturation(SaturationTransformation):

lookup_name = "michaelis_menten"

function = michaelis_menten
def function(self, x, alpha, lam):
"""Michaelis-Menten saturation function."""
return pt.as_tensor_variable(michaelis_menten(x, alpha, lam))

default_priors = {
"alpha": Prior("Gamma", mu=2, sigma=1),
Expand Down
27 changes: 1 addition & 26 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ def tanh_saturation_baselined(
return gain * x0 * pt.tanh(x * pt.arctanh(r) / x0) / r


def michaelis_menten_function(
def michaelis_menten(
x: float | np.ndarray | npt.NDArray[np.float64],
alpha: float,
lam: float,
Expand Down Expand Up @@ -914,31 +914,6 @@ def michaelis_menten_function(
return alpha * x / (lam + x)


def michaelis_menten(
x: float | np.ndarray | npt.NDArray[np.float64],
alpha: float,
lam: float,
) -> pt.TensorVariable:
r"""TensorVariable wrap over the Michaelis-Menten transformation.
Parameters
----------
x : float
The spent on a channel.
alpha : float
The maximum contribution a channel can make.
lam : float
The Michaelis constant for the given enzyme-substrate system.
Returns
-------
pt.TensorVariable
The value of the Michaelis-Menten function given the parameters as a TensorVariable.
"""
return pt.as_tensor_variable(michaelis_menten_function(x, alpha, lam))


def hill_function(
x: pt.TensorLike, slope: pt.TensorLike, kappa: pt.TensorLike
) -> pt.TensorVariable:
Expand Down
6 changes: 2 additions & 4 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import xarray as xr
from scipy.optimize import curve_fit, minimize_scalar

from pymc_marketing.mmm.transformers import michaelis_menten_function
from pymc_marketing.mmm.transformers import michaelis_menten


def estimate_menten_parameters(
Expand Down Expand Up @@ -67,9 +67,7 @@ def estimate_menten_parameters(
# Initial guess for L and k
initial_guess = [alpha_initial_estimate, lam_initial_estimate]
# Curve fitting
popt, _ = curve_fit(
michaelis_menten_function, x, y, p0=initial_guess, maxfev=maxfev
)
popt, _ = curve_fit(michaelis_menten, x, y, p0=initial_guess, maxfev=maxfev)

# Save the parameters
return popt
Expand Down
2 changes: 1 addition & 1 deletion tests/mmm/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_tanh_saturation_parameterization_transformation(self, x, b, c):
],
)
def test_michaelis_menten(self, x, alpha, lam, expected):
assert np.isclose(michaelis_menten(x, alpha, lam).eval(), expected, atol=0.01)
assert np.isclose(michaelis_menten(x, alpha, lam), expected, atol=0.01)

@pytest.mark.parametrize(
"sigma, beta, lam",
Expand Down

0 comments on commit 04e6367

Please sign in to comment.