Skip to content

Commit

Permalink
convert to bit twiddling
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper committed Jun 19, 2024
1 parent 3353209 commit 18fffac
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions blackjax/mcmc/termination.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,10 @@ def _leaf_idx_to_ckpt_idxs(n):
"""Find the checkpoint id from a step number."""
# computes the number of non-zero bits except the last bit
# e.g. 6 -> 2, 7 -> 2, 13 -> 2
_, idx_max = jax.lax.while_loop(
lambda nc: nc[0] > 0,
lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)),
(n >> 1, 0),
)
idx_max = jnp.bitwise_count(n >> 1).astype(jnp.int32)
# computes the number of contiguous last non-zero bits
# e.g. 6 -> 0, 7 -> 3, 13 -> 1
_, num_subtrees = jax.lax.while_loop(
lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0)
)
num_subtrees = jnp.bitwise_count((~n & (n + 1)) - 1).astype(jnp.int32)
idx_min = idx_max - num_subtrees + 1
return idx_min, idx_max

Expand Down

0 comments on commit 18fffac

Please sign in to comment.