Skip to content

Commit

Permalink
moving prior and posterior kwargs out of sample kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelraczycki committed Aug 22, 2023
1 parent ee23a8c commit 6eaf896
Showing 1 changed file with 41 additions and 37 deletions.
78 changes: 41 additions & 37 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,29 +270,30 @@ def build_model(

def sample_model(
self,
prior_kwargs: Optional[Dict] = None,
posterior_kwargs: Optional[Dict] = None,
**kwargs,
):
prior_predictive: bool = False,
posterior_predictive: bool = False,
prior_predictive_kwargs: Optional[Dict] = None,
posterior_predictive_kwargs: Optional[Dict] = None,
**sample_kwargs,
) -> az.InferenceData:
"""
Sample from the PyMC model.
Parameters
----------
prior_kwargs : dict, optional
prior_predictive : bool, optional
If True, the inference data will be extended with samples drawn from the prior predictive distribution.
Defaults to False.
posterior_predictive : bool, optional
If True, the inference data will be extended with samples drawn from the posterior predictive distribution.
Defaults to False.
prior_predictive_kwargs : dict, optional
keyword arguments to pass to the PyMC prior predictive sampler.
posterior_kwargs : dict, optional
posterior_predictive_kwargs : dict, optional
keyword arguments to pass to the PyMC posterior predictive sampler.
**kwargs : dict
Additional keyword arguments to pass to the PyMC sampler.
- prior_predictive : bool, optional
If True, the inference data will be extended with samples drawn from the prior predictive distribution.
Defaults to False.
- posterior_predictive : bool, optional
If True, the inference data will be extended with samples drawn from the posterior predictive distribution.
Defaults to False.
Returns
-------
xarray.Dataset
The PyMC samples dataset.
Expand All @@ -306,28 +307,29 @@ def sample_model(
--------
>>> self.build_model()
>>> idata = self.sample_model(draws=100, tune=10)
Returns
-------
idata : az.InferenceData
InferenceData object containing the samples.
"""
if self.model is None:
raise RuntimeError(
"The model hasn't been built yet, call .build_model() first or call .fit() instead."
)
prior_predictive = kwargs.pop("prior_predictive", False)
posterior_predictive = kwargs.pop("posterior_predictive", False)
if prior_predictive_kwargs is None:
prior_predictive_kwargs = {}
if posterior_predictive_kwargs is None:
posterior_predictive_kwargs = {}
with self.model:
sampler_args = {**self.sampler_config, **kwargs}
sampler_args = {**self.sampler_config, **sample_kwargs}
idata = pm.sample(**sampler_args)
if prior_predictive:
if prior_kwargs is not None:
idata.extend(pm.sample_prior_predictive(**prior_kwargs))
else:
idata.extend(pm.sample_prior_predictive())
idata.extend(pm.sample_prior_predictive(**prior_predictive_kwargs))
if posterior_predictive:
if prior_kwargs is not None:
idata.extend(
pm.sample_posterior_predictive(idata, **posterior_kwargs)
)
else:
idata.extend(pm.sample_posterior_predictive(idata))
idata.extend(
pm.sample_posterior_predictive(idata, **posterior_predictive_kwargs)
)

idata = self.set_idata_attrs(idata)
return idata
Expand Down Expand Up @@ -509,9 +511,9 @@ def fit(
sample_kwargs : Optional[Dict]
Allows for passing additional keyword arguments to the sample_model method
possible arguments are:
- prior_kwargs : dict, optional
- prior_predictive_kwargs : dict, optional
keyword arguments to pass to the PyMC prior predictive sampler.
- posterior_kwargs : dict, optional
- posterior_predictive_kwargs : dict, optional
keyword arguments to pass to the PyMC posterior predictive sampler.
- prior_predictive : bool, optional
If True, the inference data will be extended with samples drawn from the prior predictive distribution.
Expand Down Expand Up @@ -548,24 +550,26 @@ def fit(
sampler_config["random_seed"] = random_seed
sampler_config.update(**kwargs)

prior_kwargs = {}
posterior_kwargs = {}
prior_predictive_kwargs = {}
posterior_predictive_kwargs = {}
prior_predictive = False
posterior_predictive = False
if sample_kwargs is not None:
if "prior_kwargs" in sample_kwargs:
prior_kwargs = sample_kwargs.pop("prior_kwargs")
if "posterior_kwargs" in sample_kwargs:
posterior_kwargs = sample_kwargs.pop("posterior_kwargs")
if "prior_predictive_kwargs" in sample_kwargs:
prior_predictive_kwargs = sample_kwargs.pop("prior_predictive_kwargs")
if "posterior_predictive_kwargs" in sample_kwargs:
posterior_predictive_kwargs = sample_kwargs.pop(
"posterior_predictive_kwargs"
)
if "prior_predictive" in sample_kwargs:
prior_predictive = sample_kwargs.pop("prior_predictive")
if "posterior_predictive" in sample_kwargs:
posterior_predictive = sample_kwargs.pop("posterior_predictive")

# Merge all arguments and pass to sample_model
self.idata = self.sample_model(
prior_kwargs=prior_kwargs,
posterior_kwargs=posterior_kwargs,
prior_predictive_kwargs=prior_predictive_kwargs,
posterior_predictive_kwargs=posterior_predictive_kwargs,
prior_predictive=prior_predictive,
posterior_predictive=posterior_predictive,
**sampler_config,
Expand Down

0 comments on commit 6eaf896

Please sign in to comment.