diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index f6dd7c106..1f1345cc4 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -79,7 +79,7 @@ def build_kernel(): def transition_energy(state, new_state, step_size): """Transition energy to go from `state` to `new_state`""" theta = jax.tree_util.tree_map( - lambda new_x, x, g: new_x - x - step_size * g, + lambda x, new_x, g: x - new_x - step_size * g, state.position, new_state.position, new_state.logdensity_grad,