diff --git a/blackjax/smc/waste_free.py b/blackjax/smc/waste_free.py index 2f0ced582..9727567e9 100644 --- a/blackjax/smc/waste_free.py +++ b/blackjax/smc/waste_free.py @@ -67,4 +67,6 @@ def reshape_step_particles(x): def waste_free_smc(n_particles, p): + if not n_particles % p ==0: + raise ValueError("p must be a divider of n_particles ") return functools.partial(update_waste_free, num_resampled=int(n_particles / p), p=p) diff --git a/tests/smc/test_waste_free_smc.py b/tests/smc/test_waste_free_smc.py index 8b470aceb..886d08492 100644 --- a/tests/smc/test_waste_free_smc.py +++ b/tests/smc/test_waste_free_smc.py @@ -1,4 +1,5 @@ """Test the tempered SMC steps and routine""" + import functools import chex @@ -7,11 +8,14 @@ import numpy as np import pytest from absl.testing import absltest +from scipy.stats import stats import blackjax import blackjax.smc.resampling as resampling from blackjax import adaptive_tempered_smc, tempered_smc -from blackjax.smc import extend_params +from blackjax.mcmc.random_walk import build_rmh +from blackjax.smc import extend_params, base +from blackjax.smc.base import SMCState from blackjax.smc.waste_free import update_waste_free, waste_free_smc from tests.smc import SMCLinearRegressionTestCase from tests.smc.test_tempered_smc import inference_loop @@ -105,6 +109,44 @@ def test_adaptive_tempered_smc(self): self.assert_linear_regression_test_case(result) +class Update_waste_free_multivariate_particles(chex.TestCase): + + @chex.variants(with_jit=True) + def test_update_waste_free_multivariate_particles(self): + """ + Given resampled multivariate particles, + when updating with waste free, they are joined + by the result of iterating the MCMC chain to + get a bigger set of particles. + """ + resampled_particles = np.ones((50, 3)) + n_particles = 100 + + def normal_logdensity(x): + return jnp.log( + jax.scipy.stats.multivariate_normal.pdf( + x, mean=np.zeros(3), cov=np.diag(np.ones(3)) + ) + ) + + def rmh_proposal_distribution(rng_key, position): + return position + jax.random.normal(rng_key, (3,)) * 25.0 + + kernel = functools.partial( + blackjax.rmh.build_kernel(), transition_generator=rmh_proposal_distribution + ) + init = blackjax.rmh.init + update, _ = waste_free_smc(n_particles, 2)( + init, normal_logdensity, kernel, n_particles + ) + + updated_particles, infos = self.variant(update)( + jax.random.split(jax.random.PRNGKey(10), 50), resampled_particles, {} + ) + + assert updated_particles.shape == (n_particles, 3) + + def test_waste_free_set_num_mcmc_steps(): with pytest.raises(ValueError) as exc_info: update_waste_free( @@ -115,5 +157,11 @@ def test_waste_free_set_num_mcmc_steps(): ) +def test_waste_free_p_non_divier(): + with pytest.raises(ValueError) as exc_info: + waste_free_smc(100, 3) + assert str(exc_info.value).startswith("p must be a divider") + + if __name__ == "__main__": absltest.main()