Skip to content

Commit

Permalink
Merge pull request #176 from kazewong/adjusting-non-jax-example
Browse files Browse the repository at this point in the history
minor fix of the example using a non-jax likelihood
  • Loading branch information
kazewong authored Jun 5, 2024
2 parents 193ddc2 + 46d16a5 commit 90480d0
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions example/non_jax_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.proposal.Gaussian_random_walk import GaussianRandomWalk
from flowMC.Sampler import Sampler
from flowMC.utils.PRNG_keys import initialize_rng_keys
from flowMC.utils.PythonFunctionWrap import wrap_python_log_prob_fn

"""
Expand All @@ -30,34 +29,35 @@ def neal_funnel(x):
x_pdf = norm.logpdf(x["params"][1:], loc=0, scale=np.exp(x["params"][0] / 2))
return y_pdf + np.sum(x_pdf)


print("Using minimal settings for demonstration purposes.")
n_dim = 5
n_chains = 20
n_loop_training = 5
n_loop_production = 5
n_loop_training = 2
n_loop_production = 2
n_local_steps = 20
n_global_steps = 100
n_global_steps = 10
n_chains = 100
learning_rate = 0.01
momentum = 0.9
num_epochs = 100
batch_size = 1000
num_epochs = 10
batch_size = 100

data = jnp.zeros(n_dim)

rng_key_set = initialize_rng_keys(n_chains, 42)
model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))

initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1
rng_key = jax.random.PRNGKey(42)
rng_key, subkey = jax.random.split(rng_key)
model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, subkey)

RW_Sampler = GaussianRandomWalk(neal_funnel, False, {"step_size": 0.1})
rng_key, subkey = jax.random.split(rng_key)
initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1

RW_Sampler = GaussianRandomWalk(neal_funnel, False, 0.1)

print("Initializing sampler class")

nf_sampler = Sampler(
n_dim,
rng_key_set,
rng_key,
jnp.zeros(5),
RW_Sampler,
model,
Expand All @@ -76,7 +76,8 @@ def neal_funnel(x):
nf_sampler.sample(initial_position, data)
summary = nf_sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = summary.values()
nf_samples = nf_sampler.sample_flow(10000)
rng_key, subkey = jax.random.split(rng_key)
nf_samples = nf_sampler.sample_flow(subkey, 10000)

print(
"chains shape: ",
Expand All @@ -88,7 +89,7 @@ def neal_funnel(x):
)

chains = np.array(chains)
nf_samples = np.array(nf_samples[1])
nf_samples = np.array(nf_samples)
loss_vals = np.array(loss_vals)
import corner
import matplotlib.pyplot as plt
Expand Down

0 comments on commit 90480d0

Please sign in to comment.