Skip to content

Commit

Permalink
switch to using chain state
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper committed Aug 6, 2024
1 parent 6c763e7 commit e99b8cb
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 32 deletions.
8 changes: 7 additions & 1 deletion blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,16 +334,22 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
if progress_bar:
print("Running window adaptation")
one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step))
start_state = ((init_state, init_adaptation_state), -1)
else:
one_step_ = jax.jit(one_step)
start_state = (init_state, init_adaptation_state)

keys = jax.random.split(rng_key, num_steps)
schedule = build_schedule(num_steps)
last_state, info = jax.lax.scan(
one_step_,
(init_state, init_adaptation_state),
start_state,
(jnp.arange(num_steps), keys, schedule),
)

if progress_bar:
last_state, _ = last_state

last_chain_state, last_warmup_state, *_ = last_state

step_size, inverse_mass_matrix = adapt_final(last_warmup_state)
Expand Down
55 changes: 28 additions & 27 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
from fastprogress.fastprogress import progress_bar
from jax import lax
from jax.experimental import io_callback
from jax.numpy import array


def progress_bar_scan(num_samples, print_rate=None):
"Progress bar for a JAX scan"
progress_bars = {}
idx_map = {}
idx_counter = 0
lock = Lock()

if print_rate is None:
Expand All @@ -34,46 +35,43 @@ def progress_bar_scan(num_samples, print_rate=None):
print_rate = 1 # if you run the sampler for less than 20 iterations

def _calc_chain_idx(iter_num):
iter_num = int(iter_num)
try:
idx = idx_map[iter_num]
except KeyError:
idx = 0
idx_map[iter_num] = 0

idx_map[iter_num] += 1
nonlocal idx_counter
with lock:
idx = idx_counter
idx_counter += 1
return idx

def _update_bar(arg):
with lock:
idx = _calc_chain_idx(arg)
if arg == 0:
progress_bars[idx] = progress_bar(range(num_samples))
progress_bars[idx].update(0)
progress_bars[idx].update_bar(arg + 1)
def _update_bar(arg, chain_id):
chain_id = int(chain_id)
if arg == 0:
chain_id = _calc_chain_idx(arg)
progress_bars[chain_id] = progress_bar(range(num_samples))
progress_bars[chain_id].update(0)

def _close_bar(arg):
with lock:
idx = _calc_chain_idx(arg)
progress_bars[idx].on_iter_end()
progress_bars[chain_id].update_bar(arg + 1)
return chain_id

def _close_bar(arg, chain_id):
progress_bars[int(chain_id)].on_iter_end()

def _update_progress_bar(iter_num):
def _update_progress_bar(iter_num, chain_id):
"Updates progress bar of a JAX scan or loop"

_ = lax.cond(
chain_id = lax.cond(
# update every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) | (iter_num == (num_samples - 1)),
lambda _: io_callback(_update_bar, None, iter_num),
lambda _: None,
lambda _: io_callback(_update_bar, array(0), iter_num, chain_id),
lambda _: chain_id,
operand=None,
)

_ = lax.cond(
iter_num == num_samples - 1,
lambda _: io_callback(_close_bar, None, iter_num + 1),
lambda _: io_callback(_close_bar, None, iter_num + 1, chain_id),
lambda _: None,
operand=None,
)
return chain_id

def _progress_bar_scan(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
Expand All @@ -87,8 +85,11 @@ def wrapper_progress_bar(carry, x):
iter_num, *_ = x
else:
iter_num = x
_update_progress_bar(iter_num)
return func(carry, x)
subcarry, chain_id = carry
chain_id = _update_progress_bar(iter_num, chain_id)
subcarry, y = func(subcarry, x)

return (subcarry, chain_id), y

return wrapper_progress_bar

Expand Down
14 changes: 10 additions & 4 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,19 @@ def one_step(average_and_state, xs, return_state):

one_step = jax.jit(partial(one_step, return_state=return_state_history))

xs = (jnp.arange(num_steps), keys)
if progress_bar:
one_step = progress_bar_scan(num_steps)(one_step)
(((_, average), final_state), _), history = lax.scan(
one_step,
(((0, expectation(transform(initial_state))), initial_state), -1),
xs,
)

xs = (jnp.arange(num_steps), keys)
((_, average), final_state), history = lax.scan(
one_step, ((0, expectation(transform(initial_state))), initial_state), xs
)
else:
((_, average), final_state), history = lax.scan(
one_step, ((0, expectation(transform(initial_state))), initial_state), xs
)

if not return_state_history:
return average, transform(final_state)
Expand Down

0 comments on commit e99b8cb

Please sign in to comment.