Skip to content

Commit

Permalink
Don't erase previous sampling results in ModelBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 27, 2023
1 parent 916dce4 commit d5f7220
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 57 deletions.
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def get_target_transformer(self) -> Pipeline:
@property
def prior_predictive(self) -> az.InferenceData:
if self.idata is None or "prior_predictive" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
raise RuntimeError("Sample Prior predictive hasn't been called yet")
return self.idata["prior_predictive"]

@property
Expand Down
5 changes: 3 additions & 2 deletions pymc_marketing/mmm/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, List, Tuple, Union, cast

import numpy as np
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MaxAbsScaler, StandardScaler
Expand Down Expand Up @@ -32,7 +33,7 @@ class MaxAbsScaleTarget:

@preprocessing_method_y
def max_abs_scale_target_data(self, data: pd.Series) -> pd.Series:
target_vector = data.reshape(-1, 1)
target_vector = cast(np.ndarray, data.values).reshape(-1, 1)
transformers = [("scaler", MaxAbsScaler())]
pipeline = Pipeline(steps=transformers)
self.target_transformer: Pipeline = pipeline.fit(X=target_vector)
Expand Down
87 changes: 50 additions & 37 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

import arviz as az
import numpy as np
Expand Down Expand Up @@ -414,6 +414,20 @@ def load(cls, fname: str):

return model

def _add_fit_data_group(self, X, y) -> None:
y_df = pd.DataFrame({self.output_var: y})
X_df = pd.DataFrame(X, columns=X.columns)
combined_data = pd.concat([X_df, y_df], axis=1)
assert all(combined_data.columns), "All columns must have non-empty names"
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
# What if fit_data was already present?
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore

