Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added plotting utilities and refactored hyperparams #143

Merged
merged 16 commits into from
Feb 16, 2024
94 changes: 54 additions & 40 deletions src/flowMC/sampler/Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,56 @@
import equinox as eqx
import numpy as np

default_hyperparameters = {
"n_loop_training": 3,
"n_loop_production": 3,
"n_local_steps": 50,
"n_global_steps": 50,
"n_chains": 20,
"n_epochs": 30,
"learning_rate": 0.001,
"max_samples": 100000,
"momentum": 0.9,
"batch_size": 10000,
"use_global": True,
"global_sampler": None,
"logging": True,
"keep_quantile": 0,
"local_autotune": None,
"train_thinning": 1,
"output_thinning": 1,
"n_sample_max": 10000,
"precompile": False,
"verbose": False,
"outdir": "./outdir/",
}

class Sampler:
"""
Sampler class that host configuration parameters, NF model, and local sampler

Args:
ThibeauWouters marked this conversation as resolved.
Show resolved Hide resolved
n_dim (int): Dimension of the problem.
rng_key_set (Tuple): Tuple of random number generator keys.
data (Device Array): Extra data to be passed to the likelihood function.
local_sampler (Callable): Local sampler maker
nf_model (NFModel): Normalizing flow model.
n_loop_training (int, optional): Number of training loops. Defaults to 3.
n_loop_production (int, optional): Number of production loops. Defaults to 3.
n_local_steps (int, optional): Number of local steps per loop. Defaults to 50.
n_global_steps (int, optional): Number of global steps per loop. Defaults to 50.
n_chains (int, optional): Number of chains. Defaults to 20.
n_epochs (int, optional): Number of epochs per training loop. Defaults to 30.
learning_rate (float, optional): Learning rate for the NF model. Defaults to 0.01.
max_samples (int, optional): Maximum number of samples fed to training the NF model. Defaults to 10000.
momentum (float, optional): Momentum for the NF model. Defaults to 0.9.
batch_size (int, optional): Batch size for the NF model. Defaults to 10000.
use_global (bool, optional): Whether to use global sampler. Defaults to True.
logging (bool, optional): Whether to log the training process. Defaults to True.
keep_quantile (float, optional): Quantile of chains to keep when training the normalizing flow model. Defaults to 0..
local_autotune (None, optional): Auto-tune function for the local sampler. Defaults to None.
train_thinning (int, optional): Thinning for the data used to train the normalizing flow. Defaults to 1.
"n_loop_training": "(int): Number of training loops.",
"n_loop_production": "(int): Number of production loops.",
"n_local_steps": "(int) Number of local steps per loop.",
"n_global_steps": "(int) Number of local steps per loop.",
"n_chains": "(int) Number of chains",
"n_epochs": "(int) Number of epochs to train the NF per training loop",
"learning_rate": "(float) Learning rate used in the training of the NF",
"max_samples": "(int) Maximum number of samples fed to training the NF model",
"momentum": "(float) Momentum used in the training of the NF model with the Adam optimizer",
"batch_size": "(int) Size of batches used to train the NF",
"use_global": "(bool) Whether to use an NF proposal as global sampler",
"global_sampler": "(NFProposal) Global sampler",
"logging": "(bool) Whether to log the training process",
"keep_quantile": "Quantile of chains to keep when training the normalizing flow model",
"local_autotune": "(Callable) Auto-tune function for the local sampler",
"train_thinning": "(int) Thinning parameter for the data used to train the normalizing flow",
"output_thinning": "(int) Thinning parameter with which to save the data ",
"n_sample_max": "(int) Maximum number of samples fed to training the NF model",
"precompile": "(bool) Whether to precompile",
"verbose": "(bool) Show steps of algorithm in detail",
"outdir": "(str) Location to which to save plots, samples and hyperparameter settings. Note: should ideally start with `./` and also end with `/`"
"""

