Skip to content

Commit

Permalink
fixed test runs
Browse files Browse the repository at this point in the history
  • Loading branch information
Thibeau Wouters committed Feb 13, 2024
1 parent 02fcb1f commit ad1a32d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
7 changes: 3 additions & 4 deletions src/flowMC/sampler/Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
default_hyperparameters = {
"n_loop_training": 3,
"n_loop_production": 3,
"n_loop_pretraining": 0,
"n_local_steps": 50,
"n_global_steps": 50,
"n_chains": 20,
Expand All @@ -25,6 +24,7 @@
"momentum": 0.9,
"batch_size": 10000,
"use_global": True,
"global_sampler": None,
"logging": True,
"keep_quantile": 0,
"local_autotune": None,
Expand All @@ -44,7 +44,6 @@ class Sampler:
Args:
"n_loop_training": "(int): Number of training loops.",
"n_loop_production": "(int): Number of production loops.",
"n_loop_pretraining": "(int): Number of pretraining loops.",
"n_local_steps": "(int) Number of local steps per loop.",
"n_global_steps": "(int) Number of local steps per loop.",
"n_chains": "(int) Number of chains",
Expand All @@ -54,6 +53,7 @@ class Sampler:
"momentum": "(float) Momentum used in the training of the NF model with the Adam optimizer",
"batch_size": "(int) Size of batches used to train the NF",
"use_global": "(bool) Whether to use an NF proposal as global sampler",
"global_sampler": "(NFProposal) Global sampler",
"logging": "(bool) Whether to log the training process",
"keep_quantile": "Quantile of chains to keep when training the normalizing flow model",
"local_autotune": "(Callable) Auto-tune function for the local sampler",
Expand Down Expand Up @@ -107,8 +107,7 @@ def __init__(
)

if self.global_sampler is None:
global_sampler = NFProposal(self.local_sampler.logpdf, jit=self.local_sampler.jit, model=nf_model, n_sample_max=self.n_sample_max)
self.global_sampler = global_sampler
self.global_sampler = NFProposal(self.local_sampler.logpdf, jit=self.local_sampler.jit, model=nf_model, n_sample_max=self.n_sample_max)

self.likelihood_vec = self.local_sampler.logpdf_vmap

Expand Down
12 changes: 8 additions & 4 deletions src/flowMC/utils/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import matplotlib.pyplot as plt
import jax.numpy as jnp
from flowMC.sampler.Sampler import Sampler
# from flowMC.sampler.Sampler import Sampler
from jaxtyping import Float, Array

def plot_summary(sampler: Sampler, which: str = "training", **plotkwargs) -> None:
def plot_summary(sampler: object, training: bool = False, **plotkwargs) -> None:
"""
Create plots of the most important quantities in the summary.
Args:
which (str, optional): Which summary dictionary to show in plots. Defaults to "training".
training (bool, optional): If True, plot training quantities. If False, plot production quantities. Defaults to False.
"""

# Choose the dataset
data = Sampler.get_sampler_state(which = which)
data = sampler.get_sampler_state(training=training)
# TODO add loss values in plotting
keys = ["local_accs", "global_accs", "log_prob"]
if sampler.track_gelman_rubin:
Expand All @@ -25,6 +25,10 @@ def plot_summary(sampler: Sampler, which: str = "training", **plotkwargs) -> Non
outdir = "./outdir/"

for key in keys:
if training:
which = "training"
else:
which = "production"
_single_plot(data, key, which, outdir=outdir, **plotkwargs)

def _single_plot(data: dict, name: str, which: str = "training", outdir: str = "./outdir/", **plotkwargs):
Expand Down

0 comments on commit ad1a32d

Please sign in to comment.