Skip to content

Commit

Permalink
Merge pull request #169 from ThibeauWouters/main
Browse files Browse the repository at this point in the history
update example and add print summary to Sampler
  • Loading branch information
kazewong authored May 9, 2024
2 parents 6e4da45 + 0c49ac9 commit 22e0dda
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
7 changes: 5 additions & 2 deletions example/dualmoon.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import corner
import jax
print(jax.devices())
import jax.numpy as jnp # JAX NumPy
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -42,7 +43,7 @@ def target_dualmoon(x, data):
rng_key, subkey = jax.random.split(rng_key)
initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1

MALA_Sampler = MALA(target_dualmoon, True, {"step_size": 0.1})
MALA_Sampler = MALA(target_dualmoon, True, 0.1)

print("Initializing sampler class")

Expand Down Expand Up @@ -80,7 +81,7 @@ def target_dualmoon(x, data):
)

chains = np.array(chains)
nf_samples = np.array(nf_samples[1])
nf_samples = np.array(nf_samples)
loss_vals = np.array(loss_vals)

# Plot one chain to show the jump
Expand Down Expand Up @@ -123,3 +124,5 @@ def target_dualmoon(x, data):
figure.set_size_inches(7, 7)
figure.suptitle("Visualize NF samples")
plt.show()

nf_sampler.print_summary()
43 changes: 43 additions & 0 deletions src/flowMC/Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,46 @@ def save_summary(self, path: str):
"""
with open(path, "wb") as f:
pickle.dump(self.summary, f)

def print_summary(self) -> None:
"""
Print summary to the screen about log probabilities and local/global acceptance rates.
"""
train_summary = self.get_sampler_state(training=True)
production_summary = self.get_sampler_state(training=False)

training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
training_loss = train_summary["loss_vals"]

production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]

print("Training summary")
print("=" * 10)
print(
f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}"
)
print(
f"Local acceptance: {training_local_acceptance.mean():.3f} +/- {training_local_acceptance.std():.3f}"
)
print(
f"Global acceptance: {training_global_acceptance.mean():.3f} +/- {training_global_acceptance.std():.3f}"
)
print(
f"Max loss: {training_loss.max():.3f}, Min loss: {training_loss.min():.3f}"
)

print("Production summary")
print("=" * 10)
print(
f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}"
)
print(
f"Local acceptance: {production_local_acceptance.mean():.3f} +/- {production_local_acceptance.std():.3f}"
)
print(
f"Global acceptance: {production_global_acceptance.mean():.3f} +/- {production_global_acceptance.std():.3f}"
)

0 comments on commit 22e0dda

Please sign in to comment.