Skip to content

Commit

Permalink
Update blackjax/smc/waste_free.py
Browse files Browse the repository at this point in the history
Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
ciguaran and junpenglao authored Aug 26, 2024
1 parent ce4959c commit e455f43
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions blackjax/smc/waste_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ def update(rng_key, position, step_parameters):
# step particles is num_resmapled, num_mcmc_steps, dimension_of_variable
# want to transformed into num_resampled * num_mcmc_steps, dimension of variable
def reshape_step_particles(x):
if len(x.shape) > 2:
return x.reshape((x.shape[0] * x.shape[1], -1))
else:
return x.flatten()
num_resmapled, num_mcmc_steps, *dimension_of_variable = x.shape
return x.reshape((num_resmapled*num_mcmc_steps, *dimension_of_variable))

step_particles = jax.tree.map(reshape_step_particles, states.position)
new_particles = jax.tree.map(
Expand Down

0 comments on commit e455f43

Please sign in to comment.