From c65c32decedd09beb6d0b29d557934f80594d249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marylou=20Gabri=C3=A9?= Date: Wed, 5 Jun 2024 15:40:07 +0200 Subject: [PATCH 1/2] update the example --- example/non_jax_likelihood.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/example/non_jax_likelihood.py b/example/non_jax_likelihood.py index 728e9ac..4b50a34 100644 --- a/example/non_jax_likelihood.py +++ b/example/non_jax_likelihood.py @@ -6,7 +6,7 @@ from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.proposal.Gaussian_random_walk import GaussianRandomWalk from flowMC.Sampler import Sampler -from flowMC.utils.PRNG_keys import initialize_rng_keys +# from flowMC.utils.PRNG_keys import initialize_rng_keys from flowMC.utils.PythonFunctionWrap import wrap_python_log_prob_fn """ @@ -30,34 +30,35 @@ def neal_funnel(x): x_pdf = norm.logpdf(x["params"][1:], loc=0, scale=np.exp(x["params"][0] / 2)) return y_pdf + np.sum(x_pdf) - +print("Using minimal settings for demonstration purposes.") n_dim = 5 n_chains = 20 -n_loop_training = 5 -n_loop_production = 5 +n_loop_training = 2 +n_loop_production = 2 n_local_steps = 20 -n_global_steps = 100 +n_global_steps = 10 n_chains = 100 learning_rate = 0.01 momentum = 0.9 -num_epochs = 100 -batch_size = 1000 +num_epochs = 10 +batch_size = 100 data = jnp.zeros(n_dim) -rng_key_set = initialize_rng_keys(n_chains, 42) -model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10)) - -initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 +rng_key = jax.random.PRNGKey(42) +rng_key, subkey = jax.random.split(rng_key) +model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, subkey) -RW_Sampler = GaussianRandomWalk(neal_funnel, False, {"step_size": 0.1}) +rng_key, subkey = jax.random.split(rng_key) +initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1 +RW_Sampler = GaussianRandomWalk(neal_funnel, False, 0.1) print("Initializing sampler class") nf_sampler = Sampler( n_dim, - rng_key_set, + rng_key, jnp.zeros(5), RW_Sampler, model, @@ -76,7 +77,8 @@ def neal_funnel(x): nf_sampler.sample(initial_position, data) summary = nf_sampler.get_sampler_state(training=True) chains, log_prob, local_accs, global_accs, loss_vals = summary.values() -nf_samples = nf_sampler.sample_flow(10000) +rng_key, subkey = jax.random.split(rng_key) +nf_samples = nf_sampler.sample_flow(subkey, 10000) print( "chains shape: ", @@ -88,7 +90,7 @@ def neal_funnel(x): ) chains = np.array(chains) -nf_samples = np.array(nf_samples[1]) +nf_samples = np.array(nf_samples) loss_vals = np.array(loss_vals) import corner import matplotlib.pyplot as plt From 46d16a58a29edb22899c3b7c0c80cf87bc897aab Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Wed, 5 Jun 2024 14:00:52 -0400 Subject: [PATCH 2/2] Update non_jax_likelihood.py --- example/non_jax_likelihood.py | 1 - 1 file changed, 1 deletion(-) diff --git a/example/non_jax_likelihood.py b/example/non_jax_likelihood.py index 4b50a34..2a22f02 100644 --- a/example/non_jax_likelihood.py +++ b/example/non_jax_likelihood.py @@ -6,7 +6,6 @@ from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.proposal.Gaussian_random_walk import GaussianRandomWalk from flowMC.Sampler import Sampler -# from flowMC.utils.PRNG_keys import initialize_rng_keys from flowMC.utils.PythonFunctionWrap import wrap_python_log_prob_fn """