From d5f72205fc5e98fdf68ecec0325d970a2b33359c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 27 Oct 2023 16:05:30 +0200 Subject: [PATCH] Don't erase previous sampling results in ModelBuilder --- pymc_marketing/mmm/base.py | 2 +- pymc_marketing/mmm/preprocessing.py | 5 +- pymc_marketing/model_builder.py | 87 +++++++++++++++++------------ tests/mmm/test_plotting.py | 57 +++++++++++++------ 4 files changed, 94 insertions(+), 57 deletions(-) diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index 99a8501c..0fefb3c9 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -235,7 +235,7 @@ def get_target_transformer(self) -> Pipeline: @property def prior_predictive(self) -> az.InferenceData: if self.idata is None or "prior_predictive" not in self.idata: - raise RuntimeError("The model hasn't been fit yet, call .fit() first") + raise RuntimeError("Sample Prior predictive hasn't been called yet") return self.idata["prior_predictive"] @property diff --git a/pymc_marketing/mmm/preprocessing.py b/pymc_marketing/mmm/preprocessing.py index 470c870c..602c8177 100644 --- a/pymc_marketing/mmm/preprocessing.py +++ b/pymc_marketing/mmm/preprocessing.py @@ -1,5 +1,6 @@ -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Tuple, Union, cast +import numpy as np import pandas as pd from sklearn.pipeline import Pipeline from sklearn.preprocessing import MaxAbsScaler, StandardScaler @@ -32,7 +33,7 @@ class MaxAbsScaleTarget: @preprocessing_method_y def max_abs_scale_target_data(self, data: pd.Series) -> pd.Series: - target_vector = data.reshape(-1, 1) + target_vector = cast(np.ndarray, data.values).reshape(-1, 1) transformers = [("scaler", MaxAbsScaler())] pipeline = Pipeline(steps=transformers) self.target_transformer: Pipeline = pipeline.fit(X=target_vector) diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index b846d7fe..efa39944 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -18,7 +18,7 @@ import warnings from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast import arviz as az import numpy as np @@ -414,6 +414,20 @@ def load(cls, fname: str): return model + def _add_fit_data_group(self, X, y) -> None: + y_df = pd.DataFrame({self.output_var: y}) + X_df = pd.DataFrame(X, columns=X.columns) + combined_data = pd.concat([X_df, y_df], axis=1) + assert all(combined_data.columns), "All columns must have non-empty names" + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + # What if fit_data was already present? + self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + def fit( self, X: pd.DataFrame, @@ -436,9 +450,6 @@ def fit( The target values (real numbers). progressbar : bool Specifies whether the fit progressbar should be displayed - predictor_names: Optional[List[str]] = None, - Allows for custom naming of predictors given in a form of 2dArray - allows for naming of predictors when given in a form of np.ndarray, if not provided the predictors will be named like predictor1, predictor2... random_seed : Optional[RandomState] Provides sampler with initial random seed for obtaining reproducible samples **kwargs : Any @@ -455,12 +466,10 @@ def fit( Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... """ - if predictor_names is None: - predictor_names = [] + if y is None: y = np.zeros(X.shape[0]) - y_df = pd.DataFrame({self.output_var: y}) - self._generate_and_preprocess_model_data(X, y_df.values.flatten()) + self._generate_and_preprocess_model_data(X, np.asarray(y).flatten()) if self.X is None or self.y is None: raise ValueError("X and y must be set before calling build_model!") self.build_model(self.X, self.y) @@ -474,19 +483,16 @@ def fit( if self.model is not None: with self.model: sampler_args = {**self.sampler_config, **kwargs} - self.idata = pm.sample(**sampler_args) + idata = pm.sample(**sampler_args) - X_df = pd.DataFrame(X, columns=X.columns) - combined_data = pd.concat([X_df, y_df], axis=1) - assert all(combined_data.columns), "All columns must have non-empty names" - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - category=UserWarning, - message="The group fit_data is not defined in the InferenceData scheme", - ) - self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + if self.idata: + self.idata.extend(idata, join="right") + else: + self.idata = idata + + self._add_fit_data_group(X, y) self.set_idata_attrs(self.idata) + return self.idata # type: ignore def predict( @@ -522,9 +528,13 @@ def predict( """ posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, combined=False, **kwargs + X_pred, combined=False, predictions=True, **kwargs ) + if extend_idata: + assert isinstance(self.idata, az.InferenceData) + self.idata.extend(posterior_predictive_samples, join="right") + if self.output_var not in posterior_predictive_samples: raise KeyError( f"Output variable {self.output_var} not found in posterior predictive samples." @@ -540,7 +550,7 @@ def sample_prior_predictive( X_pred, y_pred=None, samples: Optional[int] = None, - extend_idata: bool = False, + extend_idata: bool = True, combined: bool = True, **kwargs, ): @@ -555,7 +565,7 @@ def sample_prior_predictive( Number of samples from the prior parameter distributions to generate. If not set, uses sampler_config['draws'] if that is available, otherwise defaults to 500. extend_idata : Boolean determining whether the predictions should be added to inference data object. - Defaults to False. + Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_prior_predictive @@ -574,17 +584,16 @@ def sample_prior_predictive( self.build_model(X_pred, y_pred) self._data_setter(X_pred, y_pred) - if self.model is not None: - with self.model: # sample with new input data - prior_pred: az.InferenceData = pm.sample_prior_predictive( - samples, **kwargs - ) - self.set_idata_attrs(prior_pred) - if extend_idata: - if self.idata is not None: - self.idata.extend(prior_pred) - else: - self.idata = prior_pred + + with cast(pm.Model, self.model): # sample with new input data + prior_pred: az.InferenceData = pm.sample_prior_predictive(samples, **kwargs) + self.set_idata_attrs(prior_pred) + + if extend_idata: + if self.idata is not None: + self.idata.extend(prior_pred, join="right") + else: + self.idata = prior_pred prior_predictive_samples = az.extract( prior_pred, "prior_predictive", combined=combined @@ -592,7 +601,9 @@ def sample_prior_predictive( return prior_predictive_samples - def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): + def sample_posterior_predictive( + self, X_pred, extend_idata: bool = True, combined: bool = True, **kwargs + ): """ Sample from the model's posterior predictive distribution. @@ -613,10 +624,12 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): """ self._data_setter(X_pred) - with self.model: # sample with new input data + with cast(pm.Model, self.model): # sample with new input data post_pred = pm.sample_posterior_predictive(self.idata, **kwargs) - if extend_idata: - self.idata.extend(post_pred) + + if extend_idata: + assert isinstance(self.idata, az.InferenceData) + self.idata.extend(post_pred, join="right") posterior_predictive_samples = az.extract( post_pred, "posterior_predictive", combined=combined diff --git a/tests/mmm/test_plotting.py b/tests/mmm/test_plotting.py index 35538bcb..5ece2416 100644 --- a/tests/mmm/test_plotting.py +++ b/tests/mmm/test_plotting.py @@ -36,6 +36,14 @@ def toy_y(toy_X) -> pd.Series: return pd.Series(rng.integers(low=0, high=100, size=toy_X.shape[0])) +class ToyMMMDefaultTransform(BaseDelayedSaturatedMMM): + pass + + +class ToyMMMCustomTransform(BaseDelayedSaturatedMMM, MaxAbsScaleTarget): + pass + + class TestBasePlotting: @pytest.fixture( scope="module", @@ -49,38 +57,36 @@ class TestBasePlotting: def plotting_mmm(self, request, toy_X, toy_y): control, transform = request.param.split("-") if transform == "default_transform": - - class ToyMMM(BaseDelayedSaturatedMMM): - pass - + mmm_class = ToyMMMDefaultTransform elif transform == "target_transform": - - class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget): - pass + mmm_class = ToyMMMCustomTransform + else: + raise ValueError(f"Unexpected transform {transform}") if control == "without_controls": - mmm = ToyMMM( + mmm = mmm_class( date_column="date", channel_columns=["channel_1", "channel_2"], adstock_max_lag=4, ) elif control == "with_controls": - mmm = ToyMMM( + mmm = mmm_class( date_column="date", adstock_max_lag=4, control_columns=["control_1", "control_2"], channel_columns=["channel_1", "channel_2"], ) - # fit the model - mmm.fit( - X=toy_X, - y=toy_y, - ) + else: + raise ValueError(f"Unexpected control {control}") + mmm.sample_prior_predictive(toy_X, toy_y, extend_idata=True, combined=True) + + # fake-fit the model for speed + mmm.idata.add_groups({"posterior": mmm.idata.prior}) + mmm._add_fit_data_group(toy_X, toy_y) + mmm.set_idata_attrs() + mmm.sample_posterior_predictive(toy_X, extend_idata=True, combined=True) - mmm._prior_predictive = mmm.prior_predictive - mmm._fit_result = mmm.fit_result - mmm._posterior_predictive = mmm.posterior_predictive return mmm @@ -109,3 +115,20 @@ def test_plots(self, plotting_mmm, func_plot_name, kwargs_plot) -> None: func = plotting_mmm.__getattribute__(func_plot_name) assert isinstance(func(**kwargs_plot), plt.Figure) plt.close("all") + + def test_plot_prior_predictive_without_fit(self, toy_X, toy_y): + """Test that plot_prior_predictive works during the workflow""" + mmm = ToyMMMDefaultTransform( + date_column="date", + channel_columns=["channel_1", "channel_2"], + adstock_max_lag=4, + ) + mmm.sample_prior_predictive(toy_X, toy_y) + assert isinstance(mmm.plot_prior_predictive(), plt.Figure) + + # We can also plot it after `fit()` + mmm.fit(toy_X, toy_y, chains=1, draws=1, tune=10) + + mmm.sample_prior_predictive(toy_X, toy_y) + assert isinstance(mmm.plot_prior_predictive(), plt.Figure) + plt.close("all")