diff --git a/src/flowMC/Sampler.py b/src/flowMC/Sampler.py index c167406..9be37ec 100644 --- a/src/flowMC/Sampler.py +++ b/src/flowMC/Sampler.py @@ -127,7 +127,7 @@ def __init__( self.likelihood_vec = self.local_sampler.logpdf_vmap self.optim = optax.chain( - optax.clip(1.0), optax.adamw(self.learning_rate, self.momentum), optax.ema(0.999) + optax.clip(1.0), optax.adam(self.learning_rate, self.momentum) ) self.optim_state = self.optim.init(eqx.filter(self.nf_model, eqx.is_array))