Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time varying intercept #628

Merged
merged 53 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
faaba0f
Add time-varying prior functionality to DelayedSaturatedMMM
ulfaslak Mar 19, 2024
3ce5ac4
resolve wd's comments
ulfaslak Mar 19, 2024
dac69da
Merge branch 'main' into time-varying-prior
wd60622 Mar 19, 2024
ec316b4
Merge branch 'main' of github.com:pymc-labs/pymc-marketing into time-…
ulfaslak Apr 4, 2024
7e8aee7
resolve failing pre-commits
ulfaslak Apr 15, 2024
b0aaa2d
merge previous PR into this
ulfaslak Apr 15, 2024
16f9414
add tvp_kwargs to model_config
ulfaslak Apr 15, 2024
8ab0532
fix typo
ulfaslak Apr 15, 2024
c06c06d
replace softplus
ulfaslak Apr 16, 2024
b70cd06
resolve minor review comments
ulfaslak Apr 16, 2024
9b1287b
Merge branch 'main' of github.com:pymc-labs/pymc-marketing into time-…
ulfaslak Apr 16, 2024
e01f9ac
Add option to supply `ax` to `plot_posterior_predictive`
ulfaslak Apr 17, 2024
59708cf
bugfix: time_index was not set correctly for OOS
ulfaslak Apr 17, 2024
83e7103
Clean up example notebook
ulfaslak Apr 17, 2024
196d720
Make utility function `transform_1d_array`
ulfaslak Apr 17, 2024
c77ddce
'tvp_kwargs' -> 'intercept_tvp_kwargs'
ulfaslak Apr 17, 2024
f25bc6e
move `infer_time_index` into utils
ulfaslak Apr 18, 2024
1b3e73e
add tests for new utils
ulfaslak Apr 18, 2024
32f11e1
small fixes (found in tests)
ulfaslak Apr 18, 2024
b35df87
add tests to cover all added cases
ulfaslak Apr 18, 2024
293e752
pull from main
ulfaslak Apr 18, 2024
910d223
fix ruff check
ulfaslak Apr 18, 2024
15308ab
update typehints
ulfaslak Apr 19, 2024
19d13d8
resolve review comments
ulfaslak Apr 19, 2024
0fad32b
refactor model logic for tv intercept
ulfaslak Apr 19, 2024
973a921
address review comment for util test
ulfaslak Apr 19, 2024
95c7ee8
.
ulfaslak Apr 19, 2024
4158768
fix documentation link
ulfaslak Apr 19, 2024
843ec21
change variable name
ulfaslak Apr 19, 2024
1c90255
fix hsgp_dims
ulfaslak Apr 19, 2024
d5e1699
update time_varying_prior to be centered on 1
ulfaslak Apr 19, 2024
74c09c0
review fixes
ulfaslak Apr 23, 2024
a6c5972
fix broken test
ulfaslak Apr 24, 2024
374303c
add final tests
ulfaslak Apr 24, 2024
3ac1f6a
Merge branch 'main' of github.com:pymc-labs/pymc-marketing into time-…
ulfaslak Apr 24, 2024
0bfbbe4
Merge branch 'main' into time-varying-intercept
ulfaslak Apr 24, 2024
665d1d2
fix coverage issues
ulfaslak Apr 24, 2024
6defce8
Merge branch 'time-varying-intercept' of github.com:pymc-labs/pymc-ma…
ulfaslak Apr 24, 2024
5f5be67
Update tests/mmm/test_tvp.py
ulfaslak Apr 25, 2024
33f8f7b
Update pymc_marketing/mmm/tvp.py
ulfaslak Apr 25, 2024
7677015
Update tests/mmm/test_tvp.py
ulfaslak Apr 29, 2024
e0b8ad6
Update tests/mmm/test_tvp.py
ulfaslak Apr 29, 2024
3ac0585
significant improvements to notebook
ulfaslak Apr 29, 2024
3d207b2
Merge branch 'time-varying-intercept' of github.com:pymc-labs/pymc-ma…
ulfaslak Apr 29, 2024
ba1de7b
fix heading
ulfaslak Apr 29, 2024
85cf5aa
Merge branch 'main' into time-varying-intercept
ulfaslak Apr 29, 2024
0f89bd4
update notebook to make it EVEN better
ulfaslak Apr 30, 2024
23ea9ec
Merge branch 'time-varying-intercept' of github.com:pymc-labs/pymc-ma…
ulfaslak Apr 30, 2024
d4688be
update legend, add watermark
ulfaslak Apr 30, 2024
d3ced36
fix intro
ulfaslak Apr 30, 2024
f8c0b3c
fix broken test
ulfaslak May 1, 2024
3f88589
copy sweep with grammarly
ulfaslak May 1, 2024
f0e59e6
Merge branch 'main' into time-varying-intercept
ulfaslak May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/notebooks/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
mmm/mmm_example
mmm/mmm_budget_allocation_example
mmm/mmm_lift_test
mmm/mmm_tvp_example
:::