def fit(
self,
X: pd.DataFrame,
Expand All @@ -436,9 +450,6 @@ def fit(
The target values (real numbers).
progressbar : bool
Specifies whether the fit progressbar should be displayed
predictor_names: Optional[List[str]] = None,
Allows for custom naming of predictors given in a form of 2dArray
allows for naming of predictors when given in a form of np.ndarray, if not provided the predictors will be named like predictor1, predictor2...
random_seed : Optional[RandomState]
Provides sampler with initial random seed for obtaining reproducible samples
**kwargs : Any
Expand All @@ -455,12 +466,10 @@ def fit(
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
"""
if predictor_names is None:
predictor_names = []

if y is None:
y = np.zeros(X.shape[0])
y_df = pd.DataFrame({self.output_var: y})
self._generate_and_preprocess_model_data(X, y_df.values.flatten())
self._generate_and_preprocess_model_data(X, np.asarray(y).flatten())
if self.X is None or self.y is None:
raise ValueError("X and y must be set before calling build_model!")
self.build_model(self.X, self.y)
Expand All @@ -474,19 +483,16 @@ def fit(
if self.model is not None:
with self.model:
sampler_args = {**self.sampler_config, **kwargs}
self.idata = pm.sample(**sampler_args)
idata = pm.sample(**sampler_args)

X_df = pd.DataFrame(X, columns=X.columns)
combined_data = pd.concat([X_df, y_df], axis=1)
assert all(combined_data.columns), "All columns must have non-empty names"
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
if self.idata:
self.idata.extend(idata, join="right")
else:
self.idata = idata

self._add_fit_data_group(X, y)
self.set_idata_attrs(self.idata)

return self.idata # type: ignore

def predict(
Expand Down Expand Up @@ -522,9 +528,13 @@ def predict(
"""

posterior_predictive_samples = self.sample_posterior_predictive(
X_pred, extend_idata, combined=False, **kwargs
X_pred, combined=False, predictions=True, **kwargs
)

if extend_idata:
assert isinstance(self.idata, az.InferenceData)
self.idata.extend(posterior_predictive_samples, join="right")

if self.output_var not in posterior_predictive_samples:
raise KeyError(
f"Output variable {self.output_var} not found in posterior predictive samples."
Expand All @@ -540,7 +550,7 @@ def sample_prior_predictive(
X_pred,
y_pred=None,
samples: Optional[int] = None,
extend_idata: bool = False,
extend_idata: bool = True,
combined: bool = True,
**kwargs,
):
Expand All @@ -555,7 +565,7 @@ def sample_prior_predictive(
Number of samples from the prior parameter distributions to generate.
If not set, uses sampler_config['draws'] if that is available, otherwise defaults to 500.
extend_idata : Boolean determining whether the predictions should be added to inference data object.
Defaults to False.
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_prior_predictive
Expand All @@ -574,25 +584,26 @@ def sample_prior_predictive(
self.build_model(X_pred, y_pred)

self._data_setter(X_pred, y_pred)
if self.model is not None:
with self.model: # sample with new input data
prior_pred: az.InferenceData = pm.sample_prior_predictive(
samples, **kwargs
)
self.set_idata_attrs(prior_pred)
if extend_idata:
if self.idata is not None:
self.idata.extend(prior_pred)
else:
self.idata = prior_pred

with cast(pm.Model, self.model): # sample with new input data
prior_pred: az.InferenceData = pm.sample_prior_predictive(samples, **kwargs)
self.set_idata_attrs(prior_pred)

if extend_idata:
if self.idata is not None:
self.idata.extend(prior_pred, join="right")
else:
self.idata = prior_pred

prior_predictive_samples = az.extract(
prior_pred, "prior_predictive", combined=combined
)

return prior_predictive_samples

def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
def sample_posterior_predictive(
self, X_pred, extend_idata: bool = True, combined: bool = True, **kwargs
):
"""
Sample from the model's posterior predictive distribution.
Expand All @@ -613,10 +624,12 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
"""
self._data_setter(X_pred)

with self.model: # sample with new input data
with cast(pm.Model, self.model): # sample with new input data
post_pred = pm.sample_posterior_predictive(self.idata, **kwargs)
if extend_idata:
self.idata.extend(post_pred)

if extend_idata:
assert isinstance(self.idata, az.InferenceData)
self.idata.extend(post_pred, join="right")

posterior_predictive_samples = az.extract(
post_pred, "posterior_predictive", combined=combined
Expand Down
57 changes: 40 additions & 17 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def toy_y(toy_X) -> pd.Series:
return pd.Series(rng.integers(low=0, high=100, size=toy_X.shape[0]))


class ToyMMMDefaultTransform(BaseDelayedSaturatedMMM):
pass


class ToyMMMCustomTransform(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
pass


class TestBasePlotting:
@pytest.fixture(
scope="module",
Expand All @@ -49,38 +57,36 @@ class TestBasePlotting:
def plotting_mmm(self, request, toy_X, toy_y):
control, transform = request.param.split("-")
if transform == "default_transform":

class ToyMMM(BaseDelayedSaturatedMMM):
pass

mmm_class = ToyMMMDefaultTransform
elif transform == "target_transform":

class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
pass
mmm_class = ToyMMMCustomTransform
else:
raise ValueError(f"Unexpected transform {transform}")

if control == "without_controls":
mmm = ToyMMM(
mmm = mmm_class(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
)
elif control == "with_controls":
mmm = ToyMMM(
mmm = mmm_class(
date_column="date",
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
channel_columns=["channel_1", "channel_2"],
)
# fit the model
mmm.fit(
X=toy_X,
y=toy_y,
)
else:
raise ValueError(f"Unexpected control {control}")

mmm.sample_prior_predictive(toy_X, toy_y, extend_idata=True, combined=True)

# fake-fit the model for speed
mmm.idata.add_groups({"posterior": mmm.idata.prior})
mmm._add_fit_data_group(toy_X, toy_y)
mmm.set_idata_attrs()

mmm.sample_posterior_predictive(toy_X, extend_idata=True, combined=True)
mmm._prior_predictive = mmm.prior_predictive
mmm._fit_result = mmm.fit_result
mmm._posterior_predictive = mmm.posterior_predictive

return mmm

Expand Down Expand Up @@ -109,3 +115,20 @@ def test_plots(self, plotting_mmm, func_plot_name, kwargs_plot) -> None:
func = plotting_mmm.__getattribute__(func_plot_name)
assert isinstance(func(**kwargs_plot), plt.Figure)
plt.close("all")

def test_plot_prior_predictive_without_fit(self, toy_X, toy_y):
"""Test that plot_prior_predictive works during the workflow"""
mmm = ToyMMMDefaultTransform(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
)
mmm.sample_prior_predictive(toy_X, toy_y)
assert isinstance(mmm.plot_prior_predictive(), plt.Figure)

# We can also plot it after `fit()`
mmm.fit(toy_X, toy_y, chains=1, draws=1, tune=10)

mmm.sample_prior_predictive(toy_X, toy_y)
assert isinstance(mmm.plot_prior_predictive(), plt.Figure)
plt.close("all")

0 comments on commit d5f7220

Please sign in to comment.