Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create inverse_scaled_logistic_saturation and the corresponding class #827

Merged
2 changes: 2 additions & 0 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from pymc_marketing.mmm.components.saturation import (
HillSaturation,
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
SaturationTransformation,
Expand All @@ -45,6 +46,7 @@
"GeometricAdstock",
"HillSaturation",
"LogisticSaturation",
"InverseScaledLogisticSaturation",
"MMM",
"MMMModelBuilder",
"MichaelisMentenSaturation",
Expand Down
35 changes: 35 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def function(self, x, b):
from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
hill_saturation,
inverse_scaled_logistic_saturation,
logistic_saturation,
michaelis_menten,
tanh_saturation,
Expand Down Expand Up @@ -201,6 +202,39 @@ def function(self, x, lam, beta):
}


class InverseScaledLogisticSaturation(SaturationTransformation):
"""Wrapper around inverse scaled logistic saturation function.

For more information, see :func:`pymc_marketing.mmm.transformers.inverse_scaled_logistic_saturation`.

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import InverseScaledLogisticSaturation

rng = np.random.default_rng(0)

adstock = InverseScaledLogisticSaturation()
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()

"""

lookup_name = "inverse_scaled_logistic"

def function(self, x, lam, beta):
return beta * inverse_scaled_logistic_saturation(x, lam)

default_priors = {
"lam": Prior("Gamma", alpha=0.3, beta=0.6),
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
"beta": Prior("HalfNormal", sigma=2),
}


class TanhSaturation(SaturationTransformation):
"""Wrapper around tanh saturation function.

Expand Down Expand Up @@ -339,6 +373,7 @@ class HillSaturation(SaturationTransformation):
cls.lookup_name: cls
for cls in [
LogisticSaturation,
InverseScaledLogisticSaturation,
TanhSaturation,
TanhSaturationBaselined,
MichaelisMentenSaturation,
Expand Down
49 changes: 49 additions & 0 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,55 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5):
return (1 - pt.exp(-lam * x)) / (1 + pt.exp(-lam * x))


def inverse_scaled_logistic_saturation(
x, lam: npt.NDArray[np.float64] | float = 0.5, eps: float = np.log(3)
):
"""Inverse scaled logistic saturation transformation.
It offers a more intuitive alternative to logistic_saturation,
allowing for lambda to be interpreted as the half saturation point,
when using default values for lam and eps.
wd60622 marked this conversation as resolved.
Show resolved Hide resolved

wd60622 marked this conversation as resolved.
Show resolved Hide resolved
.. math::
f(x) = \\frac{1 - e^{-x*\epsilon/\lambda}}{1 + e^{-x*\epsilon/\lambda}}

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
import arviz as az
from pymc_marketing.mmm.transformers import inverse_scaled_logistic_saturation
plt.style.use('arviz-darkgrid')
lam = np.array([0.25, 0.5, 1, 2, 4])
x = np.linspace(0, 5, 100)
ax = plt.subplot(111)
for l in lam:
y = inverse_scaled_logistic_saturation(x, lam=l).eval()
plt.plot(x, y, label=f'lam = {l}')
plt.xlabel('spend', fontsize=12)
plt.ylabel('f(spend)', fontsize=12)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()

Parameters
----------
x : tensor
Input tensor.
lam : float or array-like, optional, by default 0.5
Saturation parameter.
eps : float or array-like, optional, by default ln(3)
Scaling parameter. ln(3) results in halfway saturation for lam = 0.5
wd60622 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
tensor
Transformed tensor.
""" # noqa: W605
return logistic_saturation(x, eps / lam)


class TanhSaturationParameters(NamedTuple):
"""Container for tanh saturation parameters.

Expand Down
3 changes: 3 additions & 0 deletions tests/mmm/components/test_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pymc_marketing.mmm.components.saturation import (
HillSaturation,
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
TanhSaturation,
Expand All @@ -40,6 +41,7 @@ def model() -> pm.Model:
def saturation_functions():
return [
LogisticSaturation(),
InverseScaledLogisticSaturation(),
TanhSaturation(),
TanhSaturationBaselined(),
MichaelisMentenSaturation(),
Expand Down Expand Up @@ -93,6 +95,7 @@ def test_support_for_lift_test_integrations(saturation) -> None:
@pytest.mark.parametrize(
"name, saturation_cls",
[
("inverse_scaled_logistic", InverseScaledLogisticSaturation),
("logistic", LogisticSaturation),
("tanh", TanhSaturation),
("tanh_baselined", TanhSaturationBaselined),
Expand Down
27 changes: 27 additions & 0 deletions tests/mmm/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
delayed_adstock,
geometric_adstock,
hill_saturation,
inverse_scaled_logistic_saturation,
logistic_saturation,
michaelis_menten,
tanh_saturation,
Expand Down Expand Up @@ -343,6 +344,32 @@ def test_logistic_saturation_min_max_value(self, x, lam):
assert y_eval.max() <= 1
assert y_eval.min() >= 0

def test_inverse_scaled_logistic_saturation_lam_half(self):
x = np.array([0.5] * 100)
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
y = inverse_scaled_logistic_saturation(x=x, lam=0.5, eps=np.ln(3))
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
expected = np.array([0.5] * 100)
np.testing.assert_almost_equal(
y.eval(),
expected,
decimal=5,
err_msg="The function does not behave as expected at lambda 0.5.",
)

@pytest.mark.parametrize(
"x, lam",
[
(np.ones(shape=(100)), 0.5),
(np.linspace(start=0.0, stop=1.0, num=50), 10),
(np.linspace(start=200, stop=1000, num=50), 0.001),
(np.zeros(shape=(100)), 1),
],
)
def test_inverse_scaled_logistic_saturation_min_max_value(self, x, lam):
y = inverse_scaled_logistic_saturation(x=x, lam=lam)
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
y_eval = y.eval()
assert y_eval.max() <= 1
assert y_eval.min() >= 0

@pytest.mark.parametrize(
"x, b, c",
[
Expand Down