:::{toctree}
Expand Down
8,321 changes: 8,321 additions & 0 deletions docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Large diffs are not rendered by default.

401 changes: 401 additions & 0 deletions docs/source/notebooks/mmm/mock_cgp_data-no-target.csv

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pymc_marketing/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DAYS_IN_YEAR = 365.25
62 changes: 46 additions & 16 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd
import pymc as pm
import seaborn as sns
from numpy.typing import NDArray
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from xarray import DataArray, Dataset
Expand All @@ -31,6 +32,7 @@
find_sigmoid_inflection_point,
sigmoid_saturation,
standardize_scenarios_dict_keys,
transform_1d_array,
)
from pymc_marketing.mmm.validating import (
ValidateChannelColumns,
Expand All @@ -55,13 +57,19 @@
sampler_config: dict | None = None,
**kwargs,
) -> None:
self.X: pd.DataFrame | None = None
self.y: pd.Series | np.ndarray | None = None
self.date_column: str = date_column
self.channel_columns: list[str] | tuple[str] = channel_columns

self.n_channel: int = len(channel_columns)
self._fit_result: az.InferenceData | None = None
self._posterior_predictive: az.InferenceData | None = None

self.X: pd.DataFrame
self.y: pd.Series | np.ndarray

self._time_resolution: int
self._time_index: NDArray[np.int_]
self._time_index_mid: int
self._fit_result: az.InferenceData
self._posterior_predictive: az.InferenceData
super().__init__(model_config=model_config, sampler_config=sampler_config)

@property
Expand Down Expand Up @@ -314,7 +322,7 @@
return fig

def plot_posterior_predictive(
self, original_scale: bool = False, **plt_kwargs: Any
self, original_scale: bool = False, ax: plt.Axes = None, **plt_kwargs: Any
) -> plt.Figure:
posterior_predictive_data: Dataset = self.posterior_predictive
likelihood_hdi_94: DataArray = az.hdi(
Expand All @@ -332,10 +340,14 @@
Xt=likelihood_hdi_50
)

fig, ax = plt.subplots(**plt_kwargs)
if ax is None:
fig, ax = plt.subplots(**plt_kwargs)
else:
fig = ax.figure

if self.X is not None and self.y is not None:
ax.fill_between(
x=self.X[self.date_column],
x=posterior_predictive_data.date,
y1=likelihood_hdi_94[:, 0],
y2=likelihood_hdi_94[:, 1],
color="C0",
Expand All @@ -344,19 +356,29 @@
)

ax.fill_between(
x=self.X[self.date_column],
x=posterior_predictive_data.date,
y1=likelihood_hdi_50[:, 0],
y2=likelihood_hdi_50[:, 1],
color="C0",
alpha=0.3,
label="$50\%$ HDI", # noqa: W605
)

