Skip to content

Commit

Permalink
MMM: refactor priors to be user defined
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusSagen committed Oct 21, 2023
1 parent ae1c871 commit 439115c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 60 deletions.
134 changes: 83 additions & 51 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,49 @@ def __init__(
yearly_seasonality : Optional[int], optional
Number of Fourier modes to model yearly seasonality, by default None.
Examples
--------
DelayedSaturatedMMM
.. code-block:: python
import pymc as pm
from pymc_marketing.mmm import DelayedSaturatedMMM
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/datasets/mmm_example.csv"
data = pd.read_csv(data_url, parse_dates=['date_week'])
model = DelayedSaturatedMMM(
date_column="date_week",
channel_columns=["x1", "x2"],
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
model_config={
# priors
"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"beta_channel": {"dist": "HalfNormal", "kwargs": {"sigma": 2}, "dims": ("channel",)},
"alpha": {"dist": "Beta", "kwargs": {"alpha": 1, "beta": 3}, "dims": ("channel",)},
"lam": {"dist": "Gamma", "kwargs": {"alpha": 3, "beta": 1}, "dims": ("channel",)},
"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
"gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}, "dims": ("control",)},
"gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}, "dims": "fourier_mode"},
# params
"mu": {"dims": ("date",)},
"likelihood": {"dims": ("date",)},
},
)
X = data.drop('y',axis=1)
y = data['y']
model.fit(X,y)
model.plot_components_contributions();
References
----------
.. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017).
Expand Down Expand Up @@ -174,6 +217,7 @@ def build_model(
"""
model_config = self.model_config
self._generate_and_preprocess_model_data(X, y)

with pm.Model(coords=self.model_coords) as self.model:
channel_data_ = pm.MutableData(
name="channel_data",
Expand All @@ -187,33 +231,13 @@ def build_model(
dims="date",
)

intercept = pm.Normal(
name="intercept",
mu=model_config["intercept"]["mu"],
sigma=model_config["intercept"]["sigma"],
)

beta_channel = pm.HalfNormal(
name="beta_channel",
sigma=model_config["beta_channel"]["sigma"],
dims=model_config["beta_channel"]["dims"],
)
alpha = pm.Beta(
name="alpha",
alpha=model_config["alpha"]["alpha"],
beta=model_config["alpha"]["beta"],
dims=model_config["alpha"]["dims"],
)

lam = pm.Gamma(
name="lam",
alpha=model_config["lam"]["alpha"],
beta=model_config["lam"]["beta"],
dims=model_config["lam"]["dims"],
)

sigma = pm.HalfNormal(name="sigma", sigma=model_config["sigma"]["sigma"])
intercept = self.register_rv(name="intercept")
beta_channel = self.register_rv(name="beta_channel")
alpha = self.register_rv(name="alpha")
lam = self.register_rv(name="lam")
sigma = self.register_rv(name="sigma")

# TODO: register the adstock transforms
channel_adstock = pm.Deterministic(
name="channel_adstock",
var=geometric_adstock(
Expand Down Expand Up @@ -245,19 +269,14 @@ def build_model(
for column in self.control_columns
)
):
gamma_control = self.register_rv(name="gamma_control")

control_data_ = pm.MutableData(
name="control_data",
value=self.preprocessed_data["X"][self.control_columns],
dims=("date", "control"),
)

gamma_control = pm.Normal(
name="gamma_control",
mu=model_config["gamma_control"]["mu"],
sigma=model_config["gamma_control"]["sigma"],
dims=model_config["gamma_control"]["dims"],
)

control_contributions = pm.Deterministic(
name="control_contributions",
var=control_data_ * gamma_control,
Expand All @@ -274,19 +293,14 @@ def build_model(
for column in self.fourier_columns
)
):
gamma_fourier = self.register_rv(name="gamma_fourier")

fourier_data_ = pm.MutableData(
name="fourier_data",
value=self.preprocessed_data["X"][self.fourier_columns],
dims=("date", "fourier_mode"),
)

gamma_fourier = pm.Laplace(
name="gamma_fourier",
mu=model_config["gamma_fourier"]["mu"],
b=model_config["gamma_fourier"]["b"],
dims=model_config["gamma_fourier"]["dims"],
)

fourier_contribution = pm.Deterministic(
name="fourier_contributions",
var=fourier_data_ * gamma_fourier,
Expand All @@ -308,23 +322,41 @@ def build_model(
)

@property
def default_model_config(self) -> Dict:
model_config: Dict = {
"intercept": {"mu": 0, "sigma": 2},
"beta_channel": {"sigma": 2, "dims": ("channel",)},
"alpha": {"alpha": 1, "beta": 3, "dims": ("channel",)},
"lam": {"alpha": 3, "beta": 1, "dims": ("channel",)},
"sigma": {"sigma": 2},
def default_model_config(self) -> Dict[str, Dict]:
return {
# Prior
"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"beta_channel": {
"dist": "HalfNormal",
"kwargs": {"sigma": 2},
"dims": ("channel",),
},
"alpha": {
"dist": "Beta",
"kwargs": {"alpha": 1, "beta": 3},
"dims": ("channel",),
},
"lam": {
"dist": "Gamma",
"kwargs": {"alpha": 3, "beta": 1},
"dims": ("channel",),
},
"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
"gamma_control": {
"mu": 0,
"sigma": 2,
"dist": "Normal",
"kwargs": {"mu": 0, "sigma": 2},
"dims": ("control",),
},
"gamma_fourier": {
"dist": "Laplace",
"kwargs": {"mu": 0, "b": 1},
"dims": "fourier_mode",
},
# Deterministic
"mu": {"dims": ("date",)},
# Likelihood
"likelihood": {"dims": ("date",)},
"gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"},
}
return model_config

def _get_fourier_models_data(self, X) -> pd.DataFrame:
"""Generates fourier modes to model seasonality.
Expand Down
33 changes: 24 additions & 9 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,41 @@ def toy_X() -> pd.DataFrame:
@pytest.fixture(scope="class")
def model_config_requiring_serialization() -> dict:
model_config = {
"intercept": {"mu": 0, "sigma": 2},
"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"beta_channel": {
"sigma": np.array([0.4533017, 0.25488063]),
"dist": "HalfNormal",
"kwargs": {"sigma": np.array([0.4533017, 0.25488063])},
"dims": ("channel",),
},
"alpha": {
"alpha": np.array([3, 3]),
"beta": np.array([3.55001301, 2.87092431]),
"dist": "Beta",
"kwargs": {
"alpha": np.array([3, 3]),
"beta": np.array([3.55001301, 2.87092431]),
},
"dims": ("channel",),
},
"lam": {
"alpha": np.array([3, 3]),
"beta": np.array([4.12231653, 5.02896872]),
"dist": "Gamma",
"kwargs": {
"alpha": np.array([3, 3]),
"beta": np.array([4.12231653, 5.02896872]),
},
"dims": ("channel",),
},
"sigma": {"sigma": 2},
"gamma_control": {"mu": 0, "sigma": 2, "dims": ("control",)},
"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
"gamma_control": {
"dist": "Normal",
"kwargs": {"mu": 0, "sigma": 2},
"dims": ("control",),
},
"gamma_fourier": {
"dist": "Laplace",
"kwargs": {"mu": 0, "b": 1},
"dims": "fourier_mode",
},
"mu": {"dims": ("date",)},
"likelihood": {"dims": ("date",)},
"gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"},
}
return model_config

Expand Down

0 comments on commit 439115c

Please sign in to comment.