Skip to content

Commit

Permalink
Add time-varying prior functionality to DelayedSaturatedMMM
Browse files Browse the repository at this point in the history
  • Loading branch information
ulfaslak committed Mar 19, 2024
1 parent 9480869 commit faaba0f
Show file tree
Hide file tree
Showing 6 changed files with 4,870 additions and 27 deletions.
4,250 changes: 4,250 additions & 0 deletions docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Large diffs are not rendered by default.

401 changes: 401 additions & 0 deletions docs/source/notebooks/mmm/mock_cgp_data-no-target.csv

Large diffs are not rendered by default.

45 changes: 33 additions & 12 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ def __init__(
sampler_config: Optional[Dict] = None,
**kwargs,
) -> None:
self.X: Optional[pd.DataFrame] = None
self.y: Optional[Union[pd.Series, np.ndarray]] = None
self.date_column: str = date_column
self.channel_columns: Union[List[str], Tuple[str]] = channel_columns

self.n_channel: int = len(channel_columns)

self.X: Optional[pd.DataFrame] = None
self.y: Optional[Union[pd.Series, np.ndarray]] = None

self._time_resolution: Optional[int] = None
self._time_index: Optional[np.ndarray[int]] = None
self._time_index_mid: Optional[int] = None
self._fit_result: Optional[az.InferenceData] = None
self._posterior_predictive: Optional[az.InferenceData] = None
super().__init__(model_config=model_config, sampler_config=sampler_config)
Expand Down Expand Up @@ -319,7 +325,7 @@ def plot_posterior_predictive(
fig, ax = plt.subplots(**plt_kwargs)
if self.X is not None and self.y is not None:
ax.fill_between(
x=self.X[self.date_column],
x=posterior_predictive_data.date,
y1=likelihood_hdi_94[:, 0],
y2=likelihood_hdi_94[:, 1],
color="C0",
Expand All @@ -328,19 +334,26 @@ def plot_posterior_predictive(
)

ax.fill_between(
x=self.X[self.date_column],
x=posterior_predictive_data.date,
y1=likelihood_hdi_50[:, 0],
y2=likelihood_hdi_50[:, 1],
color="C0",
alpha=0.3,
label="$50\%$ HDI", # noqa: W605
)

target_to_plot: np.ndarray = np.asarray(
self.y if original_scale else self.preprocessed_data["y"] # type: ignore
target_to_plot = np.asarray(
self.y if original_scale else self.get_target_transformer().transform(self.y[:, None]).flatten() # type: ignore
)

assert len(target_to_plot) == len(posterior_predictive_data.date), (
"The length of the target variable doesn't match the length of the date column. "
"If you are predicting out-of-sample, please overwrite `self.y` with the "
"corresponding (non-transformed) target variable."
)

ax.plot(
np.asarray(self.X[self.date_column]),
np.asarray(posterior_predictive_data.date),
target_to_plot,
color="black",
)
Expand Down Expand Up @@ -417,11 +430,18 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure:
intercept = az.extract(
self.fit_result, var_names=["intercept"], combined=False
)
intercept_hdi = np.repeat(
a=az.hdi(intercept).intercept.data[None, ...],
repeats=self.X[self.date_column].shape[0],
axis=0,
)

if intercept.ndim == 2:
# Intercept has a stationary prior
intercept_hdi = np.repeat(
a=az.hdi(intercept).intercept.data[None, ...],
repeats=self.X[self.date_column].shape[0],
axis=0,
)
elif intercept.ndim == 3:
# Intercept has a time-varying prior
intercept_hdi = az.hdi(intercept).intercept.data

ax.plot(
np.asarray(self.X[self.date_column]),
np.full(len(self.X[self.date_column]), intercept.mean().data),
Expand Down Expand Up @@ -992,6 +1012,7 @@ def label_func(channel):

def legend_title_func(channel):
return "Legend"

else:
nrows = len(channels_to_plot)
figsize = (12, 4 * len(channels_to_plot))
Expand Down
123 changes: 108 additions & 15 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pymc_marketing.mmm.base import MMM
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation
from pymc_marketing.mmm.tvp import time_varying_prior
from pymc_marketing.mmm.utils import (
apply_sklearn_transformer_across_date,
generate_fourier_modes,
Expand All @@ -33,6 +34,8 @@ def __init__(
date_column: str,
channel_columns: List[str],
adstock_max_lag: int,
time_varying_media_effect: bool = False,
time_varying_intercept: bool = False,
model_config: Optional[Dict] = None,
sampler_config: Optional[Dict] = None,
validate_data: bool = True,
Expand All @@ -48,6 +51,12 @@ def __init__(
Column name of the date variable.
channel_columns : List[str]
Column names of the media channel variables.
adstock_max_lag : int
Number of lags to consider in the adstock transformation.
time_varying_media_effect : bool, optional
Whether to consider time-varying media effects, by default False.
time_varying_intercept : bool, optional
Whether to consider time-varying intercept, by default False.
model_config : Dictionary, optional
dictionary of parameters that initialise model configuration. Class-default defined by the user default_model_config method.
sampler_config : Dictionary, optional
Expand All @@ -67,6 +76,8 @@ def __init__(
"""
self.control_columns = control_columns
self.adstock_max_lag = adstock_max_lag
self.time_varying_media_effect = time_varying_media_effect
self.time_varying_intercept = time_varying_intercept
self.yearly_seasonality = yearly_seasonality
self.date_column = date_column
self.validate_data = validate_data
Expand All @@ -91,15 +102,34 @@ def output_var(self):
def _generate_and_preprocess_model_data( # type: ignore
self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray]
) -> None:
"""
Applies preprocessing to the data before fitting the model.
if validate is True, it will check if the data is valid for the model.
sets self.model_coords based on provided dataset
"""Preprocess data and set model state variables.
Applies preprocessing to the data before fitting the model. If validate
is True, it will check if the data is valid for the model. *Only* gets
called before fitting the model.
Parameters
----------
X : Union[pd.DataFrame, pd.Series], shape (n_obs, n_features)
y : Union[pd.Series, np.ndarray], shape (n_obs,)
Sets
----
preprocessed_data : Dict[str, Union[pd.DataFrame, pd.Series]]
Preprocessed data for the model.
X : pd.DataFrame
A filtered version of the input `X`, such that it is guaranteed that
it contains only the `date_column`, the columns that are specified
in the `channel_columns` and `control_columns`, and fourier features
if `yearly_seasonality=True`.
y : Union[pd.Series, np.ndarray]
The target variable for the model (as provided).
_time_index : np.ndarray
The index of the date column. Used by TVP
_time_index_mid : int
The middle index of the date index. Used by TVP.
_time_resolution: int
The time resolution of the date index. Used by TVP.
"""
date_data = X[self.date_column]
channel_data = X[self.channel_columns]
Expand Down Expand Up @@ -139,6 +169,11 @@ def _generate_and_preprocess_model_data( # type: ignore
}
self.X: pd.DataFrame = X_data
self.y: Union[pd.Series, np.ndarray] = y
self._time_index = np.arange(0, X.shape[0])
self._time_index_mid = X.shape[0] // 2
self._time_resolution = (
self.X[self.date_column].iloc[1] - self.X[self.date_column].iloc[0]
).days

def _save_input_params(self, idata) -> None:
"""Saves input parameters to the attrs of idata."""
Expand Down Expand Up @@ -337,9 +372,52 @@ def build_model(
dims="date",
)

intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
)
if self.time_varying_intercept or self.time_varying_media_effect:
time_index = pm.MutableData(
"time_index",
self._time_index,
dims="date",
)

if self.time_varying_intercept:
tv_multiplier_intercept = time_varying_prior(
name="tv_multiplier_intercept",
X=time_index,
X_mid=self._time_index_mid,
positive=True,
m=200,
L=[self._time_index_mid + 365 / self._time_resolution],
ls_mu=365 / self._time_resolution * 2,
ls_sigma=10,
eta_lam=1,
dims="date",
)
intercept_base = self.intercept_dist(
name="intercept_base", **self.model_config["intercept"]["kwargs"]
)
intercept = pm.Deterministic(
name="intercept",
var=intercept_base * tv_multiplier_intercept,
dims="date",
)
else:
intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
)

if self.time_varying_media_effect:
tv_multiplier_media = time_varying_prior(
name="tv_multiplier_media",
X=time_index,
X_mid=self._time_index_mid,
positive=True,
m=200,
L=[self._time_index_mid + 365 / self._time_resolution],
ls_mu=365 / self._time_resolution * 2,
ls_sigma=10,
eta_lam=1,
dims="date",
)

beta_channel = self.beta_channel_dist(
name="beta_channel",
Expand Down Expand Up @@ -373,11 +451,21 @@ def build_model(
var=logistic_saturation(x=channel_adstock, lam=lam),
dims=("date", "channel"),
)
channel_contributions = pm.Deterministic(
name="channel_contributions",
var=channel_adstock_saturated * beta_channel,
dims=("date", "channel"),
)

if self.time_varying_media_effect:
channel_contributions = pm.Deterministic(
name="channel_contributions",
var=channel_adstock_saturated
* beta_channel
* tv_multiplier_media[:, None],
dims=("date", "channel"),
)
else:
channel_contributions = pm.Deterministic(
name="channel_contributions",
var=channel_adstock_saturated * beta_channel,
dims=("date", "channel"),
)

mu_var = intercept + channel_contributions.sum(axis=-1)
if (
Expand Down Expand Up @@ -657,11 +745,16 @@ def identity(x):
if hasattr(self, "fourier_columns"):
data["fourier_data"] = self._get_fourier_models_data(X)

if self.time_varying_intercept or self.time_varying_media_effect:
data["time_index"] = np.arange(
self._time_index[-1], self._time_index[-1] + X.shape[0]
)

if y is not None:
if isinstance(y, pd.Series):
data[
"target"
] = y.to_numpy() # convert Series to numpy array explicitly
data["target"] = (
y.to_numpy()
) # convert Series to numpy array explicitly
elif isinstance(y, np.ndarray):
data["target"] = y
else:
Expand Down
72 changes: 72 additions & 0 deletions pymc_marketing/mmm/tvp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Optional

import pymc as pm
from pymc_marketing.mmm.utils import softplus


def time_varying_prior(
name: str,
X: pm.Deterministic,
X_mid: int | float,
positive: bool = False,
dims: Optional[tuple[str, str] | str] = None,
m: int = 40,
L: int = 100,
eta_lam: float = 1,
ls_mu: float = 5,
ls_sigma: float = 5,
cov_func: Optional[pm.gp.cov.Prod] = None,
model: Optional[pm.Model] = None,
) -> pm.Deterministic:
"""Time varying prior, based the Hilbert Space Gaussian Process (HSGP).
Parameters
----------
name : str
Name of the prior.
X : 1d array-like of int or float
Time points.
X_mid : int or float
Midpoint of the time points.
positive : bool
Whether the prior should be positive.
dims : tuple of str or str
Dimensions of the prior.
m : int
Number of basis functions.
L : int
Number of quadrature points.
eta_lam : float
Exponential prior for the variance.
ls_mu : float
Mean of the inverse gamma prior for the lengthscale.
ls_sigma : float
Standard deviation of the inverse gamma prior for the lengthscale.
cov_func : pm.gp.cov.Prod
Covariance function.
model : pm.Model
PyMC model.
Returns
-------
pm.Deterministic
Time-varying prior.
""" # noqa: W605
if cov_func is None:
eta = pm.Exponential(f"eta_{name}", lam=eta_lam)
ls = pm.InverseGamma(f"ls_{name}", mu=ls_mu, sigma=ls_sigma)
cov_func = eta**2 * pm.gp.cov.Matern52(1, ls=ls)

with pm.modelcontext(model) as model:
if type(dims) is tuple:
n_columns = len(model.coords[dims[1]])
hsgp_size = (n_columns, m)
else:
hsgp_size = m
gp = pm.gp.HSGP(m=[m], L=[L], cov_func=cov_func)
phi, sqrt_psd = gp.prior_linearized(Xs=X[:, None] - X_mid)
hsgp_coefs = pm.Normal(f"_hsgp_coefs_{name}", size=hsgp_size)
f = phi @ (hsgp_coefs * sqrt_psd).T
if positive:
f = softplus(f)
return pm.Deterministic(name, f, dims=dims)
6 changes: 6 additions & 0 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from scipy.optimize import curve_fit, minimize_scalar

Expand Down Expand Up @@ -326,3 +328,7 @@ def apply_sklearn_transformer_across_date(
data.attrs = attrs

return data


def softplus(x: pt.TensorVariable) -> pt.TensorVariable:
return pm.math.log(1 + pm.math.exp(x))

0 comments on commit faaba0f

Please sign in to comment.