Skip to content

Commit

Permalink
Update example outputs and test that it works
Browse files Browse the repository at this point in the history
  • Loading branch information
kiramclean authored and twiecki committed Jun 15, 2024
1 parent c657fef commit 33b24cc
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ Imagine we conduct an experiment to predict the growth of a plant based on diffe
}
with pm.do(generative_model, fixed_parameters) as synthetic_model:
idata = pm.sample_prior_predictive(random_seed=seed) # Sample from prior predictive distribution.
synthetic_y = idata.prior["plant growth (z-scored)"].sel(draw=0, chain=0)
synthetic_y = idata.prior["plant growth"].sel(draw=0, chain=0)
# Infer parameters conditioned on observed data
with pm.observe(generative_model, {"plant growth (z-scored)": synthetic_y}) as inference_model:
idata = pm.sample(random_seed=seed)
with pm.observe(generative_model, {"plant growth": synthetic_y}) as inference_model:
idata = pm.sample(random_seed=seed)
summary = pm.stats.summary(idata, var_names=["betas", "sigma"]))
print(summary)
summary = pm.stats.summary(idata, var_names=["betas", "sigma"])
print(summary)
From the summary, we can see that the mean of the inferred parameters are very close to the fixed parameters
Expand All @@ -116,14 +116,14 @@ sigma 0.511 0.037 0.438 0.575 0.001 0
# Simulate new data conditioned on inferred parameters
new_x_data = pm.draw(
pm.Normal.dist(shape=(3, 3)),
random_seed=seed,
pm.Normal.dist(shape=(3, 3)),
random_seed=seed,
)
new_coords = coords | {"trial": [0, 1, 2]}
with inference_model:
pm.set_data({"x": new_x_data}, coords=new_coords)
idata = pm.sample_posterior_predictive(
pm.sample_posterior_predictive(
idata,
predictions=True,
extend_inferencedata=True,
Expand All @@ -134,13 +134,13 @@ sigma 0.511 0.037 0.438 0.575 0.001 0
The new data conditioned on inferred parameters would look like:

========================== ====== ===== ======== =========
Output mean sd hdi_3% hdi_97%
========================== ====== ===== ======== =========
plant growth (z-scored)[0] 14.21 0.509 13.232 15.144
plant growth (z-scored)[1] 24.43 0.518 23.347 25.32
plant growth (z-scored)[2] -6.743 0.515 -7.778 -5.834
========================== ====== ===== ======== =========
================ ======== ======= ======== =========
Output mean sd hdi_3% hdi_97%
================ ======== ======= ======== =========
plant growth[0] 14.229 0.515 13.325 15.272
plant growth[1] 24.418 0.511 23.428 25.326
plant growth[2] -6.747 0.511 -7.740 -5.797
================ ======== ======= ======== =========

.. code-block:: python
Expand All @@ -159,13 +159,13 @@ plant growth (z-scored)[2] -6.743 0.515 -7.778 -5.834
The new data, under the above scenario would look like:

========================== ====== ===== ======== =========
Output mean sd hdi_3% hdi_97%
========================== ====== ===== ======== =========
plant growth (z-scored)[0] 14.153 0.509 13.181 15.096
plant growth (z-scored)[1] 23.85 0.517 22.915 24.878
plant growth (z-scored)[2] -7.302 0.515 -8.315 -6.374
========================== ====== ===== ======== =========
================ ======== ======= ======== =========
Output mean sd hdi_3% hdi_97%
================ ======== ======= ======== =========
plant growth[0] 12.149 0.515 11.193 13.135
plant growth[1] 29.809 0.508 28.832 30.717
plant growth[2] -0.131 0.507 -1.121 0.791
================ ======== ======= ======== =========

Getting started
===============
Expand Down

0 comments on commit 33b24cc

Please sign in to comment.