Skip to content

Commit

Permalink
Merge branch 'main' into model_evaluation2
Browse files Browse the repository at this point in the history
  • Loading branch information
louismagowan authored Oct 18, 2024
2 parents 2b9768a + e46f690 commit 791fc47
Show file tree
Hide file tree
Showing 8 changed files with 7,232 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/notebooks/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mmm/mmm_time_varying_media_example
mmm/mmm_components
mmm/mmm_roas
mmm/mmm_time_slice_cross_validation
mmm/mmm_case_study
:::

:::{toctree}
Expand Down
7,174 changes: 7,174 additions & 0 deletions docs/source/notebooks/mmm/mmm_case_study.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/notebooks/mmm/mmm_components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
"\n",
"alpha = 1\n",
"lam = 1 / 10\n",
"yy = saturation.function(xx, alpha=alpha, lam=lam)\n",
"yy = saturation.function(xx, alpha=alpha, lam=lam).eval()\n",
"\n",
"fig, ax = plt.subplots()\n",
"fig.suptitle(\"Example Saturation Curve\")\n",
Expand Down
5 changes: 4 additions & 1 deletion pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def function(self, x, b):
"""

import numpy as np
import pytensor.tensor as pt
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

Expand Down Expand Up @@ -337,7 +338,9 @@ class MichaelisMentenSaturation(SaturationTransformation):

lookup_name = "michaelis_menten"

function = michaelis_menten
def function(self, x, alpha, lam):
"""Michaelis-Menten saturation function."""
return pt.as_tensor_variable(michaelis_menten(x, alpha, lam))

default_priors = {
"alpha": Prior("Gamma", mu=2, sigma=1),
Expand Down
19 changes: 19 additions & 0 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2222,6 +2222,10 @@ def _create_synth_dataset(
) -> pd.DataFrame:
"""Create a synthetic dataset based on the given allocation strategy (Budget) and time granularity.
**Important**: When generating the posterior predicive distribution for the target with the optimized budget,
we are setting the control variables to zero! This is done because in many situations we do not have all the
control variables in the future (e.g. outlier control, special events).
Parameters
----------
df : pd.DataFrame
Expand Down Expand Up @@ -2316,6 +2320,7 @@ def allocate_budget_to_maximize_response(
custom_constraints: dict[str, float] | None = None,
quantile: float = 0.5,
noise_level: float = 0.01,
**minimize_kwargs,
) -> az.InferenceData:
"""Allocate the given budget to maximize the response over a specified time period.
Expand All @@ -2329,6 +2334,10 @@ def allocate_budget_to_maximize_response(
budget, and creates a synthetic dataset based on the optimal allocation. Finally,
it performs posterior predictive sampling on the synthetic dataset.
**Important**: When generating the posterior predicive distribution for the target with the optimized budget,
we are setting the control variables to zero! This is done because in many situations we do not have all the
control variables in the future (e.g. outlier control, special events).
Parameters
----------
budget : float or int
Expand All @@ -2344,6 +2353,10 @@ def allocate_budget_to_maximize_response(
Custom constraints for the optimization. If None, no custom constraints are applied.
quantile : float, optional
The quantile to use for recovering transformation parameters. Default is 0.5.
noise_level : float, optional
The level of noise added to the allocation strategy (by default 1%).
**minimize_kwargs
Additional arguments to pass to the `BudgetOptimizer`.
Returns
-------
Expand All @@ -2355,7 +2368,12 @@ def allocate_budget_to_maximize_response(
ValueError
If the time granularity is not supported.
ValueError
If the noise level is not a float.
"""
if not isinstance(noise_level, float):
raise ValueError("noise_level must be a float")

parameters_mid = self.format_recovered_transformation_parameters(
quantile=quantile
)
Expand All @@ -2373,6 +2391,7 @@ def allocate_budget_to_maximize_response(
total_budget=budget,
budget_bounds=budget_bounds,
custom_constraints=custom_constraints,
**minimize_kwargs,
)

synth_dataset = self._create_synth_dataset(
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.
"""Version of the package."""

__version__ = "0.9.0"
__version__ = "0.10.0"
18 changes: 18 additions & 0 deletions tests/mmm/test_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,24 @@ def test_allocate_budget_to_maximize_response(self, mmm_fitted: MMM) -> None:
inference_periods == num_periods
), f"Number of periods in the data {inference_periods} does not match the expected {num_periods}"

def test_allocate_budget_to_maximize_response_bad_noise_level(
self, mmm_fitted: MMM
) -> None:
budget = 2.0
num_periods = 8
time_granularity = "weekly"
budget_bounds = {"channel_1": [0.5, 1.2], "channel_2": [0.5, 1.5]}
noise_level = "bad_noise_level"

with pytest.raises(ValueError, match="noise_level must be a float"):
mmm_fitted.allocate_budget_to_maximize_response(
budget=budget,
time_granularity=time_granularity,
num_periods=num_periods,
budget_bounds=budget_bounds,
noise_level=noise_level,
)

@pytest.mark.parametrize(
argnames="original_scale",
argvalues=[False, True],
Expand Down
17 changes: 14 additions & 3 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from matplotlib import pyplot as plt

from pymc_marketing.mmm.components.adstock import GeometricAdstock
from pymc_marketing.mmm.components.saturation import LogisticSaturation
from pymc_marketing.mmm.components.saturation import (
LogisticSaturation,
MichaelisMentenSaturation,
)
from pymc_marketing.mmm.mmm import MMM, BaseMMM
from pymc_marketing.mmm.preprocessing import MaxAbsScaleTarget

Expand Down Expand Up @@ -220,10 +223,18 @@ def test_plots(self, plotting_mmm, func_plot_name, kwargs_plot) -> None:
plt.close("all")


@pytest.fixture(
scope="module",
params=[LogisticSaturation(), MichaelisMentenSaturation()],
ids=["LogisticSaturation", "MichaelisMentenSaturation"],
)
def saturation(request):
return request.param


@pytest.fixture(scope="module")
def mock_mmm() -> MMM:
def mock_mmm(saturation) -> MMM:
adstock = GeometricAdstock(l_max=4)
saturation = LogisticSaturation()
return MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
Expand Down

0 comments on commit 791fc47

Please sign in to comment.