Skip to content

Commit

Permalink
Fix sampling test. (#693)
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Jun 5, 2024
1 parent a4408d3 commit dd9ba03
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,17 +721,18 @@ def univariate_normal_test_case(
num_sampling_steps,
burnin,
postprocess_samples=None,
**kwargs,
):
inference_key, orbit_key = jax.random.split(rng_key)
_, states, _ = self.variant(
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
num_steps=num_sampling_steps,
**kwargs,
)
)(rng_key=inference_key, initial_state=initial_state)

# else:
if postprocess_samples:
samples = postprocess_samples(states, orbit_key)
else:
Expand Down Expand Up @@ -844,9 +845,8 @@ def test_orbital_hmc(self):
burnin = 15_000

def postprocess_samples(states, key):
return orbit_samples(
states.positions[burnin:], states.weights[burnin:], key
)
positions, weights = states
return orbit_samples(positions[burnin:], weights[burnin:], key)

self.univariate_normal_test_case(
inference_algorithm,
Expand All @@ -855,6 +855,7 @@ def postprocess_samples(states, key):
20_000,
burnin,
postprocess_samples,
transform=lambda x: (x.positions, x.weights),
)

@chex.all_variants(with_pmap=False)
Expand Down

0 comments on commit dd9ba03

Please sign in to comment.