Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added plotting utilities and refactored hyperparams #143

Merged
merged 16 commits into from
Feb 16, 2024
120 changes: 62 additions & 58 deletions src/flowMC/sampler/Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,18 @@
import optax
from flowMC.sampler.Proposal_Base import ProposalBase
from flowMC.nfmodel.base import NFModel
from flowMC.utils import initialize_summary_dict
from tqdm import tqdm
import equinox as eqx
import numpy as np

import matplotlib.pyplot as plt
ThibeauWouters marked this conversation as resolved.
Show resolved Hide resolved
from flowMC.utils.hyperparameters import flowmc_default_hyperparameters

class Sampler:
"""
Sampler class that host configuration parameters, NF model, and local sampler

Args:
ThibeauWouters marked this conversation as resolved.
Show resolved Hide resolved
n_dim (int): Dimension of the problem.
rng_key_set (Tuple): Tuple of random number generator keys.
data (Device Array): Extra data to be passed to the likelihood function.
local_sampler (Callable): Local sampler maker
nf_model (NFModel): Normalizing flow model.
n_loop_training (int, optional): Number of training loops. Defaults to 3.
n_loop_production (int, optional): Number of production loops. Defaults to 3.
n_local_steps (int, optional): Number of local steps per loop. Defaults to 50.
n_global_steps (int, optional): Number of global steps per loop. Defaults to 50.
n_chains (int, optional): Number of chains. Defaults to 20.
n_epochs (int, optional): Number of epochs per training loop. Defaults to 30.
learning_rate (float, optional): Learning rate for the NF model. Defaults to 0.01.
max_samples (int, optional): Maximum number of samples fed to training the NF model. Defaults to 10000.
momentum (float, optional): Momentum for the NF model. Defaults to 0.9.
batch_size (int, optional): Batch size for the NF model. Defaults to 10000.
use_global (bool, optional): Whether to use global sampler. Defaults to True.
logging (bool, optional): Whether to log the training process. Defaults to True.
keep_quantile (float, optional): Quantile of chains to keep when training the normalizing flow model. Defaults to 0..
local_autotune (None, optional): Auto-tune function for the local sampler. Defaults to None.
train_thinning (int, optional): Thinning for the data used to train the normalizing flow. Defaults to 1.
For information regarding the hyperparameters to be passed, see flowMC.utils.hyperparameters
"""

