Skip to content

Commit

Permalink
support for predict_posterior method with adstock effects
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Dec 20, 2023
1 parent f99600b commit c49ab8e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 8 deletions.
52 changes: 51 additions & 1 deletion pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def default_sampler_config(self) -> Dict:
@property
def output_var(self):
"""Defines target variable for the model"""
return "y"
return "likelihood"

def _generate_and_preprocess_model_data( # type: ignore
self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray]
Expand Down Expand Up @@ -843,3 +843,53 @@ def plot_channel_contributions_grid(
ylabel="contribution",
)
return fig

def predict_posterior(
self,
X_pred: Union[np.ndarray, pd.DataFrame, pd.Series],
extend_idata: bool = True,
combined: bool = True,
include_last_observations: bool = False,
**kwargs,
) -> DataArray:
"""
Generate posterior predictive samples on unseen data.
Parameters
---------
X_pred : array-like if sklearn is available, otherwise array, shape (n_pred, n_features)
The input data used for prediction.
extend_idata : Boolean determining whether the predictions should be added to inference data object.
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
include_last_observations: Whether to include last observed data for carryover adstock and saturation effect.
Defaults to False.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
Returns
-------
y_pred : DataArray, shape (n_pred, chains * draws) if combined is True, otherwise (chains, draws, n_pred)
Posterior predictive samples for each input X_pred
"""
if not isinstance(X_pred, pd.DataFrame):
raise ValueError("X_pred must be a pandas DataFrame")

Check warning on line 876 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L876

Added line #L876 was not covered by tests

if include_last_observations:
X_pred = pd.concat([self.X.iloc[-self.adstock_max_lag :], X_pred])

posterior_predictive_samples = self.sample_posterior_predictive(
X_pred, extend_idata, combined, **kwargs
)

if self.output_var not in posterior_predictive_samples:
raise KeyError(

Check warning on line 886 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L886

Added line #L886 was not covered by tests
f"Output variable {self.output_var} not found in posterior predictive samples."
)

if include_last_observations:
posterior_predictive_samples = posterior_predictive_samples.isel(
date=slice(self.adstock_max_lag, None)
)

return posterior_predictive_samples[self.output_var]
18 changes: 11 additions & 7 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def test_new_data_predictions(
) -> None:
mmm = request.getfixturevalue(model_name)
n = new_dates.size
new_X = pd.DataFrame(
X_pred = pd.DataFrame(
{
"date": new_dates,
"channel_1": rng.integers(low=0, high=400, size=n),
Expand All @@ -583,20 +583,24 @@ def test_new_data_predictions(
}
)

with pytest.raises(
TypeError,
match=r"The DType <class 'numpy.dtype\[datetime64\]'> could not be promoted by",
):
mmm.predict_posterior(X_pred=new_X)
pp_without = mmm.predict_posterior(
X_pred=X_pred, include_last_observations=False
)
pp_with = mmm.predict_posterior(X_pred=X_pred, include_last_observations=True)

assert pp_without.coords.equals(pp_with.coords)

posterior_predictive = mmm.sample_posterior_predictive(
X_pred=new_X, extend_idata=False, combined=True
X_pred=X_pred, extend_idata=False, combined=True
)
pd.testing.assert_index_equal(
pd.DatetimeIndex(posterior_predictive.coords["date"]), new_dates
)
assert posterior_predictive["likelihood"].shape[0] == new_dates.size

posterior_predictive_mean = mmm.predict(X_pred=X_pred)
assert posterior_predictive_mean.shape[0] == new_dates.size

@pytest.mark.parametrize(
argnames="model_config",
argvalues=[
Expand Down

0 comments on commit c49ab8e

Please sign in to comment.