Skip to content

Commit

Permalink
Merge branch 'main' into mmm_casestudy
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz authored Oct 17, 2024
2 parents 283d0c1 + 429b955 commit b9898ca
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/notebooks/mmm/mmm_components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
"\n",
"alpha = 1\n",
"lam = 1 / 10\n",
"yy = saturation.function(xx, alpha=alpha, lam=lam)\n",
"yy = saturation.function(xx, alpha=alpha, lam=lam).eval()\n",
"\n",
"fig, ax = plt.subplots()\n",
"fig.suptitle(\"Example Saturation Curve\")\n",
Expand Down
5 changes: 4 additions & 1 deletion pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def function(self, x, b):
"""

import numpy as np
import pytensor.tensor as pt
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

Expand Down Expand Up @@ -337,7 +338,9 @@ class MichaelisMentenSaturation(SaturationTransformation):

lookup_name = "michaelis_menten"

function = michaelis_menten
def function(self, x, alpha, lam):
"""Michaelis-Menten saturation function."""
return pt.as_tensor_variable(michaelis_menten(x, alpha, lam))

default_priors = {
"alpha": Prior("Gamma", mu=2, sigma=1),
Expand Down
17 changes: 14 additions & 3 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from matplotlib import pyplot as plt

from pymc_marketing.mmm.components.adstock import GeometricAdstock
from pymc_marketing.mmm.components.saturation import LogisticSaturation
from pymc_marketing.mmm.components.saturation import (
LogisticSaturation,
MichaelisMentenSaturation,
)
from pymc_marketing.mmm.mmm import MMM, BaseMMM
from pymc_marketing.mmm.preprocessing import MaxAbsScaleTarget

Expand Down Expand Up @@ -220,10 +223,18 @@ def test_plots(self, plotting_mmm, func_plot_name, kwargs_plot) -> None:
plt.close("all")


@pytest.fixture(
scope="module",
params=[LogisticSaturation(), MichaelisMentenSaturation()],
ids=["LogisticSaturation", "MichaelisMentenSaturation"],
)
def saturation(request):
return request.param


@pytest.fixture(scope="module")
def mock_mmm() -> MMM:
def mock_mmm(saturation) -> MMM:
adstock = GeometricAdstock(l_max=4)
saturation = LogisticSaturation()
return MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
Expand Down

0 comments on commit b9898ca

Please sign in to comment.