Skip to content

Commit

Permalink
Original hill function definition (pymc-labs#925)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz authored and radiokosmos committed Sep 1, 2024
1 parent e95e51c commit c96d959
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 28 deletions.
Binary file modified docs/source/uml/classes_mmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from pymc_marketing.mmm.components.saturation import (
HillSaturation,
HillSaturationSigmoid,
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
Expand All @@ -50,6 +51,7 @@
"DelayedSaturatedMMM",
"GeometricAdstock",
"HillSaturation",
"HillSaturationSigmoid",
"LogisticSaturation",
"InverseScaledLogisticSaturation",
"MMM",
Expand Down
42 changes: 39 additions & 3 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def function(self, x, b):

from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
hill_saturation,
hill_function,
hill_saturation_sigmoid,
inverse_scaled_logistic_saturation,
logistic_saturation,
michaelis_menten,
Expand Down Expand Up @@ -343,7 +344,7 @@ class MichaelisMentenSaturation(SaturationTransformation):
class HillSaturation(SaturationTransformation):
"""Wrapper around Hill saturation function.
For more information, see :func:`pymc_marketing.mmm.transformers.hill_saturation`.
For more information, see :func:`pymc_marketing.mmm.transformers.hill_function`.
.. plot::
:context: close-figs
Expand All @@ -364,7 +365,41 @@ class HillSaturation(SaturationTransformation):

lookup_name = "hill"

function = hill_saturation
def function(self, x, slope, kappa, beta):
return beta * hill_function(x, slope, kappa)

default_priors = {
"slope": Prior("HalfNormal", sigma=1.5),
"kappa": Prior("HalfNormal", sigma=1.5),
"beta": Prior("HalfNormal", sigma=1.5),
}


class HillSaturationSigmoid(SaturationTransformation):
"""Wrapper around Hill saturation sigmoid function.
For more information, see :func:`pymc_marketing.mmm.transformers.hill_saturation_sigmoid`.
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import HillSaturationSigmoid
rng = np.random.default_rng(0)
adstock = HillSaturationSigmoid()
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()
"""

lookup_name = "hill_sigmoid"

function = hill_saturation_sigmoid

default_priors = {
"sigma": Prior("HalfNormal", sigma=1.5),
Expand Down Expand Up @@ -415,6 +450,7 @@ def function(self, x, alpha, beta):
TanhSaturationBaselined,
MichaelisMentenSaturation,
HillSaturation,
HillSaturationSigmoid,
RootSaturation,
]
}
Expand Down
97 changes: 85 additions & 12 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,13 +899,86 @@ def michaelis_menten(
return alpha * x / (lam + x)


def hill_saturation(
def hill_function(
x: pt.TensorLike, slope: pt.TensorLike, kappa: pt.TensorLike
) -> pt.TensorVariable:
r"""Hill Function
.. math::
f(x) = 1 - \frac{\kappa^s}{\kappa^s + x^s}
where:
- :math:`s` is the slope of the hill.
- :math:`\kappa` is the half-saturation point as :math:`f(\kappa) = 0.5` for any value of :math:`s` and :math:`\kappa`.
- :math:`x` is the independent variable and must be non-negative.
Hill function from Equation (5) in the paper [1]_.
.. plot::
:context: close-figs
import numpy as np
import matplotlib.pyplot as plt
from pymc_marketing.mmm.transformers import hill_function
x = np.linspace(0, 10, 100)
# Varying slope
slopes = [0.3, 0.7, 1.2]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, slope in enumerate(slopes):
plt.subplot(1, 3, i+1)
y = hill_function(x, slope, 2).eval()
plt.plot(x, y)
plt.xlabel('x')
plt.title(f'Slope = {slope}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()
# Varying kappa
kappas = [1, 5, 10]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, kappa in enumerate(kappas):
plt.subplot(1, 3, i+1)
y = hill_function(x, 1, kappa).eval()
plt.plot(x, y)
plt.xlabel('x')
plt.title(f'Kappa = {kappa}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()
Parameters
----------
x : float or array-like
The independent variable, typically representing the concentration of a
substrate or the intensity of a stimulus.
slope : float
The slope of the hill. Must pe non-positive.
kappa : float
The half-saturation point as :math:`f(\kappa) = 0.5` for any value of :math:`s` and :math:`\kappa`.
Returns
-------
float
The value of the Hill function given the parameters.
References
----------
.. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017).
""" # noqa: E501
return pt.as_tensor_variable(
1 - pt.power(kappa, slope) / (pt.power(kappa, slope) + pt.power(x, slope))
)


def hill_saturation_sigmoid(
x: pt.TensorLike,
sigma: pt.TensorLike,
beta: pt.TensorLike,
lam: pt.TensorLike,
) -> pt.TensorVariable:
r"""Hill Saturation Function
r"""Hill Saturation Sigmoid Function
.. math::
f(x) = \frac{\sigma}{1 + e^{-\beta(x - \lambda)}} - \frac{\sigma}{1 + e^{\beta\lambda}}
Expand All @@ -929,45 +1002,45 @@ def hill_saturation(
import numpy as np
import matplotlib.pyplot as plt
from pymc_marketing.mmm.transformers import hill_saturation
from pymc_marketing.mmm.transformers import hill_saturation_sigmoid
x = np.linspace(0, 10, 100)
# Varying sigma
sigmas = [0.5, 1, 1.5]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, sigma in enumerate(sigmas):
plt.subplot(1, 3, i+1)
y = hill_saturation(x, sigma, 2, 5).eval()
y = hill_saturation_sigmoid(x, sigma, 2, 5).eval()
plt.plot(x, y)
plt.xlabel('x')
plt.title(f'Sigma = {sigma}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation')
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()
# Varying beta
betas = [1, 2, 3]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, beta in enumerate(betas):
plt.subplot(1, 3, i+1)
y = hill_saturation(x, 1, beta, 5).eval()
y = hill_saturation_sigmoid(x, 1, beta, 5).eval()
plt.plot(x, y)
plt.xlabel('x')
plt.title(f'Beta = {beta}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation')
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()
# Varying lam
lams = [3, 5, 7]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, lam in enumerate(lams):
plt.subplot(1, 3, i+1)
y = hill_saturation(x, 1, 2, lam).eval()
y = hill_saturation_sigmoid(x, 1, 2, lam).eval()
plt.plot(x, y)
plt.xlabel('x')
plt.title(f'Lambda = {lam}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation')
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()
Expand All @@ -977,8 +1050,8 @@ def hill_saturation(
The independent variable, typically representing the concentration of a
substrate or the intensity of a stimulus.
sigma : float
The upper asymptote of the curve, representing the maximum value the
function will approach as x grows large.
The upper asymptote of the curve, representing the approximate maximum value the
function will approach as x grows large. The true maximum value is at `sigma * (1 - 1 / (1 + exp(beta * lam)))`
beta : float
The slope parameter, determining the steepness of the curve.
lam : float
Expand All @@ -988,7 +1061,7 @@ def hill_saturation(
Returns
-------
float or array-like
The value of the Hill function for each input value of x.
The value of the Hill saturation sigmoid function for each input value of x.
"""
return sigma / (1 + pt.exp(-beta * (x - lam))) - sigma / (1 + pt.exp(beta * lam))

Expand Down
3 changes: 3 additions & 0 deletions tests/mmm/components/test_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pymc_marketing.mmm import (
HillSaturation,
HillSaturationSigmoid,
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
Expand All @@ -48,6 +49,7 @@ def saturation_functions():
TanhSaturationBaselined(),
MichaelisMentenSaturation(),
HillSaturation(),
HillSaturationSigmoid(),
RootSaturation(),
]

Expand Down Expand Up @@ -104,6 +106,7 @@ def test_support_for_lift_test_integrations(saturation) -> None:
("tanh_baselined", TanhSaturationBaselined),
("michaelis_menten", MichaelisMentenSaturation),
("hill", HillSaturation),
("hill_sigmoid", HillSaturationSigmoid),
("root", RootSaturation),
],
)
Expand Down
Loading

0 comments on commit c96d959

Please sign in to comment.