@property
Expand All @@ -60,26 +84,15 @@ def __init__(
self.rng_keys_mcmc = rng_keys_mcmc
self.n_dim = n_dim

self.n_loop_training = kwargs.get("n_loop_training", 3)
self.n_loop_production = kwargs.get("n_loop_production", 3)
self.n_local_steps = kwargs.get("n_local_steps", 50)
self.n_global_steps = kwargs.get("n_global_steps", 50)
self.n_chains = kwargs.get("n_chains", 20)
self.n_epochs = kwargs.get("n_epochs", 30)
self.learning_rate = kwargs.get("learning_rate", 0.01)
self.max_samples = kwargs.get("max_samples", 10000)
self.momentum = kwargs.get("momentum", 0.9)
self.batch_size = kwargs.get("batch_size", 10000)
self.use_global = kwargs.get("use_global", True)
self.global_sampler = kwargs.get("global_sampler", None)
self.logging = kwargs.get("logging", True)
self.keep_quantile = kwargs.get("keep_quantile", 0)
self.local_autotune = kwargs.get("local_autotune", None)
self.train_thinning = kwargs.get("train_thinning", 1)
self.output_thinning = kwargs.get("output_thinning", 1)
self.n_sample_max = kwargs.get("n_sample_max", 10000)
self.precompile = kwargs.get("precompile", False)
self.verbose = kwargs.get("verbose", False)

# Set and override any given hyperparameters
ThibeauWouters marked this conversation as resolved.
Show resolved Hide resolved
self.hyperparameters = default_hyperparameters
hyperparameter_names = list(default_hyperparameters.keys())
for key, value in kwargs.items():
if key in hyperparameter_names:
self.hyperparameters[key] = value
for key, value in self.hyperparameters.items():
setattr(self, key, value)

self.variables = {"mean": None, "var": None}

Expand Down Expand Up @@ -118,6 +131,7 @@ def __init__(
self.summary["training"] = training
self.summary["production"] = production


def sample(self, initial_position: Array, data: dict):
"""
Sample from the posterior using the local sampler.
Expand Down
61 changes: 61 additions & 0 deletions src/flowMC/utils/postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import matplotlib.pyplot as plt
import jax.numpy as jnp
# from flowMC.sampler.Sampler import Sampler
from jaxtyping import Float, Array

def plot_summary(sampler: object, training: bool = False, **plotkwargs) -> None:
"""
Create plots of the most important quantities in the summary.

Args:
training (bool, optional): If True, plot training quantities. If False, plot production quantities. Defaults to False.
"""

# Choose the dataset
data = sampler.get_sampler_state(training=training)
# TODO add loss values in plotting
keys = ["local_accs", "global_accs", "log_prob"]
if sampler.track_gelman_rubin:
keys.append("gelman_rubin")

# Check if outdir is property of sampler
if hasattr(sampler, "outdir"):
outdir = sampler.outdir
else:
outdir = "./outdir/"

for key in keys:
if training:
which = "training"
else:
which = "production"
_single_plot(data, key, which, outdir=outdir, **plotkwargs)

def _single_plot(data: dict, name: str, which: str = "training", outdir: str = "./outdir/", **plotkwargs):
"""
Create a single plot of a quantity in the summary.

Args:
data (dict): Dictionary with the summary data.
name (str): Name of the quantity to plot.
which (str, optional): Name of this summary dict. Defaults to "training".
"""
# Get plot kwargs
figsize = plotkwargs["figsize"] if "figsize" in plotkwargs else (12, 8)
alpha = plotkwargs["alpha"] if "alpha" in plotkwargs else 1
eps = 1e-3

# Prepare plot data
plotdata = data[name]
mean = jnp.mean(plotdata, axis=0)
x = [i+1 for i in range(len(mean))]

# Plot
plt.figure(figsize=figsize)
plt.plot(x, mean, linestyle="-", color="blue", alpha=alpha)
plt.xlabel("Iteration")
plt.ylabel(f"{name} ({which})")
# Extras for some variables:
if "acc" in name:
plt.ylim(0-eps, 1+eps)
plt.savefig(f"{outdir}{name}_{which}.png", bbox_inches='tight')
Loading