Skip to content

Commit

Permalink
better test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran committed Aug 16, 2024
1 parent 8145cbb commit 9550d5d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
2 changes: 2 additions & 0 deletions blackjax/smc/waste_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,6 @@ def reshape_step_particles(x):


def waste_free_smc(n_particles, p):
if not n_particles % p ==0:
raise ValueError("p must be a divider of n_particles ")
return functools.partial(update_waste_free, num_resampled=int(n_particles / p), p=p)
50 changes: 49 additions & 1 deletion tests/smc/test_waste_free_smc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the tempered SMC steps and routine"""

import functools

import chex
Expand All @@ -7,11 +8,14 @@
import numpy as np
import pytest
from absl.testing import absltest
from scipy.stats import stats

import blackjax
import blackjax.smc.resampling as resampling
from blackjax import adaptive_tempered_smc, tempered_smc
from blackjax.smc import extend_params
from blackjax.mcmc.random_walk import build_rmh
from blackjax.smc import extend_params, base
from blackjax.smc.base import SMCState
from blackjax.smc.waste_free import update_waste_free, waste_free_smc
from tests.smc import SMCLinearRegressionTestCase
from tests.smc.test_tempered_smc import inference_loop
Expand Down Expand Up @@ -105,6 +109,44 @@ def test_adaptive_tempered_smc(self):
self.assert_linear_regression_test_case(result)


class Update_waste_free_multivariate_particles(chex.TestCase):

@chex.variants(with_jit=True)
def test_update_waste_free_multivariate_particles(self):
"""
Given resampled multivariate particles,
when updating with waste free, they are joined
by the result of iterating the MCMC chain to
get a bigger set of particles.
"""
resampled_particles = np.ones((50, 3))
n_particles = 100

def normal_logdensity(x):
return jnp.log(
jax.scipy.stats.multivariate_normal.pdf(
x, mean=np.zeros(3), cov=np.diag(np.ones(3))
)
)

def rmh_proposal_distribution(rng_key, position):
return position + jax.random.normal(rng_key, (3,)) * 25.0

kernel = functools.partial(
blackjax.rmh.build_kernel(), transition_generator=rmh_proposal_distribution
)
init = blackjax.rmh.init
update, _ = waste_free_smc(n_particles, 2)(
init, normal_logdensity, kernel, n_particles
)

updated_particles, infos = self.variant(update)(
jax.random.split(jax.random.PRNGKey(10), 50), resampled_particles, {}
)

assert updated_particles.shape == (n_particles, 3)


def test_waste_free_set_num_mcmc_steps():
with pytest.raises(ValueError) as exc_info:
update_waste_free(
Expand All @@ -115,5 +157,11 @@ def test_waste_free_set_num_mcmc_steps():
)


def test_waste_free_p_non_divier():
with pytest.raises(ValueError) as exc_info:
waste_free_smc(100, 3)
assert str(exc_info.value).startswith("p must be a divider")


if __name__ == "__main__":
absltest.main()

0 comments on commit 9550d5d

Please sign in to comment.