diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index 0a943d24..cd1ba8fd 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -58,7 +58,7 @@ def _validate_cols( raise ValueError(f"Column {required_col} has duplicate entries") def __repr__(self): - if self.model is None: + if not hasattr(self, "model"): return self._model_type else: return f"{self._model_type}\n{self.model.str_repr()}" diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index bdab7e4f..79266646 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -267,14 +267,16 @@ def get_target_transformer(self) -> Pipeline: def prior(self) -> Dataset: if self.idata is None or "prior" not in self.idata: raise RuntimeError( - "The model hasn't been fit yet, call .sample_prior_predictive() with extend_idata=True first" + "The model hasn't been sampled yet, call .sample_prior_predictive() first" ) return self.idata["prior"] @property - def prior_predictive(self) -> az.InferenceData: + def prior_predictive(self) -> Dataset: if self.idata is None or "prior_predictive" not in self.idata: - raise RuntimeError("The model hasn't been fit yet, call .fit() first") + raise RuntimeError( + "The model hasn't been sampled yet, call .sample_prior_predictive() first" + ) return self.idata["prior_predictive"] @property @@ -286,7 +288,9 @@ def fit_result(self) -> Dataset: @property def posterior_predictive(self) -> Dataset: if self.idata is None or "posterior_predictive" not in self.idata: - raise RuntimeError("The model hasn't been fit yet, call .fit() first") + raise RuntimeError( + "The model hasn't been fit yet, call .sample_posterior_predictive() first" + ) return self.idata["posterior_predictive"] def plot_prior_predictive( diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index cbde9f97..9a7ed5f3 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -1829,7 +1829,7 @@ def add_lift_test_measurements( model.add_lift_test_measurements(df_lift_test) """ - if self.model is None: + if not hasattr(self, "model"): raise RuntimeError( "The model has not been built yet. Please, build the model first." ) diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 4d33ade4..cebb0ebf 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -86,7 +86,7 @@ def __init__( self.model_config = ( self.default_model_config | model_config ) # parameters for priors etc. - self.model: pm.Model | None = None # Set by build_model + self.model: pm.Model self.idata: az.InferenceData | None = None # idata is generated during fitting self.is_fitted_ = False @@ -458,7 +458,7 @@ def fit( if self.X is None or self.y is None: raise ValueError("X and y must be set before calling build_model!") - if self.model is None: + if not hasattr(self, "model"): self.build_model(self.X, self.y) sampler_config = self.sampler_config.copy() @@ -466,11 +466,14 @@ def fit( sampler_config["random_seed"] = random_seed sampler_config.update(**kwargs) - sampler_config.update(**kwargs) - if self.model is not None: - with self.model: - sampler_args = {**self.sampler_config, **kwargs} - self.idata = pm.sample(**sampler_args) + sampler_args = {**self.sampler_config, **kwargs} + with self.model: + idata = pm.sample(**sampler_args) + + if self.idata: + self.idata.extend(idata, join="right") + else: + self.idata = idata X_df = pd.DataFrame(X, columns=X.columns) combined_data = pd.concat([X_df, y_df], axis=1) @@ -537,7 +540,7 @@ def sample_prior_predictive( X_pred, y_pred=None, samples: int | None = None, - extend_idata: bool = False, + extend_idata: bool = True, combined: bool = True, **kwargs, ): @@ -552,7 +555,7 @@ def sample_prior_predictive( Number of samples from the prior parameter distributions to generate. If not set, uses sampler_config['draws'] if that is available, otherwise defaults to 500. extend_idata : Boolean determining whether the predictions should be added to inference data object. - Defaults to False. + Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_prior_predictive @@ -567,21 +570,19 @@ def sample_prior_predictive( if samples is None: samples = self.sampler_config.get("draws", 500) - if self.model is None: + if not hasattr(self, "model"): self.build_model(X_pred, y_pred) self._data_setter(X_pred, y_pred) - if self.model is not None: - with self.model: # sample with new input data - prior_pred: az.InferenceData = pm.sample_prior_predictive( - samples, **kwargs - ) - self.set_idata_attrs(prior_pred) - if extend_idata: - if self.idata is not None: - self.idata.extend(prior_pred, join="right") - else: - self.idata = prior_pred + with self.model: # sample with new input data + prior_pred: az.InferenceData = pm.sample_prior_predictive(samples, **kwargs) + self.set_idata_attrs(prior_pred) + + if extend_idata: + if self.idata is not None: + self.idata.extend(prior_pred, join="right") + else: + self.idata = prior_pred prior_predictive_samples = az.extract( prior_pred, "prior_predictive", combined=combined @@ -590,7 +591,11 @@ def sample_prior_predictive( return prior_predictive_samples def sample_posterior_predictive( - self, X_pred, extend_idata: bool = True, combined: bool = True, **kwargs + self, + X_pred, + extend_idata: bool = True, + combined: bool = True, + **sample_posterior_predictive_kwargs, ): """ Sample from the model's posterior predictive distribution. @@ -603,7 +608,7 @@ def sample_posterior_predictive( Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. - **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive + **sample_posterior_predictive_kwargs: Additional arguments to pass to pymc.sample_posterior_predictive Returns ------- @@ -612,16 +617,21 @@ def sample_posterior_predictive( """ self._data_setter(X_pred) - with self.model: # type: ignore - post_pred = pm.sample_posterior_predictive(self.idata, **kwargs) - if extend_idata: - self.idata.extend(post_pred, join="right") # type: ignore + with self.model: + post_pred = pm.sample_posterior_predictive( + self.idata, **sample_posterior_predictive_kwargs + ) + + if extend_idata: + self.idata.extend(post_pred, join="right") # type: ignore - posterior_predictive_samples = az.extract( - post_pred, "posterior_predictive", combined=combined + variable_name = ( + "predictions" + if sample_posterior_predictive_kwargs.get("predictions") + else "posterior_predictive" ) - return posterior_predictive_samples + return az.extract(post_pred, variable_name, combined=combined) def get_params(self, deep=True): """ diff --git a/tests/conftest.py b/tests/conftest.py index 82c9ead4..10015292 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,7 +77,7 @@ def set_model_fit(model: CLVModel, fit: InferenceData | Dataset): assert "posterior" in fit.groups() else: fit = InferenceData(posterior=fit) - if model.model is None: + if not hasattr(model, "model"): model.build_model() model.idata = fit model.idata.add_groups(fit_data=model.data.to_xarray()) diff --git a/tests/mmm/test_base.py b/tests/mmm/test_base.py index 73fd0111..d4f59424 100644 --- a/tests/mmm/test_base.py +++ b/tests/mmm/test_base.py @@ -270,7 +270,9 @@ def test_calling_prior_predictive_before_fit_raises_error(test_mmm, toy_X, toy_y test_mmm.idata = None with pytest.raises( RuntimeError, - match=re.escape("The model hasn't been fit yet, call .fit() first"), + match=re.escape( + "The model hasn't been sampled yet, call .sample_prior_predictive() first" + ), ): test_mmm.prior_predictive @@ -297,7 +299,7 @@ def test_calling_prior_before_sample_prior_predictive_raises_error( with pytest.raises( RuntimeError, match=re.escape( - "The model hasn't been fit yet, call .sample_prior_predictive() with extend_idata=True first" + "The model hasn't been sampled yet, call .sample_prior_predictive() first", ), ): test_mmm.prior diff --git a/tests/model_builder/test_model_builder.py b/tests/test_model_builder.py similarity index 89% rename from tests/model_builder/test_model_builder.py rename to tests/test_model_builder.py index c8964016..d22c73a7 100644 --- a/tests/model_builder/test_model_builder.py +++ b/tests/test_model_builder.py @@ -135,9 +135,9 @@ def _save_input_params(self, idata): def output_var(self): return "output" - def _data_setter(self, X: pd.Series, y: pd.Series = None): + def _data_setter(self, X: pd.DataFrame, y: pd.Series = None): with self.model: - pm.set_data({"x": X.values}) + pm.set_data({"x": X["input"].values}) if y is not None: y = y.values if isinstance(y, pd.Series) else y pm.set_data({"y_data": y}) @@ -195,8 +195,8 @@ def test_save_load(fitted_model_instance): assert fitted_model_instance.id == test_builder2.id x_pred = rng.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) - pred1 = fitted_model_instance.predict(prediction_data["input"]) - pred2 = test_builder2.predict(prediction_data["input"]) + pred1 = fitted_model_instance.predict(prediction_data) + pred2 = test_builder2.predict(prediction_data) assert pred1.shape == pred2.shape temp.close() @@ -230,9 +230,9 @@ def test_fit(fitted_model_instance): assert fitted_model_instance.idata.posterior.dims["draw"] == 100 prediction_data = pd.DataFrame({"input": rng.uniform(low=0, high=1, size=100)}) - fitted_model_instance.predict(prediction_data["input"]) + fitted_model_instance.predict(prediction_data) post_pred = fitted_model_instance.sample_posterior_predictive( - prediction_data["input"], extend_idata=True, combined=True + prediction_data, extend_idata=True, combined=True ) assert ( post_pred[fitted_model_instance.output_var].shape[0] @@ -256,7 +256,7 @@ def test_predict(fitted_model_instance): rng = np.random.default_rng(42) x_pred = rng.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) - pred = fitted_model_instance.predict(prediction_data["input"]) + pred = fitted_model_instance.predict(prediction_data) # Perform elementwise comparison using numpy assert type(pred) == np.ndarray assert len(pred) > 0 @@ -269,7 +269,7 @@ def test_sample_posterior_predictive(fitted_model_instance, combined): x_pred = rng.uniform(low=0, high=1, size=n_pred) prediction_data = pd.DataFrame({"input": x_pred}) pred = fitted_model_instance.sample_posterior_predictive( - prediction_data["input"], combined=combined, extend_idata=True + prediction_data, combined=combined, extend_idata=True ) chains = fitted_model_instance.idata.sample_stats.dims["chain"] draws = fitted_model_instance.idata.sample_stats.dims["draw"] @@ -313,7 +313,7 @@ def test_sample_xxx_predictive_keeps_second( method_name = f"sample_{name}" method = getattr(fitted_model_instance, method_name) - X_pred = toy_X["input"] + X_pred = toy_X kwargs = { "X_pred": X_pred, @@ -329,3 +329,26 @@ def test_sample_xxx_predictive_keeps_second( sample = getattr(fitted_model_instance.idata, name) xr.testing.assert_allclose(sample, second_sample) + + +def test_prediction_kwarg(fitted_model_instance, toy_X): + result = fitted_model_instance.sample_posterior_predictive( + toy_X, + extend_idata=True, + predictions=True, + ) + assert "predictions" in fitted_model_instance.idata + assert "predictions_constant_data" in fitted_model_instance.idata + + assert isinstance(result, xr.Dataset) + + +def test_fit_after_prior_keeps_prior(toy_X, toy_y): + model = ModelBuilderTest() + model.sample_prior_predictive(toy_X) + assert "prior" in model.idata + assert "prior_predictive" in model.idata + + model.fit(X=toy_X, y=toy_y, chains=1, draws=100, tune=100) + assert "prior" in model.idata + assert "prior_predictive" in model.idata