target_to_plot: np.ndarray = np.asarray(
self.y if original_scale else self.preprocessed_data["y"] # type: ignore
target_to_plot = np.asarray(
self.y
if original_scale
else transform_1d_array(self.get_target_transformer().transform, self.y)
)

if len(target_to_plot) != len(posterior_predictive_data.date):
raise ValueError(

Check warning on line 374 in pymc_marketing/mmm/base.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/base.py#L374

Added line #L374 was not covered by tests
"The length of the target variable doesn't match the length of the date column. "
"If you are predicting out-of-sample, please overwrite `self.y` with the "
"corresponding (non-transformed) target variable."
)

ax.plot(
np.asarray(self.X[self.date_column]),
np.asarray(posterior_predictive_data.date),
target_to_plot,
color="black",
)
Expand Down Expand Up @@ -435,11 +457,18 @@
intercept = az.extract(
self.fit_result, var_names=["intercept"], combined=False
)
intercept_hdi = np.repeat(
a=az.hdi(intercept).intercept.data[None, ...],
repeats=self.X[self.date_column].shape[0],
axis=0,
)

if intercept.ndim == 2:
# Intercept has a stationary prior
intercept_hdi = np.repeat(
a=az.hdi(intercept).intercept.data[None, ...],
repeats=self.X[self.date_column].shape[0],
axis=0,
)
elif intercept.ndim == 3:

Check warning on line 468 in pymc_marketing/mmm/base.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/base.py#L468

Added line #L468 was not covered by tests
# Intercept has a time-varying prior
intercept_hdi = az.hdi(intercept).intercept.data

Check warning on line 470 in pymc_marketing/mmm/base.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/base.py#L470

Added line #L470 was not covered by tests

ax.plot(
np.asarray(self.X[self.date_column]),
np.full(len(self.X[self.date_column]), intercept.mean().data),
Expand Down Expand Up @@ -1028,6 +1057,7 @@

def legend_title_func(channel):
return "Legend"

else:
nrows = len(channels_to_plot)
figsize = (12, 4 * len(channels_to_plot))
Expand Down
79 changes: 73 additions & 6 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from pytensor.tensor import TensorVariable
from xarray import DataArray, Dataset

from pymc_marketing.constants import DAYS_IN_YEAR
from pymc_marketing.mmm.base import MMM
from pymc_marketing.mmm.lift_test import (
add_logistic_empirical_lift_measurements_to_likelihood,
scale_lift_measurements,
)
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation
from pymc_marketing.mmm.tvp import create_time_varying_intercept, infer_time_index
from pymc_marketing.mmm.utils import (
apply_sklearn_transformer_across_dim,
create_new_spend_data,
Expand All @@ -47,6 +49,7 @@
date_column: str,
channel_columns: list[str],
adstock_max_lag: int,
time_varying_intercept: bool = False,
model_config: dict | None = None,
sampler_config: dict | None = None,
validate_data: bool = True,
Expand All @@ -62,6 +65,10 @@
Column name of the date variable.
channel_columns : List[str]
Column names of the media channel variables.
adstock_max_lag : int
Number of lags to consider in the adstock transformation.
time_varying_intercept : bool, optional
Whether to consider time-varying intercept, by default False.
model_config : Dictionary, optional
dictionary of parameters that initialise model configuration.
Class-default defined by the user default_model_config method.
Expand All @@ -79,6 +86,7 @@
"""
self.control_columns = control_columns
self.adstock_max_lag = adstock_max_lag
self.time_varying_intercept = time_varying_intercept
self.yearly_seasonality = yearly_seasonality
self.date_column = date_column
self.validate_data = validate_data
Expand Down Expand Up @@ -112,6 +120,24 @@
----------
X : Union[pd.DataFrame, pd.Series], shape (n_obs, n_features)
y : Union[pd.Series, np.ndarray], shape (n_obs,)

Sets
----
preprocessed_data : Dict[str, Union[pd.DataFrame, pd.Series]]
Preprocessed data for the model.
X : pd.DataFrame
A filtered version of the input `X`, such that it is guaranteed that
it contains only the `date_column`, the columns that are specified
in the `channel_columns` and `control_columns`, and fourier features
if `yearly_seasonality=True`.
y : Union[pd.Series, np.ndarray]
The target variable for the model (as provided).
_time_index : np.ndarray
The index of the date column. Used by TVP
_time_index_mid : int
The middle index of the date index. Used by TVP.
_time_resolution: int
The time resolution of the date index. Used by TVP.
ulfaslak marked this conversation as resolved.
Show resolved Hide resolved
"""
date_data = X[self.date_column]
channel_data = X[self.channel_columns]
Expand Down Expand Up @@ -152,6 +178,13 @@
self.X: pd.DataFrame = X_data
self.y: pd.Series | np.ndarray = y

if self.time_varying_intercept:
self._time_index = np.arange(0, X.shape[0])
self._time_index_mid = X.shape[0] // 2
self._time_resolution = (
self.X[self.date_column].iloc[1] - self.X[self.date_column].iloc[0]
).days

def _save_input_params(self, idata) -> None:
"""Saves input parameters to the attrs of idata."""
idata.attrs["date_column"] = json.dumps(self.date_column)
Expand Down Expand Up @@ -355,9 +388,23 @@
dims="date",
)

intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
)
if self.time_varying_intercept:
ulfaslak marked this conversation as resolved.
Show resolved Hide resolved
time_index = pm.Data(
"time_index",
self._time_index,
dims="date",
)
Copy link
Contributor

@cetagostini cetagostini Apr 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe time_index should be mutable=True no?

time_index = pm.MutableData(
    "time_index",
    value=np.arange(self.x_channel_data.shape[0]),
    dims="date",
)

intercept = create_time_varying_intercept(
time_index,
self._time_index_mid,
self._time_resolution,
self.intercept_dist,
self.model_config,
)
else:
intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
)

