diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 5ec95edf0..258dbf29e 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -153,7 +153,7 @@ def progressive_biased_sampling( biases the transition away from the trajectory's initial state. """ - p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), a_max=1) + p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), max=1) do_accept = jax.random.bernoulli(rng_key, p_accept) new_weight = jnp.logaddexp(proposal.weight, new_proposal.weight) new_sum_log_p_accept = jnp.logaddexp( @@ -224,7 +224,7 @@ def static_binomial_sampling( then the new proposal is accepted with probability 1. """ - p_accept = jnp.clip(jnp.exp(log_p_accept), a_max=1) + p_accept = jnp.clip(jnp.exp(log_p_accept), max=1) do_accept = jax.random.bernoulli(rng_key, p_accept) info = do_accept, p_accept, None return ( @@ -253,7 +253,7 @@ def nonreversible_slice_sampling( to the accept/reject step of a current state and new proposal. """ - p_accept = jnp.clip(jnp.exp(delta_energy), a_max=1) + p_accept = jnp.clip(jnp.exp(delta_energy), max=1) do_accept = jnp.log(jnp.abs(slice)) <= delta_energy slice_next = slice * (jnp.exp(-delta_energy) * do_accept + (1 - do_accept)) info = do_accept, p_accept, slice_next