Skip to content

Commit

Permalink
Fix deprecated call to jnp.clip (#664)
Browse files Browse the repository at this point in the history
  • Loading branch information
GaetanLepage committed May 8, 2024
1 parent 1bc6f93 commit 3f92393
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def progressive_biased_sampling(
biases the transition away from the trajectory's initial state.
"""
p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), a_max=1)
p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), max=1)
do_accept = jax.random.bernoulli(rng_key, p_accept)
new_weight = jnp.logaddexp(proposal.weight, new_proposal.weight)
new_sum_log_p_accept = jnp.logaddexp(
Expand Down Expand Up @@ -224,7 +224,7 @@ def static_binomial_sampling(
then the new proposal is accepted with probability 1.
"""
p_accept = jnp.clip(jnp.exp(log_p_accept), a_max=1)
p_accept = jnp.clip(jnp.exp(log_p_accept), max=1)
do_accept = jax.random.bernoulli(rng_key, p_accept)
info = do_accept, p_accept, None
return (
Expand Down Expand Up @@ -253,7 +253,7 @@ def nonreversible_slice_sampling(
to the accept/reject step of a current state and new proposal.
"""
p_accept = jnp.clip(jnp.exp(delta_energy), a_max=1)
p_accept = jnp.clip(jnp.exp(delta_energy), max=1)
do_accept = jnp.log(jnp.abs(slice)) <= delta_energy
slice_next = slice * (jnp.exp(-delta_energy) * do_accept + (1 - do_accept))
info = do_accept, p_accept, slice_next
Expand Down

0 comments on commit 3f92393

Please sign in to comment.