diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 91463693..71309dd6 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -270,29 +270,30 @@ def build_model( def sample_model( self, - prior_kwargs: Optional[Dict] = None, - posterior_kwargs: Optional[Dict] = None, - **kwargs, - ): + prior_predictive: bool = False, + posterior_predictive: bool = False, + prior_predictive_kwargs: Optional[Dict] = None, + posterior_predictive_kwargs: Optional[Dict] = None, + **sample_kwargs, + ) -> az.InferenceData: """ Sample from the PyMC model. Parameters ---------- - prior_kwargs : dict, optional + prior_predictive : bool, optional + If True, the inference data will be extended with samples drawn from the prior predictive distribution. + Defaults to False. + posterior_predictive : bool, optional + If True, the inference data will be extended with samples drawn from the posterior predictive distribution. + Defaults to False. + prior_predictive_kwargs : dict, optional keyword arguments to pass to the PyMC prior predictive sampler. - posterior_kwargs : dict, optional + posterior_predictive_kwargs : dict, optional keyword arguments to pass to the PyMC posterior predictive sampler. + **kwargs : dict Additional keyword arguments to pass to the PyMC sampler. - - prior_predictive : bool, optional - If True, the inference data will be extended with samples drawn from the prior predictive distribution. - Defaults to False. - - - posterior_predictive : bool, optional - If True, the inference data will be extended with samples drawn from the posterior predictive distribution. - Defaults to False. - Returns ------- xarray.Dataset The PyMC samples dataset. @@ -306,28 +307,29 @@ def sample_model( -------- >>> self.build_model() >>> idata = self.sample_model(draws=100, tune=10) + + Returns + ------- + idata : az.InferenceData + InferenceData object containing the samples. """ if self.model is None: raise RuntimeError( "The model hasn't been built yet, call .build_model() first or call .fit() instead." ) - prior_predictive = kwargs.pop("prior_predictive", False) - posterior_predictive = kwargs.pop("posterior_predictive", False) + if prior_predictive_kwargs is None: + prior_predictive_kwargs = {} + if posterior_predictive_kwargs is None: + posterior_predictive_kwargs = {} with self.model: - sampler_args = {**self.sampler_config, **kwargs} + sampler_args = {**self.sampler_config, **sample_kwargs} idata = pm.sample(**sampler_args) if prior_predictive: - if prior_kwargs is not None: - idata.extend(pm.sample_prior_predictive(**prior_kwargs)) - else: - idata.extend(pm.sample_prior_predictive()) + idata.extend(pm.sample_prior_predictive(**prior_predictive_kwargs)) if posterior_predictive: - if prior_kwargs is not None: - idata.extend( - pm.sample_posterior_predictive(idata, **posterior_kwargs) - ) - else: - idata.extend(pm.sample_posterior_predictive(idata)) + idata.extend( + pm.sample_posterior_predictive(idata, **posterior_predictive_kwargs) + ) idata = self.set_idata_attrs(idata) return idata @@ -509,9 +511,9 @@ def fit( sample_kwargs : Optional[Dict] Allows for passing additional keyword arguments to the sample_model method possible arguments are: - - prior_kwargs : dict, optional + - prior_predictive_kwargs : dict, optional keyword arguments to pass to the PyMC prior predictive sampler. - - posterior_kwargs : dict, optional + - posterior_predictive_kwargs : dict, optional keyword arguments to pass to the PyMC posterior predictive sampler. - prior_predictive : bool, optional If True, the inference data will be extended with samples drawn from the prior predictive distribution. @@ -548,15 +550,17 @@ def fit( sampler_config["random_seed"] = random_seed sampler_config.update(**kwargs) - prior_kwargs = {} - posterior_kwargs = {} + prior_predictive_kwargs = {} + posterior_predictive_kwargs = {} prior_predictive = False posterior_predictive = False if sample_kwargs is not None: - if "prior_kwargs" in sample_kwargs: - prior_kwargs = sample_kwargs.pop("prior_kwargs") - if "posterior_kwargs" in sample_kwargs: - posterior_kwargs = sample_kwargs.pop("posterior_kwargs") + if "prior_predictive_kwargs" in sample_kwargs: + prior_predictive_kwargs = sample_kwargs.pop("prior_predictive_kwargs") + if "posterior_predictive_kwargs" in sample_kwargs: + posterior_predictive_kwargs = sample_kwargs.pop( + "posterior_predictive_kwargs" + ) if "prior_predictive" in sample_kwargs: prior_predictive = sample_kwargs.pop("prior_predictive") if "posterior_predictive" in sample_kwargs: @@ -564,8 +568,8 @@ def fit( # Merge all arguments and pass to sample_model self.idata = self.sample_model( - prior_kwargs=prior_kwargs, - posterior_kwargs=posterior_kwargs, + prior_predictive_kwargs=prior_predictive_kwargs, + posterior_predictive_kwargs=posterior_predictive_kwargs, prior_predictive=prior_predictive, posterior_predictive=posterior_predictive, **sampler_config,