diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index e480f81f..7cba45aa 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -342,6 +342,27 @@ class MichaelisMentenSaturation(SaturationTransformation): class HillSaturation(SaturationTransformation): + """Wrapper around Hill saturation function. + + For more information, see :func:`pymc_marketing.mmm.transformers.hill_function`. + + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import HillSaturation + + rng = np.random.default_rng(0) + + adstock = HillSaturation() + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + + """ + lookup_name = "hill" def function(self, x, slope, kappa, beta): diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index bf0bf47b..ab5796d6 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -63,11 +63,11 @@ class BaseMMM(BaseValidateMMM): """ - Base class for a media mix model using Delayed Adstock and Logistic Saturation (see [1]_). + Base class for a media mix model using Delayed Adstock and Logistic Saturation (see [1]_). References - ---------- - .. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017). + ---------- + .. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017). """ _model_name: str = "BaseMMM" diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index 58f1df07..2ae28d86 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -902,6 +902,71 @@ def michaelis_menten( def hill_function( x: pt.TensorLike, slope: pt.TensorLike, kappa: pt.TensorLike ) -> pt.TensorVariable: + r"""Hill Function + + .. math:: + f(x) = 1 - \frac{\kappa^s}{\kappa^s + x^s} + + where: + - :math:`s` is the slope of the hill. + - :math:`\kappa` is the half-saturation point as :math:`f(\kappa) = 0.5` for any value of :math:`s` and :math:`\kappa`. + - :math:`x` is the independent variable and must be non-negative. + + Hill function from Equation (5) in the paper [1]_. + + .. plot:: + :context: close-figs + + import numpy as np + import matplotlib.pyplot as plt + from pymc_marketing.mmm.transformers import hill_function + x = np.linspace(0, 10, 100) + # Varying slope + slopes = [0.3, 0.7, 1.2] + fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) + for i, slope in enumerate(slopes): + plt.subplot(1, 3, i+1) + y = hill_function(x, slope, 2).eval() + plt.plot(x, y) + plt.xlabel('x') + plt.title(f'Slope = {slope}') + plt.subplot(1,3,1) + plt.ylabel('Hill Saturation Sigmoid') + plt.tight_layout() + plt.show() + # Varying kappa + kappas = [1, 5, 10] + fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) + for i, kappa in enumerate(kappas): + plt.subplot(1, 3, i+1) + y = hill_function(x, 1, kappa).eval() + plt.plot(x, y) + plt.xlabel('x') + plt.title(f'Kappa = {kappa}') + plt.subplot(1,3,1) + plt.ylabel('Hill Saturation Sigmoid') + plt.tight_layout() + plt.show() + + Parameters + ---------- + x : float or array-like + The independent variable, typically representing the concentration of a + substrate or the intensity of a stimulus. + slope : float + The slope of the hill. Must pe non-positive. + kappa : float + The half-saturation point as :math:`f(\kappa) = 0.5` for any value of :math:`s` and :math:`\kappa`. + + Returns + ------- + float + The value of the Hill function given the parameters. + + References + ---------- + .. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017). + """ # noqa: E501 return pt.as_tensor_variable( 1 - pt.power(kappa, slope) / (pt.power(kappa, slope) + pt.power(x, slope)) ) @@ -949,7 +1014,7 @@ def hill_saturation_sigmoid( plt.xlabel('x') plt.title(f'Sigma = {sigma}') plt.subplot(1,3,1) - plt.ylabel('Hill Saturation') + plt.ylabel('Hill Saturation Sigmoid') plt.tight_layout() plt.show() # Varying beta @@ -962,7 +1027,7 @@ def hill_saturation_sigmoid( plt.xlabel('x') plt.title(f'Beta = {beta}') plt.subplot(1,3,1) - plt.ylabel('Hill Saturation') + plt.ylabel('Hill Saturation Sigmoid') plt.tight_layout() plt.show() # Varying lam @@ -975,7 +1040,7 @@ def hill_saturation_sigmoid( plt.xlabel('x') plt.title(f'Lambda = {lam}') plt.subplot(1,3,1) - plt.ylabel('Hill Saturation') + plt.ylabel('Hill Saturation Sigmoid') plt.tight_layout() plt.show() diff --git a/tests/mmm/components/test_saturation.py b/tests/mmm/components/test_saturation.py index abc44d7f..cebd9cd6 100644 --- a/tests/mmm/components/test_saturation.py +++ b/tests/mmm/components/test_saturation.py @@ -48,6 +48,7 @@ def saturation_functions(): TanhSaturation(), TanhSaturationBaselined(), MichaelisMentenSaturation(), + HillSaturation(), HillSaturationSigmoid(), RootSaturation(), ]