From 27dfc9e30dd5b8c8f0771f5f6f3cbf7ec3f4c7ec Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Wed, 7 Aug 2024 06:22:43 -0700 Subject: [PATCH] Enable progress bar under pmap (#712) * enable pmap progbar * fix bar creation * add locking * fix formatting * switch to using chain state --- blackjax/adaptation/window_adaptation.py | 8 +++- blackjax/progress_bar.py | 55 ++++++++++++++---------- blackjax/util.py | 14 ++++-- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 63c54bad0..cb02eb2c4 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -334,16 +334,22 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): if progress_bar: print("Running window adaptation") one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step)) + start_state = ((init_state, init_adaptation_state), -1) else: one_step_ = jax.jit(one_step) + start_state = (init_state, init_adaptation_state) keys = jax.random.split(rng_key, num_steps) schedule = build_schedule(num_steps) last_state, info = jax.lax.scan( one_step_, - (init_state, init_adaptation_state), + start_state, (jnp.arange(num_steps), keys, schedule), ) + + if progress_bar: + last_state, _ = last_state + last_chain_state, last_warmup_state, *_ = last_state step_size, inverse_mass_matrix = adapt_final(last_warmup_state) diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index ac509b9b6..188ab7dba 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -14,14 +14,19 @@ """Progress bar decorators for use with step functions. Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`. """ +from threading import Lock + from fastprogress.fastprogress import progress_bar from jax import lax from jax.experimental import io_callback +from jax.numpy import array def progress_bar_scan(num_samples, print_rate=None): "Progress bar for a JAX scan" progress_bars = {} + idx_counter = 0 + lock = Lock() if print_rate is None: if num_samples > 20: @@ -29,41 +34,44 @@ def progress_bar_scan(num_samples, print_rate=None): else: print_rate = 1 # if you run the sampler for less than 20 iterations - def _define_bar(arg): - del arg - progress_bars[0] = progress_bar(range(num_samples)) - progress_bars[0].update(0) + def _calc_chain_idx(iter_num): + nonlocal idx_counter + with lock: + idx = idx_counter + idx_counter += 1 + return idx + + def _update_bar(arg, chain_id): + chain_id = int(chain_id) + if arg == 0: + chain_id = _calc_chain_idx(arg) + progress_bars[chain_id] = progress_bar(range(num_samples)) + progress_bars[chain_id].update(0) - def _update_bar(arg): - progress_bars[0].update_bar(arg + 1) + progress_bars[chain_id].update_bar(arg + 1) + return chain_id - def _close_bar(arg): - del arg - progress_bars[0].on_iter_end() + def _close_bar(arg, chain_id): + progress_bars[int(chain_id)].on_iter_end() - def _update_progress_bar(iter_num): + def _update_progress_bar(iter_num, chain_id): "Updates progress bar of a JAX scan or loop" - _ = lax.cond( - iter_num == 0, - lambda _: io_callback(_define_bar, None, iter_num), - lambda _: None, - operand=None, - ) - _ = lax.cond( + chain_id = lax.cond( # update every multiple of `print_rate` except at the end (iter_num % print_rate == 0) | (iter_num == (num_samples - 1)), - lambda _: io_callback(_update_bar, None, iter_num), - lambda _: None, + lambda _: io_callback(_update_bar, array(0), iter_num, chain_id), + lambda _: chain_id, operand=None, ) _ = lax.cond( iter_num == num_samples - 1, - lambda _: io_callback(_close_bar, None, None), + lambda _: io_callback(_close_bar, None, iter_num + 1, chain_id), lambda _: None, operand=None, ) + return chain_id def _progress_bar_scan(func): """Decorator that adds a progress bar to `body_fun` used in `lax.scan`. @@ -77,8 +85,11 @@ def wrapper_progress_bar(carry, x): iter_num, *_ = x else: iter_num = x - _update_progress_bar(iter_num) - return func(carry, x) + subcarry, chain_id = carry + chain_id = _update_progress_bar(iter_num, chain_id) + subcarry, y = func(subcarry, x) + + return (subcarry, chain_id), y return wrapper_progress_bar diff --git a/blackjax/util.py b/blackjax/util.py index cdb9f4c91..78a7c0633 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -224,13 +224,19 @@ def one_step(average_and_state, xs, return_state): one_step = jax.jit(partial(one_step, return_state=return_state_history)) + xs = (jnp.arange(num_steps), keys) if progress_bar: one_step = progress_bar_scan(num_steps)(one_step) + (((_, average), final_state), _), history = lax.scan( + one_step, + (((0, expectation(transform(initial_state))), initial_state), -1), + xs, + ) - xs = (jnp.arange(num_steps), keys) - ((_, average), final_state), history = lax.scan( - one_step, ((0, expectation(transform(initial_state))), initial_state), xs - ) + else: + ((_, average), final_state), history = lax.scan( + one_step, ((0, expectation(transform(initial_state))), initial_state), xs + ) if not return_state_history: return average, transform(final_state)