diff --git a/pymc_marketing/mmm/__init__.py b/pymc_marketing/mmm/__init__.py index 2d75a1b0..dafc319f 100644 --- a/pymc_marketing/mmm/__init__.py +++ b/pymc_marketing/mmm/__init__.py @@ -23,6 +23,7 @@ ) from pymc_marketing.mmm.components.saturation import ( HillSaturation, + InverseScaledLogisticSaturation, LogisticSaturation, MichaelisMentenSaturation, SaturationTransformation, @@ -45,6 +46,7 @@ "GeometricAdstock", "HillSaturation", "LogisticSaturation", + "InverseScaledLogisticSaturation", "MMM", "MMMModelBuilder", "MichaelisMentenSaturation", diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index d93f8ca6..fb9afa04 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -76,6 +76,7 @@ def function(self, x, b): from pymc_marketing.mmm.components.base import Transformation from pymc_marketing.mmm.transformers import ( hill_saturation, + inverse_scaled_logistic_saturation, logistic_saturation, michaelis_menten, tanh_saturation, @@ -201,6 +202,39 @@ def function(self, x, lam, beta): } +class InverseScaledLogisticSaturation(SaturationTransformation): + """Wrapper around inverse scaled logistic saturation function. + + For more information, see :func:`pymc_marketing.mmm.transformers.inverse_scaled_logistic_saturation`. + + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import InverseScaledLogisticSaturation + + rng = np.random.default_rng(0) + + adstock = InverseScaledLogisticSaturation() + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + + """ + + lookup_name = "inverse_scaled_logistic" + + def function(self, x, lam, beta): + return beta * inverse_scaled_logistic_saturation(x, lam) + + default_priors = { + "lam": Prior("Gamma", alpha=0.5, beta=1), + "beta": Prior("HalfNormal", sigma=2), + } + + class TanhSaturation(SaturationTransformation): """Wrapper around tanh saturation function. @@ -339,6 +373,7 @@ class HillSaturation(SaturationTransformation): cls.lookup_name: cls for cls in [ LogisticSaturation, + InverseScaledLogisticSaturation, TanhSaturation, TanhSaturationBaselined, MichaelisMentenSaturation, diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index 5c036445..58fa6c50 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -478,6 +478,55 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5): return (1 - pt.exp(-lam * x)) / (1 + pt.exp(-lam * x)) +def inverse_scaled_logistic_saturation( + x, lam: npt.NDArray[np.float64] | float = 0.5, eps: float = np.log(3) +): + """Inverse scaled logistic saturation transformation. + It offers a more intuitive alternative to logistic_saturation, + allowing for lambda to be interpreted as the half saturation point + when using default value for eps. + + .. math:: + f(x) = \\frac{1 - e^{-x*\epsilon/\lambda}}{1 + e^{-x*\epsilon/\lambda}} + + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + import arviz as az + from pymc_marketing.mmm.transformers import inverse_scaled_logistic_saturation + plt.style.use('arviz-darkgrid') + lam = np.array([0.25, 0.5, 1, 2, 4]) + x = np.linspace(0, 5, 100) + ax = plt.subplot(111) + for l in lam: + y = inverse_scaled_logistic_saturation(x, lam=l).eval() + plt.plot(x, y, label=f'lam = {l}') + plt.xlabel('spend', fontsize=12) + plt.ylabel('f(spend)', fontsize=12) + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) + ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + plt.show() + + Parameters + ---------- + x : tensor + Input tensor. + lam : float or array-like, optional, by default 0.5 + Saturation parameter. + eps : float or array-like, optional, by default ln(3) + Scaling parameter. ln(3) results in halfway saturation at lam + + Returns + ------- + tensor + Transformed tensor. + """ # noqa: W605 + return logistic_saturation(x, eps / lam) + + class TanhSaturationParameters(NamedTuple): """Container for tanh saturation parameters. diff --git a/tests/mmm/components/test_saturation.py b/tests/mmm/components/test_saturation.py index fc78b362..cea75391 100644 --- a/tests/mmm/components/test_saturation.py +++ b/tests/mmm/components/test_saturation.py @@ -22,6 +22,7 @@ from pymc_marketing.mmm.components.saturation import ( HillSaturation, + InverseScaledLogisticSaturation, LogisticSaturation, MichaelisMentenSaturation, TanhSaturation, @@ -40,6 +41,7 @@ def model() -> pm.Model: def saturation_functions(): return [ LogisticSaturation(), + InverseScaledLogisticSaturation(), TanhSaturation(), TanhSaturationBaselined(), MichaelisMentenSaturation(), @@ -93,6 +95,7 @@ def test_support_for_lift_test_integrations(saturation) -> None: @pytest.mark.parametrize( "name, saturation_cls", [ + ("inverse_scaled_logistic", InverseScaledLogisticSaturation), ("logistic", LogisticSaturation), ("tanh", TanhSaturation), ("tanh_baselined", TanhSaturationBaselined), diff --git a/tests/mmm/test_transformers.py b/tests/mmm/test_transformers.py index aa158699..f9c47ddc 100644 --- a/tests/mmm/test_transformers.py +++ b/tests/mmm/test_transformers.py @@ -28,6 +28,7 @@ delayed_adstock, geometric_adstock, hill_saturation, + inverse_scaled_logistic_saturation, logistic_saturation, michaelis_menten, tanh_saturation, @@ -343,6 +344,26 @@ def test_logistic_saturation_min_max_value(self, x, lam): assert y_eval.max() <= 1 assert y_eval.min() >= 0 + def test_inverse_scaled_logistic_saturation_lam_half(self): + x = np.array([0.01, 0.1, 0.5, 1, 100]) + y = inverse_scaled_logistic_saturation(x=x, lam=x) + expected = np.array([0.5] * len(x)) + np.testing.assert_almost_equal( + y.eval(), + expected, + decimal=5, + err_msg="The function does not behave as expected at the default value for eps", + ) + + def test_inverse_scaled_logistic_saturation_min_max_value(self): + x = np.array([0, 1, 100, 500, 5000]) + lam = np.array([0.01, 0.25, 0.75, 1.5, 5.0, 10.0, 15.0])[:, None] + + y = inverse_scaled_logistic_saturation(x=x, lam=lam) + y_eval = y.eval() + assert y_eval.max() <= 1 + assert y_eval.min() >= 0 + @pytest.mark.parametrize( "x, b, c", [