Skip to content

Commit

Permalink
deprecate samples arg in prior_predictive. closes #7173
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Jun 17, 2024
1 parent 8a68a5c commit af7127c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
18 changes: 15 additions & 3 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,20 @@ def observed_dependent_deterministics(model: Model):


def sample_prior_predictive(
samples: int = 500,
draws: int = 500,
model: Model | None = None,
var_names: Iterable[str] | None = None,
random_seed: RandomState = None,
return_inferencedata: bool = True,
idata_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
samples: int | None = None,
) -> InferenceData | dict[str, np.ndarray]:
"""Generate samples from the prior predictive distribution.
Parameters
----------
samples : int
draws : int
Number of samples from the prior predictive to generate. Defaults to 500.
model : Model (optional if in ``with`` context)
var_names : Iterable[str]
Expand All @@ -366,13 +367,24 @@ def sample_prior_predictive(
Keyword arguments for :func:`pymc.to_inference_data`
compile_kwargs: dict, optional
Keyword arguments for :func:`pymc.pytensorf.compile_pymc`.
samples : int
Number of samples from the prior predictive to generate. Deprecated in favor of `draws`.
Returns
-------
arviz.InferenceData or Dict
An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default),
or a dictionary with variable names as keys and samples as numpy arrays.
"""
if samples is not None:
warnings.warn(
f"The samples argument has been deprecated in favor of draws. Use draws={samples} going forward.",
DeprecationWarning,
stacklevel=2,
)

draws = samples

model = modelcontext(model)

if model.potentials:
Expand Down Expand Up @@ -415,7 +427,7 @@ def sample_prior_predictive(

# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
values = zip(*(sampler_fn() for i in range(samples)))
values = zip(*(sampler_fn() for i in range(draws)))

data = {k: np.stack(v) for k, v in zip(names, values)}
if data is None:
Expand Down
35 changes: 22 additions & 13 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,12 +794,12 @@ def test_logging_sampled_basic_rvs_prior(self, caplog):
z = pm.Normal("z", y, observed=0)

with m:
pm.sample_prior_predictive(samples=1)
pm.sample_prior_predictive(draws=1)
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")]
caplog.clear()

with m:
pm.sample_prior_predictive(samples=1, var_names=["x"])
pm.sample_prior_predictive(draws=1, var_names=["x"])
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x]")]
caplog.clear()

Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_observed_data_needed_in_pp(self):
mu = x_data.sum(-1)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))

prior = pm.sample_prior_predictive(samples=25).prior
prior = pm.sample_prior_predictive(draws=25).prior

fake_idata = InferenceData(posterior=prior)

Expand All @@ -1052,7 +1052,7 @@ def test_observed_data_needed_in_pp(self):
mu = (y_data.sum() * x_data).sum(-1)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))

prior = pm.sample_prior_predictive(samples=25).prior
prior = pm.sample_prior_predictive(draws=25).prior

fake_idata = InferenceData(posterior=prior)

Expand Down Expand Up @@ -1135,7 +1135,7 @@ def test_multivariate2(self, seeded_test):
compute_convergence_checks=False,
)
sim_priors = pm.sample_prior_predictive(
return_inferencedata=False, samples=20, model=dm_model
return_inferencedata=False, draws=20, model=dm_model
)
sim_ppc = pm.sample_posterior_predictive(
burned_trace, return_inferencedata=False, model=dm_model
Expand Down Expand Up @@ -1227,7 +1227,7 @@ def test_zeroinflatedpoisson(self):
mu = pm.Beta("mu", alpha=1, beta=1)
psi = pm.HalfNormal("psi", sigma=1)
pm.ZeroInflatedPoisson("suppliers", psi=psi, mu=mu, size=20)
gen_data = pm.sample_prior_predictive(samples=5000)
gen_data = pm.sample_prior_predictive(draws=5000)
assert gen_data.prior["mu"].shape == (1, 5000)
assert gen_data.prior["psi"].shape == (1, 5000)
assert gen_data.prior["suppliers"].shape == (1, 5000, 20)
Expand All @@ -1240,7 +1240,7 @@ def test_potentials_warning(self):

with m:
with pytest.warns(UserWarning, match=warning_msg):
pm.sample_prior_predictive(samples=5)
pm.sample_prior_predictive(draws=5)

def test_transformed_vars_not_supported(self):
with pm.Model() as model:
Expand All @@ -1260,7 +1260,7 @@ def test_issue_4490(self):
c = pm.Normal("c")
d = pm.Normal("d")
prior1 = pm.sample_prior_predictive(
samples=1, var_names=["a", "b", "c", "d"], random_seed=seed
draws=1, var_names=["a", "b", "c", "d"], random_seed=seed
)

with pm.Model() as m2:
Expand All @@ -1269,7 +1269,7 @@ def test_issue_4490(self):
c = pm.Normal("c")
d = pm.Normal("d")
prior2 = pm.sample_prior_predictive(
samples=1, var_names=["b", "a", "d", "c"], random_seed=seed
draws=1, var_names=["b", "a", "d", "c"], random_seed=seed
)

assert prior1.prior["a"] == prior2.prior["a"]
Expand All @@ -1284,7 +1284,7 @@ def test_pytensor_function_kwargs(self):
y = pm.Deterministic("y", x + sharedvar)

prior = pm.sample_prior_predictive(
samples=5,
draws=5,
return_inferencedata=False,
compile_kwargs=dict(
mode=Mode("py"),
Expand All @@ -1308,7 +1308,7 @@ def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):

with pmodel:
prior = pm.sample_prior_predictive(
samples=20,
draws=20,
return_inferencedata=False,
)
idat = pm.to_inference_data(trace, prior=prior)
Expand Down Expand Up @@ -1367,7 +1367,7 @@ def test_distinct_rvs():
Y_rv = pm.Normal("y")

pp_samples = pm.sample_prior_predictive(
samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
)

assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0]
Expand All @@ -1377,7 +1377,7 @@ def test_distinct_rvs():
Y_rv = pm.Normal("y")

pp_samples_2 = pm.sample_prior_predictive(
samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
)

assert np.array_equal(pp_samples["y"], pp_samples_2["y"])
Expand Down Expand Up @@ -1706,3 +1706,12 @@ def test_observed_dependent_deterministics():
det_mixed = pm.Deterministic("det_mixed", free + obs)

assert set(observed_dependent_deterministics(m)) == {det_obs, det_obs2, det_mixed}


def test_sample_prior_predictive_samples_deprecated_warns() -> None:
with pm.Model() as m:
pm.Normal("a")

match = "The samples argument has been deprecated"
with pytest.warns(DeprecationWarning, match=match):
pm.sample_prior_predictive(model=m, samples=10)

0 comments on commit af7127c

Please sign in to comment.