diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index b8a611606..10fb194fa 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -130,7 +130,8 @@ def as_top_level_api( mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters - The parameters of the MCMC step function. + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. target_ess diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index ca41cc39f..43b83d034 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -109,6 +109,9 @@ def kernel( Current state of the tempered SMC algorithm lmbda Current value of the tempering parameter + mcmc_parameters + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. Returns ------- @@ -120,12 +123,13 @@ def kernel( """ delta = lmbda - state.lmbda - shared_mcmc_parameters = { - k: v[0, ...] for k, v in mcmc_parameters.items() if v.shape[0] == 1 - } - unshared_mcmc_parameters = { - k: v for k, v in mcmc_parameters.items() if v.shape[0] != 1 - } + shared_mcmc_parameters = {} + unshared_mcmc_parameters = {} + for k, v in mcmc_parameters.items(): + if v.shape[0] == 1: + shared_mcmc_parameters[k] = v[0, ...] + else: + unshared_mcmc_parameters[k] = v def log_weights_fn(position: ArrayLikeTree) -> float: return delta * loglikelihood_fn(position) @@ -188,7 +192,8 @@ def as_top_level_api( mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters - The parameters of the MCMC step function. + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. num_mcmc_steps