@property
Expand All @@ -60,25 +42,14 @@ def __init__(
self.rng_keys_mcmc = rng_keys_mcmc
self.n_dim = n_dim

self.n_loop_training = kwargs.get("n_loop_training", 3)
self.n_loop_production = kwargs.get("n_loop_production", 3)
self.n_local_steps = kwargs.get("n_local_steps", 50)
self.n_global_steps = kwargs.get("n_global_steps", 50)
self.n_chains = kwargs.get("n_chains", 20)
self.n_epochs = kwargs.get("n_epochs", 30)
self.learning_rate = kwargs.get("learning_rate", 0.01)
self.max_samples = kwargs.get("max_samples", 100000)
self.momentum = kwargs.get("momentum", 0.9)
self.batch_size = kwargs.get("batch_size", 10000)
self.use_global = kwargs.get("use_global", True)
self.logging = kwargs.get("logging", True)
self.keep_quantile = kwargs.get("keep_quantile", 0)
self.local_autotune = kwargs.get("local_autotune", None)
self.train_thinning = kwargs.get("train_thinning", 1)
self.output_thinning = kwargs.get("output_thinning", 1)
self.n_sample_max = kwargs.get("n_sample_max", 10000)
self.precompile = kwargs.get("precompile", False)
self.verbose = kwargs.get("verbose", False)
# Set and override any given hyperparameters
ThibeauWouters marked this conversation as resolved.
Show resolved Hide resolved
self.hyperparameters = flowmc_default_hyperparameters
hyperparameter_names = list(flowmc_default_hyperparameters.keys())
for key, value in kwargs.items():
if key in hyperparameter_names:
self.hyperparameters[key] = value
for key, value in self.hyperparameters.items():
setattr(self, key, value)

self.variables = {"mean": None, "var": None}

Expand All @@ -98,19 +69,10 @@ def __init__(
self.optim_state = tx.init(eqx.filter(self.nf_model, eqx.is_array))
self.nf_training_loop, train_epoch, train_step = make_training_loop(tx)

# Initialized result dictionary
ThibeauWouters marked this conversation as resolved.
Show resolved Hide resolved
training = {}
training["chains"] = jnp.empty((self.n_chains, 0, self.n_dim))
training["log_prob"] = jnp.empty((self.n_chains, 0))
training["local_accs"] = jnp.empty((self.n_chains, 0))
training["global_accs"] = jnp.empty((self.n_chains, 0))
# Initialized result dictionaries
training = initialize_summary_dict(self)
training["loss_vals"] = jnp.empty((0, self.n_epochs))

production = {}
production["chains"] = jnp.empty((self.n_chains, 0, self.n_dim))
production["log_prob"] = jnp.empty((self.n_chains, 0))
production["local_accs"] = jnp.empty((self.n_chains, 0))
production["global_accs"] = jnp.empty((self.n_chains, 0))
production = initialize_summary_dict(self)

self.summary = {}
self.summary["training"] = training
Expand Down Expand Up @@ -341,7 +303,7 @@ def production_run(
last_step = self.sampling_loop(last_step, data)
return last_step

def get_sampler_state(self, training: bool = False) -> dict:
def get_sampler_state(self, which: str = "training") -> dict:
"""
Get the sampler state. There are two sets of sampler outputs one can get,
the training set and the production set.
Expand All @@ -354,10 +316,7 @@ def get_sampler_state(self, training: bool = False) -> dict:
training (bool): Whether to get the training set sampler state. Defaults to False.

"""
if training == True:
return self.summary["training"]
else:
return self.summary["production"]
return self.summary[which]

def sample_flow(self, n_samples: int) -> jnp.ndarray:
"""
Expand Down Expand Up @@ -495,3 +454,48 @@ def save_summary(self, path: str):
"""
with open(path, "wb") as f:
pickle.dump(self.summary, f)

kazewong marked this conversation as resolved.
Show resolved Hide resolved
def plot_summary(self, which: str = "training", **plotkwargs) -> None:
"""
Create plots of the most important quantities in the summary.

Args:
which (str, optional): Which summary dictionary to show in plots. Defaults to "training".
"""

# Choose the dataset
data = self.get_sampler_state(which)
# TODO add loss values in plotting
keys = ["local_accs", "global_accs", "log_prob"]

for key in keys:
self._single_plot(data, key, which, **plotkwargs)

def _single_plot(self, data: dict, name: str, which: str = "training", **plotkwargs):
"""
Create a single plot of a quantity in the summary.

Args:
data (dict): Dictionary with the summary data.
name (str): Name of the quantity to plot.
which (str, optional): Name of this summary dict. Defaults to "training".
"""
# Get plot kwargs
figsize = plotkwargs["figsize"] if "figsize" in plotkwargs else (12, 8)
alpha = plotkwargs["alpha"] if "alpha" in plotkwargs else 1
eps = 1e-3

# Prepare plot data
plotdata = data[name]
mean = jnp.mean(plotdata, axis=0)
x = [i+1 for i in range(len(mean))]

# Plot
plt.figure(figsize=figsize)
plt.plot(x, mean, linestyle="-", color="blue", alpha=alpha)
plt.xlabel("Iteration")
plt.ylabel(f"{name} ({which})")
# Extras for some variables:
if "acc" in name:
plt.ylim(0-eps, 1+eps)
plt.savefig(f"{self.outdir}{name}_{which}.png", bbox_inches='tight')
13 changes: 13 additions & 0 deletions src/flowMC/utils/__init__.py
ThibeauWouters marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import jax.numpy as jnp

# TODO - add loss values?
def initialize_summary_dict(sampler):

my_dict = dict()

my_dict["chains"] = jnp.empty((sampler.n_chains, 0, sampler.n_dim))
my_dict["log_prob"] = jnp.empty((sampler.n_chains, 0))
my_dict["local_accs"] = jnp.empty((sampler.n_chains, 0))
my_dict["global_accs"] = jnp.empty((sampler.n_chains, 0))

return my_dict
47 changes: 47 additions & 0 deletions src/flowMC/utils/hyperparameters.py
ThibeauWouters marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
flowmc_default_hyperparameters = {
"n_loop_training": 3,
"n_loop_production": 3,
"n_loop_pretraining": 0,
"n_local_steps": 50,
"n_global_steps": 50,
"n_chains": 20,
"n_epochs": 30,
"learning_rate": 0.001,
"max_samples": 100000,
"momentum": 0.9,
"batch_size": 10000,
"use_global": True,
"logging": True,
"keep_quantile": 0,
"local_autotune": None,
"train_thinning": 1,
"output_thinning": 1,
"n_sample_max": 10000,
"precompile": False,
"verbose": False,
"outdir": "./outdir/"
}

flowmc_hyperparameters_explanation = {
"n_loop_training": "(int): Number of training loops.",
"n_loop_production": "(int): Number of production loops.",
"n_loop_pretraining": "(int): Number of pretraining loops.",
"n_local_steps": "(int) Number of local steps per loop.",
"n_global_steps": "(int) Number of local steps per loop.",
"n_chains": "(int) Number of chains",
"n_epochs": "(int) Number of epochs to train the NF per training loop",
"learning_rate": "(float) Learning rate used in the training of the NF",
"max_samples": "(int) Maximum number of samples fed to training the NF model",
"momentum": "(float) Momentum used in the training of the NF model with the Adam optimizer",
"batch_size": "(int) Size of batches used to train the NF",
"use_global": "(bool) Whether to use an NF proposal as global sampler",
"logging": "(bool) Whether to log the training process",
"keep_quantile": "Quantile of chains to keep when training the normalizing flow model",
"local_autotune": "(Callable) Auto-tune function for the local sampler",
"train_thinning": "(int) Thinning parameter for the data used to train the normalizing flow",
"output_thinning": "(int) Thinning parameter with which to save the data ",
"n_sample_max": "(int) Maximum number of samples fed to training the NF model",
"precompile": "(bool) Whether to precompile",
"verbose": "(bool) Show steps of algorithm in detail",
"outdir_name": "(str) Location to which to save plots, samples and hyperparameter settings. Note: should ideally start with `./` and also end with `/`"
}