Skip to content

Commit

Permalink
Create inverse_scaled_logistic_saturation and the corresponding class (
Browse files Browse the repository at this point in the history
  • Loading branch information
arthurmello authored Jul 18, 2024
1 parent 717702a commit 6049ae8
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from pymc_marketing.mmm.components.saturation import (
HillSaturation,
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
SaturationTransformation,
Expand All @@ -45,6 +46,7 @@
"GeometricAdstock",
"HillSaturation",
"LogisticSaturation",
"InverseScaledLogisticSaturation",
"MMM",
"MMMModelBuilder",
"MichaelisMentenSaturation",
Expand Down
35 changes: 35 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def function(self, x, b):
from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
hill_saturation,
inverse_scaled_logistic_saturation,
logistic_saturation,
michaelis_menten,
tanh_saturation,
Expand Down Expand Up @@ -201,6 +202,39 @@ def function(self, x, lam, beta):
}


class InverseScaledLogisticSaturation(SaturationTransformation):
"""Wrapper around inverse scaled logistic saturation function.
For more information, see :func:`pymc_marketing.mmm.transformers.inverse_scaled_logistic_saturation`.
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import InverseScaledLogisticSaturation
rng = np.random.default_rng(0)
adstock = InverseScaledLogisticSaturation()
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()
"""

lookup_name = "inverse_scaled_logistic"

def function(self, x, lam, beta):
return beta * inverse_scaled_logistic_saturation(x, lam)

default_priors = {
"lam": Prior("Gamma", alpha=0.5, beta=1),
"beta": Prior("HalfNormal", sigma=2),
}


class TanhSaturation(SaturationTransformation):
"""Wrapper around tanh saturation function.
Expand Down Expand Up @@ -339,6 +373,7 @@ class HillSaturation(SaturationTransformation):
cls.lookup_name: cls
for cls in [
LogisticSaturation,
InverseScaledLogisticSaturation,
TanhSaturation,
TanhSaturationBaselined,
MichaelisMentenSaturation,
Expand Down
49 changes: 49 additions & 0 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,55 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5):
return (1 - pt.exp(-lam * x)) / (1 + pt.exp(-lam * x))


def inverse_scaled_logistic_saturation(
x, lam: npt.NDArray[np.float64] | float = 0.5, eps: float = np.log(3)
):
"""Inverse scaled logistic saturation transformation.
It offers a more intuitive alternative to logistic_saturation,
allowing for lambda to be interpreted as the half saturation point
when using default value for eps.
.. math::
f(x) = \\frac{1 - e^{-x*\epsilon/\lambda}}{1 + e^{-x*\epsilon/\lambda}}
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
import arviz as az
from pymc_marketing.mmm.transformers import inverse_scaled_logistic_saturation
plt.style.use('arviz-darkgrid')
lam = np.array([0.25, 0.5, 1, 2, 4])
x = np.linspace(0, 5, 100)
ax = plt.subplot(111)
for l in lam:
y = inverse_scaled_logistic_saturation(x, lam=l).eval()
plt.plot(x, y, label=f'lam = {l}')
plt.xlabel('spend', fontsize=12)
plt.ylabel('f(spend)', fontsize=12)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()
Parameters
----------
x : tensor
Input tensor.
lam : float or array-like, optional, by default 0.5
Saturation parameter.
eps : float or array-like, optional, by default ln(3)
Scaling parameter. ln(3) results in halfway saturation at lam
Returns
-------
tensor
Transformed tensor.
""" # noqa: W605
return logistic_saturation(x, eps / lam)


class TanhSaturationParameters(NamedTuple):
"""Container for tanh saturation parameters.
Expand Down
3 changes: 3 additions & 0 deletions tests/mmm/components/test_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pymc_marketing.mmm.components.saturation import (
HillSaturation,
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
TanhSaturation,
Expand All @@ -40,6 +41,7 @@ def model() -> pm.Model:
def saturation_functions():
return [
LogisticSaturation(),
InverseScaledLogisticSaturation(),
TanhSaturation(),
TanhSaturationBaselined(),
MichaelisMentenSaturation(),
Expand Down Expand Up @@ -93,6 +95,7 @@ def test_support_for_lift_test_integrations(saturation) -> None:
@pytest.mark.parametrize(
"name, saturation_cls",
[
("inverse_scaled_logistic", InverseScaledLogisticSaturation),
("logistic", LogisticSaturation),
("tanh", TanhSaturation),
("tanh_baselined", TanhSaturationBaselined),
Expand Down
21 changes: 21 additions & 0 deletions tests/mmm/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
delayed_adstock,
geometric_adstock,
hill_saturation,
inverse_scaled_logistic_saturation,
logistic_saturation,
michaelis_menten,
tanh_saturation,
Expand Down Expand Up @@ -343,6 +344,26 @@ def test_logistic_saturation_min_max_value(self, x, lam):
assert y_eval.max() <= 1
assert y_eval.min() >= 0

def test_inverse_scaled_logistic_saturation_lam_half(self):
x = np.array([0.01, 0.1, 0.5, 1, 100])
y = inverse_scaled_logistic_saturation(x=x, lam=x)
expected = np.array([0.5] * len(x))
np.testing.assert_almost_equal(
y.eval(),
expected,
decimal=5,
err_msg="The function does not behave as expected at the default value for eps",
)

def test_inverse_scaled_logistic_saturation_min_max_value(self):
x = np.array([0, 1, 100, 500, 5000])
lam = np.array([0.01, 0.25, 0.75, 1.5, 5.0, 10.0, 15.0])[:, None]

y = inverse_scaled_logistic_saturation(x=x, lam=lam)
y_eval = y.eval()
assert y_eval.max() <= 1
assert y_eval.min() >= 0

@pytest.mark.parametrize(
"x, b, c",
[
Expand Down

0 comments on commit 6049ae8

Please sign in to comment.