Skip to content

Commit

Permalink
Update the Metropolis-within-Gibbs example (#558)
Browse files Browse the repository at this point in the history
* Update Metropolis-within-Gibbs example

Update the Metropolis-within-Gibbs example markdown notebook to be
compatible with API changes.

Minor changes to text to keep consistency with code.

* Fix trailing whitespace

* simplify changes

---------

Co-authored-by: Tommy Hentschel <[email protected]>
Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
3 people authored Sep 15, 2023
1 parent 655c36b commit 8018fc4
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions docs/examples/howto_metropolis_within_gibbs.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.4
jupytext_version: 1.14.7
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -67,22 +67,22 @@ In this case the conditional distributions $p(\xx \mid \yy)$ and $p(\yy \mid \xx
1. Maintain separate MCMC kernels to update each component of $p(\xx, \yy)$ while holding the other fixed.
2. Apply the kernel updates correctly.

The issue with (2) is that each kernel update for a given MCMC `Algorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RMHState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RMHState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.mcmc.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RMHState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$.
The issue with (2) is that each kernel update for a given MCMC `Algorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RWState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RWState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RWState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$.

A general way of performing this manual update is to use the `blackjax.mcmc.algorithm.init()` function of the given component's MCMC algorithm to update the `AlgorithmState`. This function has arguments `position` and `logdensity_fn`. For example with the HMC component, after obtaining $\xx_t$ but before drawing $\yy_t$, the `position` would be $\yy_{t-1}$ and the `logdensity_fn` function would be $\log p(\xx_t, \cdot )$.
A general way of performing this manual update is to use the `blackjax.algorithm.init()` function of the given component's MCMC algorithm to update the `AlgorithmState`. This function has arguments `position` and `logdensity_fn`. For example with the HMC component, after obtaining $\xx_t$ but before drawing $\yy_t$, the `position` would be $\yy_{t-1}$ and the `logdensity_fn` function would be $\log p(\xx_t, \cdot )$.

Using this approach, we now are now ready to implement the Gibbs sampling kernel in the code below.

### Construct the MWG Kernel

```{code-cell} ipython3
# MCMC initializers for each set of paramters
mwg_init_x = blackjax.mcmc.rmh.init
mwg_init_y = blackjax.mcmc.hmc.init
mwg_init_x = blackjax.rmh.init
mwg_init_y = blackjax.hmc.init
# MCMC updaters
mwg_step_fn_x = blackjax.mcmc.rmh.build_kernel()
mwg_step_fn_y = blackjax.mcmc.hmc.build_kernel() # default integrator, etc.
mwg_step_fn_x = blackjax.rmh.build_kernel()
mwg_step_fn_y = blackjax.hmc.build_kernel() # default integrator, etc.
def mwg_kernel(rng_key, state, parameters):
Expand Down Expand Up @@ -151,7 +151,7 @@ def mwg_kernel(rng_key, state, parameters):
```{code-cell} ipython3
parameters = {
"x": {
"sigma": .2 * jnp.eye(2)
"transition_generator": blackjax.mcmc.random_walk.normal(.2 * jnp.eye(2))
},
"y": {
"inverse_mass_matrix": jnp.array([1., 1.]),
Expand Down Expand Up @@ -331,6 +331,6 @@ positions_general = sampling_loop_general(

## Developer Notes

- The update method above (using `blackjax.mcmc.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RMHState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code.
- The update method above (using `blackjax.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RWState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code.

- The general MWG kernel prototyped above should be adequate for problems with a small number of components. However, the for-loop over the components of `state` gets unrolled by the JAX JIT compiler (as discussed [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#structured-control-flow-primitives)), which can cause long compilation times when the number of components is large. To mitigate this problem, the for-loop could be replaced by a `lax.scan()` primitive. For the sake of simplicity this approach is not fully developed here.

0 comments on commit 8018fc4

Please sign in to comment.