diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index e78d6a468..18a07625b 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -721,6 +721,7 @@ def univariate_normal_test_case( num_sampling_steps, burnin, postprocess_samples=None, + **kwargs, ): inference_key, orbit_key = jax.random.split(rng_key) _, states, _ = self.variant( @@ -728,10 +729,10 @@ def univariate_normal_test_case( run_inference_algorithm, inference_algorithm=inference_algorithm, num_steps=num_sampling_steps, + **kwargs, ) )(rng_key=inference_key, initial_state=initial_state) - # else: if postprocess_samples: samples = postprocess_samples(states, orbit_key) else: @@ -844,9 +845,8 @@ def test_orbital_hmc(self): burnin = 15_000 def postprocess_samples(states, key): - return orbit_samples( - states.positions[burnin:], states.weights[burnin:], key - ) + positions, weights = states + return orbit_samples(positions[burnin:], weights[burnin:], key) self.univariate_normal_test_case( inference_algorithm, @@ -855,6 +855,7 @@ def postprocess_samples(states, key): 20_000, burnin, postprocess_samples, + transform=lambda x: (x.positions, x.weights), ) @chex.all_variants(with_pmap=False)