Skip to content

Commit

Permalink
Update Sampler.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong authored Feb 16, 2024
1 parent 3741e22 commit dc2319f
Showing 1 changed file with 1 addition and 16 deletions.
17 changes: 1 addition & 16 deletions src/flowMC/sampler/Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import optax
from flowMC.sampler.Proposal_Base import ProposalBase
from flowMC.nfmodel.base import NFModel
from flowMC.utils.postprocessing import gelman_rubin
from tqdm import tqdm
import equinox as eqx
import numpy as np
Expand All @@ -34,7 +33,6 @@
"precompile": False,
"verbose": False,
"outdir": "./outdir/",
"track_gelman_rubin": False,
}

class Sampler:
Expand Down Expand Up @@ -128,10 +126,6 @@ def __init__(
production["log_prob"] = jnp.empty((self.n_chains, 0))
production["local_accs"] = jnp.empty((self.n_chains, 0))
production["global_accs"] = jnp.empty((self.n_chains, 0))

if self.track_gelman_rubin:
training["gelman_rubin"] = jnp.empty((self.n_dim, 0))
production["gelman_rubin"] = jnp.empty((self.n_dim, 0))

self.summary = {}
self.summary["training"] = training
Expand Down Expand Up @@ -284,15 +278,6 @@ def sampling_loop(
global_acceptance[:, 1::self.output_thinning],
axis=1,
)

if self.track_gelman_rubin:
# Get chains up to this point and compute Gelman-Rubin R statistic
chains = self.summary[summary_mode]["chains"]
R = gelman_rubin(chains)
R = jnp.reshape(R, (-1, 1))
self.summary[summary_mode]["gelman_rubin"] = jnp.append(
self.summary[summary_mode]["gelman_rubin"], R, axis=1
)

last_step = self.summary[summary_mode]["chains"][:, -1]

Expand Down Expand Up @@ -525,4 +510,4 @@ def save_summary(self, path: str):
path (str): Path to save the summary.
"""
with open(path, "wb") as f:
pickle.dump(self.summary, f)
pickle.dump(self.summary, f)

0 comments on commit dc2319f

Please sign in to comment.