diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index b373d062f..ca41cc39f 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, NamedTuple import jax @@ -119,6 +120,13 @@ def kernel( """ delta = lmbda - state.lmbda + shared_mcmc_parameters = { + k: v[0, ...] for k, v in mcmc_parameters.items() if v.shape[0] == 1 + } + unshared_mcmc_parameters = { + k: v for k, v in mcmc_parameters.items() if v.shape[0] != 1 + } + def log_weights_fn(position: ArrayLikeTree) -> float: return delta * loglikelihood_fn(position) @@ -127,11 +135,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood + shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + def mcmc_kernel(rng_key, position, step_parameters): state = mcmc_init_fn(position, tempered_logposterior_fn) def body_fn(state, rng_key): - new_state, info = mcmc_step_fn( + new_state, info = shared_mcmc_step_fn( rng_key, state, tempered_logposterior_fn, **step_parameters ) return new_state, info @@ -142,7 +152,7 @@ def body_fn(state, rng_key): smc_state, info = smc.base.step( rng_key, - SMCState(state.particles, state.weights, mcmc_parameters), + SMCState(state.particles, state.weights, unshared_mcmc_parameters), jax.vmap(mcmc_kernel), jax.vmap(log_weights_fn), resampling_fn, diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index bf970ae47..1ec806063 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -281,7 +281,7 @@ def test_with_adaptive_tempered(self): def parameter_update(state, info): return extend_params( - 100, + 1, { "inverse_mass_matrix": mass_matrix_from_particles(state.particles), "step_size": 10e-2, @@ -298,7 +298,7 @@ def parameter_update(state, info): resampling.systematic, mcmc_parameter_update_fn=parameter_update, initial_parameter_value=extend_params( - 100, + 1, dict( inverse_mass_matrix=jnp.eye(2), step_size=10e-2, @@ -326,7 +326,7 @@ def body(carry): _, state = inference_loop(smc_kernel, self.key, init_state) - assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2) + assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2) self.assert_linear_regression_test_case(state.sampler_state) @chex.all_variants(with_pmap=False) @@ -340,7 +340,7 @@ def test_with_tempered_smc(self): def parameter_update(state, info): return extend_params( - 100, + 1, { "inverse_mass_matrix": mass_matrix_from_particles(state.particles), "step_size": 10e-2, @@ -357,7 +357,7 @@ def parameter_update(state, info): resampling.systematic, mcmc_parameter_update_fn=parameter_update, initial_parameter_value=extend_params( - 100, + 1, dict( inverse_mass_matrix=jnp.eye(2), step_size=10e-2, diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index a7d9acdd8..dac590a09 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -65,16 +65,25 @@ def logprior_fn(x): hmc_kernel = blackjax.hmc.build_kernel() hmc_init = blackjax.hmc.init - hmc_parameters = extend_params( - num_particles, - { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(2), - "num_integration_steps": 50, - }, + hmc_parameters_list = [ + extend_params( + num_particles if extend else 1, + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) + for extend in [True, False] + ] + hmc_parameters_list.append( + extend_params( + num_particles, {"step_size": 10e-2, "num_integration_steps": 50} + ) + | extend_params(num_particles, {"inverse_mass_matrix": jnp.eye(2)}) ) - for target_ess in [0.5, 0.75]: + for target_ess, hmc_parameters in zip([0.5, 0.5, 0.75], hmc_parameters_list): tempering = adaptive_tempered_smc( logprior_fn, loglikelihood_fn, @@ -115,7 +124,7 @@ def test_fixed_schedule_tempered_smc(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() hmc_parameters = extend_params( - 100, + 1, { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(2),