beta_channel = self.beta_channel_dist(
name="beta_channel",
Expand Down Expand Up @@ -391,9 +438,11 @@
var=logistic_saturation(x=channel_adstock, lam=lam),
dims=("date", "channel"),
)

channel_contributions_var = channel_adstock_saturated * beta_channel
channel_contributions = pm.Deterministic(
name="channel_contributions",
var=channel_adstock_saturated * beta_channel,
var=channel_contributions_var,
dims=("date", "channel"),
)

Expand Down Expand Up @@ -468,7 +517,10 @@
@property
def default_model_config(self) -> dict:
return {
"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"intercept": {
"dist": "Normal",
"kwargs": {"mu": 0, "sigma": 2},
},
"beta_channel": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
"alpha": {"dist": "Beta", "kwargs": {"alpha": 1, "beta": 3}},
"lam": {"dist": "Gamma", "kwargs": {"alpha": 3, "beta": 1}},
Expand All @@ -480,6 +532,14 @@
},
"gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}},
"intercept_tvp_kwargs": {
"m": 200,
"L": None,
"eta_lam": 1,
"ls_mu": None,
"ls_sigma": 10,
"cov_func": None,
Comment on lines +536 to +541
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we expect to happen with these None defaults?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's are checks starting on like 399, which set them if they are None. The reason that we wan't directly set them here is that the best defaults are estimated from the data, and we can't know here.

},
}

def _get_fourier_models_data(self, X) -> pd.DataFrame:
Expand All @@ -494,7 +554,9 @@
date_data: pd.Series = pd.to_datetime(
arg=X[self.date_column], format="%Y-%m-%d"
)
periods: npt.NDArray[np.float_] = date_data.dt.dayofyear.to_numpy() / 365.25
periods: npt.NDArray[np.float_] = (
date_data.dt.dayofyear.to_numpy() / DAYS_IN_YEAR
)
return generate_fourier_modes(
periods=periods,
n_order=self.yearly_seasonality,
Expand Down Expand Up @@ -678,6 +740,11 @@
if hasattr(self, "fourier_columns"):
data["fourier_data"] = self._get_fourier_models_data(X)

if self.time_varying_intercept:
data["time_index"] = infer_time_index(

Check warning on line 744 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L744

Added line #L744 was not covered by tests
X[self.date_column], self.X[self.date_column], self._time_resolution
)

if y is not None:
if isinstance(y, pd.Series):
data["target"] = (
Expand Down
Loading
Loading