From c49ab8e65155c46727d4c181a388ca32e9efd560 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 20 Dec 2023 19:37:38 +0100 Subject: [PATCH] support for predict_posterior method with adstock effects --- pymc_marketing/mmm/delayed_saturated_mmm.py | 52 ++++++++++++++++++++- tests/mmm/test_delayed_saturated_mmm.py | 18 ++++--- 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index ab841370..fc813af5 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -83,7 +83,7 @@ def default_sampler_config(self) -> Dict: @property def output_var(self): """Defines target variable for the model""" - return "y" + return "likelihood" def _generate_and_preprocess_model_data( # type: ignore self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray] @@ -843,3 +843,53 @@ def plot_channel_contributions_grid( ylabel="contribution", ) return fig + + def predict_posterior( + self, + X_pred: Union[np.ndarray, pd.DataFrame, pd.Series], + extend_idata: bool = True, + combined: bool = True, + include_last_observations: bool = False, + **kwargs, + ) -> DataArray: + """ + Generate posterior predictive samples on unseen data. + + Parameters + --------- + X_pred : array-like if sklearn is available, otherwise array, shape (n_pred, n_features) + The input data used for prediction. + extend_idata : Boolean determining whether the predictions should be added to inference data object. + Defaults to True. + combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. + Defaults to True. + include_last_observations: Whether to include last observed data for carryover adstock and saturation effect. + Defaults to False. + **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive + + Returns + ------- + y_pred : DataArray, shape (n_pred, chains * draws) if combined is True, otherwise (chains, draws, n_pred) + Posterior predictive samples for each input X_pred + """ + if not isinstance(X_pred, pd.DataFrame): + raise ValueError("X_pred must be a pandas DataFrame") + + if include_last_observations: + X_pred = pd.concat([self.X.iloc[-self.adstock_max_lag :], X_pred]) + + posterior_predictive_samples = self.sample_posterior_predictive( + X_pred, extend_idata, combined, **kwargs + ) + + if self.output_var not in posterior_predictive_samples: + raise KeyError( + f"Output variable {self.output_var} not found in posterior predictive samples." + ) + + if include_last_observations: + posterior_predictive_samples = posterior_predictive_samples.isel( + date=slice(self.adstock_max_lag, None) + ) + + return posterior_predictive_samples[self.output_var] diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index e5801253..9999a346 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -571,7 +571,7 @@ def test_new_data_predictions( ) -> None: mmm = request.getfixturevalue(model_name) n = new_dates.size - new_X = pd.DataFrame( + X_pred = pd.DataFrame( { "date": new_dates, "channel_1": rng.integers(low=0, high=400, size=n), @@ -583,20 +583,24 @@ def test_new_data_predictions( } ) - with pytest.raises( - TypeError, - match=r"The DType could not be promoted by", - ): - mmm.predict_posterior(X_pred=new_X) + pp_without = mmm.predict_posterior( + X_pred=X_pred, include_last_observations=False + ) + pp_with = mmm.predict_posterior(X_pred=X_pred, include_last_observations=True) + + assert pp_without.coords.equals(pp_with.coords) posterior_predictive = mmm.sample_posterior_predictive( - X_pred=new_X, extend_idata=False, combined=True + X_pred=X_pred, extend_idata=False, combined=True ) pd.testing.assert_index_equal( pd.DatetimeIndex(posterior_predictive.coords["date"]), new_dates ) assert posterior_predictive["likelihood"].shape[0] == new_dates.size + posterior_predictive_mean = mmm.predict(X_pred=X_pred) + assert posterior_predictive_mean.shape[0] == new_dates.size + @pytest.mark.parametrize( argnames="model_config", argvalues=[