Skip to content

Commit

Permalink
fix parameter split + docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper committed Jun 13, 2024
1 parent 7a8d3ce commit 93adbef
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
3 changes: 2 additions & 1 deletion blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def as_top_level_api(
mcmc_init_fn
The MCMC init function used to build a MCMC state from a particle position.
mcmc_parameters
The parameters of the MCMC step function.
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
resampling_fn
The function used to resample the particles.
target_ess
Expand Down
19 changes: 12 additions & 7 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def kernel(
Current state of the tempered SMC algorithm
lmbda
Current value of the tempering parameter
mcmc_parameters
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
Returns
-------
Expand All @@ -120,12 +123,13 @@ def kernel(
"""
delta = lmbda - state.lmbda

shared_mcmc_parameters = {
k: v[0, ...] for k, v in mcmc_parameters.items() if v.shape[0] == 1
}
unshared_mcmc_parameters = {
k: v for k, v in mcmc_parameters.items() if v.shape[0] != 1
}
shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

def log_weights_fn(position: ArrayLikeTree) -> float:
return delta * loglikelihood_fn(position)
Expand Down Expand Up @@ -188,7 +192,8 @@ def as_top_level_api(
mcmc_init_fn
The MCMC init function used to build a MCMC state from a particle position.
mcmc_parameters
The parameters of the MCMC step function.
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
resampling_fn
The function used to resample the particles.
num_mcmc_steps
Expand Down

0 comments on commit 93adbef

Please sign in to comment.