Skip to content

Commit

Permalink
Update quickstart.md
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong authored Apr 9, 2024
1 parent 099e8c7 commit e3372bf
Showing 1 changed file with 30 additions and 26 deletions.
56 changes: 30 additions & 26 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,39 +38,43 @@ To sample a N dimensional Gaussian, you would do something like:
import jax
import jax.numpy as jnp
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.sampler.MALA import MALA
from flowMC.sampler.Sampler import Sampler
from flowMC.utils.PRNG_keys import initialize_rng_keys
from flowMC.nfmodel.utils import *
from flowMC.proposal.MALA import MALA
from flowMC.Sampler import Sampler
def log_posterior(x, data):
return -0.5 * jnp.sum((x-data) ** 2)
data = jnp.arange(5)
def log_posterior(x, data: dict):
return -0.5 * jnp.sum((x - data['data']) ** 2)
n_dim = 5
data = {'data':jnp.arange(5)}
n_dim = 1
n_chains = 10
rng_key_set = initialize_rng_keys(n_chains, seed=42)
initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1
model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, PRNGKeyArray(21))
rng_key = jax.random.PRNGKey(42)
rng_key, subkey = jax.random.split(rng_key)
initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1
rng_key, subkey = jax.random.split(rng_key)
model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, subkey)
step_size = 1e-1
local_sampler = MALA(log_posterior, True, {"step_size": step_size})
nf_sampler = Sampler(n_dim,
rng_key_set,
jnp.arange(n_dim),
local_sampler,
model,
n_local_steps = 50,
n_global_steps = 50,
n_epochs = 30,
learning_rate = 1e-2,
batch_size = 1000,
n_chains = n_chains)
local_sampler = MALA(log_posterior, True, step_size=step_size)
nf_sampler = Sampler(
n_dim,
rng_key,
data,
local_sampler,
model,
n_local_steps=50,
n_global_steps=50,
n_epochs=30,
learning_rate=1e-2,
batch_size=10000,
n_chains=n_chains,
)
nf_sampler.sample(initial_position, data)
chains,log_prob,local_accs, global_accs = nf_sampler.get_sampler_state().values()
chains, log_prob, local_accs, global_accs = nf_sampler.get_sampler_state().values()
```

For more examples, have a look at the [tutorials](https://github.com/kazewong/flowMC/tree/main/example) on GitHub.
Expand Down Expand Up @@ -99,4 +103,4 @@ Being able to run many chains in parallel helps training the normalizing flow mo

For the global sampler to be effective, the normalizing flow needs to learn where there is mass in the target distribution. Once the flow overlaps with the target, non-local jumps will start to be accepted and the MCMC chains will mix quickly.

As the flow learns from the chains, starting the chains in regions of interest will speed up the convergence of the algorithm. If these regions are not known, a good rule of thumb is to start from random draws from the prior provided the prior is spread enough to cover high density regions of the posterior.
As the flow learns from the chains, starting the chains in regions of interest will speed up the convergence of the algorithm. If these regions are not known, a good rule of thumb is to start from random draws from the prior provided the prior is spread enough to cover high density regions of the posterior.

0 comments on commit e3372bf

Please sign in to comment.