From 3dc38096ca9ae8f4e8f47fdcee4b876f4b4c6bc3 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 25 Mar 2024 11:56:21 -0300 Subject: [PATCH] SMC: allow each mutation kernel to have different parameters. (#649) * vmaping over parameters in base * switch from mcmc_factory to just passing in parameters * pre-commit and typing * CRU and docs improvement * pre-commit * code review updates * pre-commit * rename test --- blackjax/smc/__init__.py | 8 ++- blackjax/smc/base.py | 20 ++++-- blackjax/smc/inner_kernel_tuning.py | 39 +++++------ blackjax/smc/tempered.py | 6 +- tests/mcmc/test_sampling.py | 8 ++- tests/smc/test_inner_kernel_tuning.py | 89 ++++++++++++++++---------- tests/smc/test_kernel_compatibility.py | 72 +++++++++++++++------ tests/smc/test_smc.py | 68 ++++++++++++++------ tests/smc/test_tempered_smc.py | 40 +++++++----- 9 files changed, 229 insertions(+), 121 deletions(-) diff --git a/blackjax/smc/__init__.py b/blackjax/smc/__init__.py index 180cd8259..ef10b10e6 100644 --- a/blackjax/smc/__init__.py +++ b/blackjax/smc/__init__.py @@ -1,3 +1,9 @@ from . import adaptive_tempered, inner_kernel_tuning, tempered +from .base import extend_params -__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning"] +__all__ = [ + "adaptive_tempered", + "tempered", + "inner_kernel_tuning", + "extend_params", +] diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 409f588d2..4a9ff17c3 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -40,6 +40,7 @@ class SMCState(NamedTuple): particles: ArrayTree weights: Array + update_parameters: ArrayTree class SMCInfo(NamedTuple): @@ -59,12 +60,12 @@ class SMCInfo(NamedTuple): update_info: NamedTuple -def init(particles: ArrayLikeTree): +def init(particles: ArrayLikeTree, init_update_params): # Infer the number of particles from the size of the leading dimension of # the first leaf of the inputted PyTree. num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] weights = jnp.ones(num_particles) / num_particles - return SMCState(particles, weights) + return SMCState(particles, weights, init_update_params) def step( @@ -137,13 +138,24 @@ def step( particles = jax.tree_map(lambda x: x[resampling_idx], state.particles) keys = jax.random.split(updating_key, num_resampled) - particles, update_info = update_fn(keys, particles) + particles, update_info = update_fn(keys, particles, state.update_parameters) log_weights = weight_fn(particles) logsum_weights = jax.scipy.special.logsumexp(log_weights) normalizing_constant = logsum_weights - jnp.log(num_particles) weights = jnp.exp(log_weights - logsum_weights) - return SMCState(particles, weights), SMCInfo( + return SMCState(particles, weights, state.update_parameters), SMCInfo( resampling_idx, normalizing_constant, update_info ) + + +def extend_params(n_particles, params): + """Given a dictionary of params, repeats them for every single particle. The expected + usage is in cases where the aim is to repeat the same parameters for all chains within SMC. + """ + + def extend(param): + return jnp.repeat(jnp.asarray(param)[None, ...], n_particles, axis=0) + + return jax.tree_map(extend, params) diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 6aaf3a5d3..705a60c35 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -8,8 +8,15 @@ class StateWithParameterOverride(NamedTuple): + """ + Stores both the sampling status and also a dictionary + that contains an dictionary with parameter names as key + and (n_particles, *) arrays as meanings. The latter + represent a parameter per chain for the next mutation step. + """ + sampler_state: ArrayTree - parameter_override: ArrayTree + parameter_override: Dict[str, ArrayTree] def init(alg_init_fn, position, initial_parameter_value): @@ -20,11 +27,10 @@ def build_kernel( smc_algorithm, logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_factory: Callable, + mcmc_step_fn: Callable, mcmc_init_fn: Callable, - mcmc_parameters: Dict, resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], + mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], num_mcmc_steps: int = 10, **extra_parameters, ) -> Callable: @@ -41,12 +47,11 @@ def build_kernel( A function that computes the log density of the prior distribution loglikelihood_fn A function that returns the probability at a given position. - mcmc_factory - A callable that can construct an inner kernel out of the newly-computed parameter + mcmc_step_fn: + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. + mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) mcmc_init_fn A callable that initializes the inner kernel - mcmc_parameters - Other (fixed across SMC iterations) parameters for the inner kernel mcmc_parameter_update_fn A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. extra_parameters: @@ -59,9 +64,9 @@ def kernel( step_fn = smc_algorithm( logprior_fn=logprior_fn, loglikelihood_fn=loglikelihood_fn, - mcmc_step_fn=mcmc_factory(state.parameter_override), + mcmc_step_fn=mcmc_step_fn, mcmc_init_fn=mcmc_init_fn, - mcmc_parameters=mcmc_parameters, + mcmc_parameters=state.parameter_override, resampling_fn=resampling_fn, num_mcmc_steps=num_mcmc_steps, **extra_parameters, @@ -89,17 +94,15 @@ class inner_kernel_tuning: A function that computes the log density of the prior distribution loglikelihood_fn A function that returns the probability at a given position. - mcmc_factory - A callable that can construct an inner kernel out of the newly-computed parameter + mcmc_step_fn + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. mcmc_init_fn A callable that initializes the inner kernel - mcmc_parameters - Other (fixed across SMC iterations) parameters for the inner kernel step mcmc_parameter_update_fn A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. initial_parameter_value - Paramter to be used by the mcmc_factory before the first iteration. + Parameter to be used by the mcmc_factory before the first iteration. extra_parameters: parameters to be used for the creation of the smc_algorithm. @@ -117,9 +120,8 @@ def __new__( # type: ignore[misc] smc_algorithm: Union[adaptive_tempered_smc, tempered_smc], logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_factory: Callable, + mcmc_step_fn: Callable, mcmc_init_fn: Callable, - mcmc_parameters: Dict, resampling_fn: Callable, mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], initial_parameter_value, @@ -130,9 +132,8 @@ def __new__( # type: ignore[misc] smc_algorithm, logprior_fn, loglikelihood_fn, - mcmc_factory, + mcmc_step_fn, mcmc_init_fn, - mcmc_parameters, resampling_fn, mcmc_parameter_update_fn, num_mcmc_steps, diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 49fa21277..561eadecc 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -127,12 +127,12 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood - def mcmc_kernel(rng_key, position): + def mcmc_kernel(rng_key, position, step_parameters): state = mcmc_init_fn(position, tempered_logposterior_fn) def body_fn(state, rng_key): new_state, info = mcmc_step_fn( - rng_key, state, tempered_logposterior_fn, **mcmc_parameters + rng_key, state, tempered_logposterior_fn, **step_parameters ) return new_state, info @@ -142,7 +142,7 @@ def body_fn(state, rng_key): smc_state, info = smc.base.step( rng_key, - SMCState(state.particles, state.weights), + SMCState(state.particles, state.weights, mcmc_parameters), jax.vmap(mcmc_kernel), jax.vmap(log_weights_fn), resampling_fn, diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 7d20805ab..51831b587 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -27,12 +27,12 @@ def sample_orbit(orbit, weights, rng_key): return samples -def irmh_proposal_distribution(rng_key): +def irmh_proposal_distribution(rng_key, mean): """ The proposal distribution is chosen to be wider than the target, so that the RMH rejection doesn't make the sample overemphasize the center of the target distribution. """ - return 1.0 + jax.random.normal(rng_key) * 25.0 + return mean + jax.random.normal(rng_key) * 25.0 def rmh_proposal_distribution(rng_key, position): @@ -657,7 +657,9 @@ def test_univariate_normal( self, algorithm, initial_position, parameters, num_sampling_steps, burnin ): if algorithm == blackjax.irmh: - parameters["proposal_distribution"] = irmh_proposal_distribution + parameters["proposal_distribution"] = functools.partial( + irmh_proposal_distribution, mean=1.0 + ) if algorithm == blackjax.rmh: parameters["proposal_generator"] = rmh_proposal_distribution diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 1bbc68970..cf1db09dd 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -12,6 +12,8 @@ import blackjax import blackjax.smc.resampling as resampling from blackjax import adaptive_tempered_smc, tempered_smc +from blackjax.mcmc.random_walk import build_irmh +from blackjax.smc import extend_params from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate from blackjax.smc.tuning.from_particles import ( @@ -92,38 +94,37 @@ def smc_inner_kernel_tuning_test_case( proposal_factory.return_value = 100 def mcmc_parameter_update_fn(state, info): - return 100 + return extend_params(1000, {"mean": 100}) - mcmc_factory = MagicMock() - sampling_algorithm = MagicMock() - mcmc_factory.return_value = sampling_algorithm prior = lambda x: stats.norm.logpdf(x) - def kernel_factory(proposal_distribution): - kernel = blackjax.irmh.build_kernel() - - def wrapped_kernel(rng_key, state, logdensity): - return kernel(rng_key, state, logdensity, proposal_distribution) - - return wrapped_kernel + def wrapped_kernel(rng_key, state, logdensity, mean): + return build_irmh()( + rng_key, + state, + logdensity, + functools.partial(irmh_proposal_distribution, mean=mean), + ) kernel = inner_kernel_tuning( logprior_fn=prior, loglikelihood_fn=specialized_log_weights_fn, - mcmc_factory=kernel_factory, + mcmc_step_fn=wrapped_kernel, mcmc_init_fn=blackjax.irmh.init, resampling_fn=resampling.systematic, smc_algorithm=smc_algorithm, - mcmc_parameters={}, mcmc_parameter_update_fn=mcmc_parameter_update_fn, - initial_parameter_value=irmh_proposal_distribution, + initial_parameter_value=extend_params(1000, {"mean": 1.0}), **smc_parameters, ) new_state, new_info = kernel.step( self.key, state=kernel.init(init_particles), **step_parameters ) - assert new_state.parameter_override == 100 + assert set(new_state.parameter_override.keys()) == { + "mean", + } + np.testing.assert_allclose(new_state.parameter_override["mean"], 100) class MeanAndStdFromParticlesTest(chex.TestCase): @@ -270,14 +271,6 @@ def setUp(self): super().setUp() self.key = jax.random.key(42) - def mcmc_factory(self, mass_matrix): - return functools.partial( - blackjax.hmc.build_kernel(), - inverse_mass_matrix=mass_matrix, - step_size=10e-2, - num_integration_steps=50, - ) - @chex.all_variants(with_pmap=False) def test_with_adaptive_tempered(self): ( @@ -286,18 +279,32 @@ def test_with_adaptive_tempered(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() + def parameter_update(state, info): + return extend_params( + 100, + { + "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "step_size": 10e-2, + "num_integration_steps": 50, + }, + ) + init, step = blackjax.inner_kernel_tuning( adaptive_tempered_smc, logprior_fn, loglikelihood_fn, - self.mcmc_factory, + blackjax.hmc.build_kernel(), blackjax.hmc.init, - {}, resampling.systematic, - mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles( - state.particles + mcmc_parameter_update_fn=parameter_update, + initial_parameter_value=extend_params( + 100, + dict( + inverse_mass_matrix=jnp.eye(2), + step_size=10e-2, + num_integration_steps=50, + ), ), - initial_parameter_value=jnp.eye(2), num_mcmc_steps=10, target_ess=0.5, ) @@ -319,7 +326,7 @@ def body(carry): state, _ = inference_loop(smc_kernel, self.key, init_state) - assert state.parameter_override.shape == (2, 2) + assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2) self.assert_linear_regression_test_case(state.sampler_state) @chex.all_variants(with_pmap=False) @@ -331,18 +338,32 @@ def test_with_tempered_smc(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() + def parameter_update(state, info): + return extend_params( + 100, + { + "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "step_size": 10e-2, + "num_integration_steps": 50, + }, + ) + init, step = blackjax.inner_kernel_tuning( tempered_smc, logprior_fn, loglikelihood_fn, - self.mcmc_factory, + blackjax.hmc.build_kernel(), blackjax.hmc.init, - {}, resampling.systematic, - mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles( - state.particles + mcmc_parameter_update_fn=parameter_update, + initial_parameter_value=extend_params( + 100, + dict( + inverse_mass_matrix=jnp.eye(2), + step_size=10e-2, + num_integration_steps=50, + ), ), - initial_parameter_value=jnp.eye(2), num_mcmc_steps=10, ) diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index 3d2469914..3e675c2cc 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -7,6 +7,7 @@ import blackjax from blackjax import adaptive_tempered_smc from blackjax.mcmc.random_walk import normal +from blackjax.smc import extend_params class SMCAndMCMCIntegrationTest(unittest.TestCase): @@ -18,8 +19,9 @@ class SMCAndMCMCIntegrationTest(unittest.TestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) + self.n_particles = 3 self.initial_particles = jax.random.multivariate_normal( - self.key, jnp.zeros(2), jnp.eye(2), (3,) + self.key, jnp.zeros(2), jnp.eye(2), (self.n_particles,) ) def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters): @@ -40,54 +42,82 @@ def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters): kernel(self.key, init(self.initial_particles)) def test_compatible_with_rwm(self): + rwm = blackjax.additive_step_random_walk.build_kernel() + + def kernel(rng_key, state, logdensity_fn, proposal_mean): + return rwm(rng_key, state, logdensity_fn, normal(proposal_mean)) + self.check_compatible( - blackjax.additive_step_random_walk.build_kernel(), + kernel, blackjax.additive_step_random_walk.init, - {"random_step": normal(1.0)}, + extend_params(self.n_particles, {"proposal_mean": 1.0}), ) def test_compatible_with_rmh(self): + rmh = blackjax.rmh.build_kernel() + + def kernel( + rng_key, state, logdensity_fn, proposal_mean, proposal_logdensity_fn=None + ): + return rmh( + rng_key, + state, + logdensity_fn, + lambda a, b: blackjax.mcmc.random_walk.normal(proposal_mean)(a, b), + proposal_logdensity_fn, + ) + self.check_compatible( - blackjax.rmh.build_kernel(), + kernel, blackjax.rmh.init, - { - "transition_generator": lambda a, b: blackjax.mcmc.random_walk.normal( - 1.0 - )(a, b) - }, + extend_params(self.n_particles, {"proposal_mean": 1.0}), ) def test_compatible_with_hmc(self): self.check_compatible( blackjax.hmc.build_kernel(), blackjax.hmc.init, - { - "step_size": 0.3, - "inverse_mass_matrix": jnp.array([1]), - "num_integration_steps": 1, - }, + extend_params( + self.n_particles, + { + "step_size": 0.3, + "inverse_mass_matrix": jnp.array([1.0]), + "num_integration_steps": 1, + }, + ), ) def test_compatible_with_irmh(self): + def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None): + return blackjax.irmh.build_kernel()( + rng_key, + state, + logdensity_fn, + lambda key: mean + jax.random.normal(key), + proposal_logdensity_fn, + ) + self.check_compatible( - blackjax.irmh.build_kernel(), + kernel, blackjax.irmh.init, - { - "proposal_distribution": lambda key: jnp.array([1.0, 1.0]) - + jax.random.normal(key) - }, + extend_params(self.n_particles, {"mean": jnp.array([1.0, 1.0])}), ) def test_compatible_with_nuts(self): self.check_compatible( blackjax.nuts.build_kernel(), blackjax.nuts.init, - {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, + extend_params( + self.n_particles, + {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, + ), ) def test_compatible_with_mala(self): self.check_compatible( - blackjax.mala.build_kernel(), blackjax.mala.init, {"step_size": 1e-10} + blackjax.mala.build_kernel(), + blackjax.mala.init, + extend_params(self.n_particles, {"step_size": 1e-10}), ) @staticmethod diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 242e11c55..2838e984f 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -8,7 +8,7 @@ import blackjax import blackjax.smc.resampling as resampling -from blackjax.smc.base import init, step +from blackjax.smc.base import extend_params, init, step def logdensity_fn(position): @@ -31,14 +31,8 @@ def test_smc(self): num_mcmc_steps = 20 num_particles = 1000 - hmc = blackjax.hmc( - logdensity_fn, - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=50, - ) - - def update_fn(rng_key, position): + def update_fn(rng_key, position, update_params): + hmc = blackjax.hmc(logdensity_fn, **update_params) state = hmc.init(position) def body_fn(state, rng_key): @@ -53,7 +47,13 @@ def body_fn(state, rng_key): # Initialize the state of the SMC sampler init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) - state = init(init_particles) + same_for_all_params = dict( + step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50 + ) + state = init( + init_particles, + extend_params(num_particles, same_for_all_params), + ) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4))( @@ -74,15 +74,9 @@ def test_smc_waste_free(self): num_particles = 1000 num_resampled = num_particles // num_mcmc_steps - hmc = blackjax.hmc( - logdensity_fn, - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=100, - ) - - def waste_free_update_fn(keys, particles): - def one_particle_fn(rng_key, position): + def waste_free_update_fn(keys, particles, update_params): + def one_particle_fn(rng_key, position, particle_update_params): + hmc = blackjax.hmc(logdensity_fn, **particle_update_params) state = hmc.init(position) def body_fn(state, rng_key): @@ -93,7 +87,7 @@ def body_fn(state, rng_key): _, (states, info) = jax.lax.scan(body_fn, state, keys) return states.position, info - particles, info = jax.vmap(one_particle_fn)(keys, particles) + particles, info = jax.vmap(one_particle_fn)(keys, particles, update_params) particles = particles.reshape((num_particles,)) return particles, info @@ -101,7 +95,17 @@ def body_fn(state, rng_key): # Initialize the state of the SMC sampler init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) - state = init(init_particles) + state = init( + init_particles, + extend_params( + num_resampled, + dict( + step_size=1e-2, + inverse_mass_matrix=jnp.eye(1), + num_integration_steps=100, + ), + ), + ) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4, 5))( @@ -118,5 +122,27 @@ def body_fn(state, rng_key): np.testing.assert_allclose(1.0, std, atol=1e-1) +class ExtendParamsTest(chex.TestCase): + def test_extend_params(self): + extended = extend_params( + 3, + { + "a": 50, + "b": np.array([50]), + "c": np.array([50, 60]), + "d": np.array([[1, 2], [3, 4]]), + }, + ) + np.testing.assert_allclose(extended["a"], np.ones((3,)) * 50) + np.testing.assert_allclose(extended["b"], np.array([[50], [50], [50]])) + np.testing.assert_allclose( + extended["c"], np.array([[50, 60], [50, 60], [50, 60]]) + ) + np.testing.assert_allclose( + extended["d"], + np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]), + ) + + if __name__ == "__main__": absltest.main() diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index f4234d117..3ab387e14 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -12,6 +12,7 @@ import blackjax.smc.resampling as resampling import blackjax.smc.solver as solver from blackjax import adaptive_tempered_smc, tempered_smc +from blackjax.smc import extend_params from tests.smc import SMCLinearRegressionTestCase @@ -64,11 +65,14 @@ def logprior_fn(x): hmc_kernel = blackjax.hmc.build_kernel() hmc_init = blackjax.hmc.init - hmc_parameters = { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(2), - "num_integration_steps": 50, - } + hmc_parameters = extend_params( + num_particles, + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) for target_ess in [0.5, 0.75]: tempering = adaptive_tempered_smc( @@ -110,11 +114,14 @@ def test_fixed_schedule_tempered_smc(self): lambda_schedule = np.logspace(-5, 0, num_tempering_steps) hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(2), - "num_integration_steps": 50, - } + hmc_parameters = extend_params( + 100, + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) tempering = tempered_smc( logprior_fn, @@ -174,11 +181,14 @@ def test_normalizing_constant(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(num_dim), - "num_integration_steps": 50, - } + hmc_parameters = extend_params( + num_particles, + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(num_dim), + "num_integration_steps": 50, + }, + ) tempering = adaptive_tempered_smc( logprior_fn,