From 0c49ac9b57667c2beee2d39dbd3a50be002e2e9c Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Wed, 8 May 2024 03:49:12 -0700 Subject: [PATCH] update example and add print summary to Sampler --- example/dualmoon.py | 7 +++++-- src/flowMC/Sampler.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/example/dualmoon.py b/example/dualmoon.py index 9c518f4..2cecfc9 100644 --- a/example/dualmoon.py +++ b/example/dualmoon.py @@ -1,5 +1,6 @@ import corner import jax +print(jax.devices()) import jax.numpy as jnp # JAX NumPy import matplotlib.pyplot as plt import numpy as np @@ -42,7 +43,7 @@ def target_dualmoon(x, data): rng_key, subkey = jax.random.split(rng_key) initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1 -MALA_Sampler = MALA(target_dualmoon, True, {"step_size": 0.1}) +MALA_Sampler = MALA(target_dualmoon, True, 0.1) print("Initializing sampler class") @@ -80,7 +81,7 @@ def target_dualmoon(x, data): ) chains = np.array(chains) -nf_samples = np.array(nf_samples[1]) +nf_samples = np.array(nf_samples) loss_vals = np.array(loss_vals) # Plot one chain to show the jump @@ -123,3 +124,5 @@ def target_dualmoon(x, data): figure.set_size_inches(7, 7) figure.suptitle("Visualize NF samples") plt.show() + +nf_sampler.print_summary() \ No newline at end of file diff --git a/src/flowMC/Sampler.py b/src/flowMC/Sampler.py index 7d4af2e..1ef8c3b 100644 --- a/src/flowMC/Sampler.py +++ b/src/flowMC/Sampler.py @@ -388,3 +388,46 @@ def save_summary(self, path: str): """ with open(path, "wb") as f: pickle.dump(self.summary, f) + + def print_summary(self) -> None: + """ + Print summary to the screen about log probabilities and local/global acceptance rates. + """ + train_summary = self.get_sampler_state(training=True) + production_summary = self.get_sampler_state(training=False) + + training_log_prob = train_summary["log_prob"] + training_local_acceptance = train_summary["local_accs"] + training_global_acceptance = train_summary["global_accs"] + training_loss = train_summary["loss_vals"] + + production_log_prob = production_summary["log_prob"] + production_local_acceptance = production_summary["local_accs"] + production_global_acceptance = production_summary["global_accs"] + + print("Training summary") + print("=" * 10) + print( + f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}" + ) + print( + f"Local acceptance: {training_local_acceptance.mean():.3f} +/- {training_local_acceptance.std():.3f}" + ) + print( + f"Global acceptance: {training_global_acceptance.mean():.3f} +/- {training_global_acceptance.std():.3f}" + ) + print( + f"Max loss: {training_loss.max():.3f}, Min loss: {training_loss.min():.3f}" + ) + + print("Production summary") + print("=" * 10) + print( + f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}" + ) + print( + f"Local acceptance: {production_local_acceptance.mean():.3f} +/- {production_local_acceptance.std():.3f}" + ) + print( + f"Global acceptance: {production_global_acceptance.mean():.3f} +/- {production_global_acceptance.std():.3f}" + ) \ No newline at end of file