From dc2319ff4b8341f29a390ebc87d2ff58bc3ca2f7 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 16 Feb 2024 14:05:10 -0500 Subject: [PATCH] Update Sampler.py --- src/flowMC/sampler/Sampler.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/flowMC/sampler/Sampler.py b/src/flowMC/sampler/Sampler.py index 87782fc..ab6d8d8 100644 --- a/src/flowMC/sampler/Sampler.py +++ b/src/flowMC/sampler/Sampler.py @@ -7,7 +7,6 @@ import optax from flowMC.sampler.Proposal_Base import ProposalBase from flowMC.nfmodel.base import NFModel -from flowMC.utils.postprocessing import gelman_rubin from tqdm import tqdm import equinox as eqx import numpy as np @@ -34,7 +33,6 @@ "precompile": False, "verbose": False, "outdir": "./outdir/", - "track_gelman_rubin": False, } class Sampler: @@ -128,10 +126,6 @@ def __init__( production["log_prob"] = jnp.empty((self.n_chains, 0)) production["local_accs"] = jnp.empty((self.n_chains, 0)) production["global_accs"] = jnp.empty((self.n_chains, 0)) - - if self.track_gelman_rubin: - training["gelman_rubin"] = jnp.empty((self.n_dim, 0)) - production["gelman_rubin"] = jnp.empty((self.n_dim, 0)) self.summary = {} self.summary["training"] = training @@ -284,15 +278,6 @@ def sampling_loop( global_acceptance[:, 1::self.output_thinning], axis=1, ) - - if self.track_gelman_rubin: - # Get chains up to this point and compute Gelman-Rubin R statistic - chains = self.summary[summary_mode]["chains"] - R = gelman_rubin(chains) - R = jnp.reshape(R, (-1, 1)) - self.summary[summary_mode]["gelman_rubin"] = jnp.append( - self.summary[summary_mode]["gelman_rubin"], R, axis=1 - ) last_step = self.summary[summary_mode]["chains"][:, -1] @@ -525,4 +510,4 @@ def save_summary(self, path: str): path (str): Path to save the summary. """ with open(path, "wb") as f: - pickle.dump(self.summary, f) \ No newline at end of file + pickle.dump(self.summary, f)