Skip to content

Commit

Permalink
Fix MALA transition energy (#653)
Browse files Browse the repository at this point in the history
* Fix MALA transition energy

* Use a different logic.
  • Loading branch information
ksnxr authored Mar 31, 2024
1 parent 2e25624 commit f77297f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def transition_energy(state, new_state, step_size):
"""Transition energy to go from `state` to `new_state`"""
theta = jax.tree_util.tree_map(
lambda new_x, x, g: new_x - x - step_size * g,
new_state.position,
state.position,
state.logdensity_grad,
new_state.position,
new_state.logdensity_grad,
)
theta_dot = jax.tree_util.tree_reduce(
operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta)
)
return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot
return -new_state.logdensity + 0.25 * (1.0 / step_size) * theta_dot

compute_acceptance_ratio = proposal.compute_asymmetric_acceptance_ratio(
transition_energy
Expand Down

0 comments on commit f77297f

Please sign in to comment.