From e99b8cbf919c7706b966a3a13e7b2a59ce0d6d29 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Tue, 6 Aug 2024 11:51:23 -0700 Subject: [PATCH] switch to using chain state --- blackjax/adaptation/window_adaptation.py | 8 +++- blackjax/progress_bar.py | 55 ++++++++++++------------ blackjax/util.py | 14 ++++-- 3 files changed, 45 insertions(+), 32 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 63c54bad0..cb02eb2c4 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -334,16 +334,22 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): if progress_bar: print("Running window adaptation") one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step)) + start_state = ((init_state, init_adaptation_state), -1) else: one_step_ = jax.jit(one_step) + start_state = (init_state, init_adaptation_state) keys = jax.random.split(rng_key, num_steps) schedule = build_schedule(num_steps) last_state, info = jax.lax.scan( one_step_, - (init_state, init_adaptation_state), + start_state, (jnp.arange(num_steps), keys, schedule), ) + + if progress_bar: + last_state, _ = last_state + last_chain_state, last_warmup_state, *_ = last_state step_size, inverse_mass_matrix = adapt_final(last_warmup_state) diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index f1dc5ad70..188ab7dba 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -19,12 +19,13 @@ from fastprogress.fastprogress import progress_bar from jax import lax from jax.experimental import io_callback +from jax.numpy import array def progress_bar_scan(num_samples, print_rate=None): "Progress bar for a JAX scan" progress_bars = {} - idx_map = {} + idx_counter = 0 lock = Lock() if print_rate is None: @@ -34,46 +35,43 @@ def progress_bar_scan(num_samples, print_rate=None): print_rate = 1 # if you run the sampler for less than 20 iterations def _calc_chain_idx(iter_num): - iter_num = int(iter_num) - try: - idx = idx_map[iter_num] - except KeyError: - idx = 0 - idx_map[iter_num] = 0 - - idx_map[iter_num] += 1 + nonlocal idx_counter + with lock: + idx = idx_counter + idx_counter += 1 return idx - def _update_bar(arg): - with lock: - idx = _calc_chain_idx(arg) - if arg == 0: - progress_bars[idx] = progress_bar(range(num_samples)) - progress_bars[idx].update(0) - progress_bars[idx].update_bar(arg + 1) + def _update_bar(arg, chain_id): + chain_id = int(chain_id) + if arg == 0: + chain_id = _calc_chain_idx(arg) + progress_bars[chain_id] = progress_bar(range(num_samples)) + progress_bars[chain_id].update(0) - def _close_bar(arg): - with lock: - idx = _calc_chain_idx(arg) - progress_bars[idx].on_iter_end() + progress_bars[chain_id].update_bar(arg + 1) + return chain_id + + def _close_bar(arg, chain_id): + progress_bars[int(chain_id)].on_iter_end() - def _update_progress_bar(iter_num): + def _update_progress_bar(iter_num, chain_id): "Updates progress bar of a JAX scan or loop" - _ = lax.cond( + chain_id = lax.cond( # update every multiple of `print_rate` except at the end (iter_num % print_rate == 0) | (iter_num == (num_samples - 1)), - lambda _: io_callback(_update_bar, None, iter_num), - lambda _: None, + lambda _: io_callback(_update_bar, array(0), iter_num, chain_id), + lambda _: chain_id, operand=None, ) _ = lax.cond( iter_num == num_samples - 1, - lambda _: io_callback(_close_bar, None, iter_num + 1), + lambda _: io_callback(_close_bar, None, iter_num + 1, chain_id), lambda _: None, operand=None, ) + return chain_id def _progress_bar_scan(func): """Decorator that adds a progress bar to `body_fun` used in `lax.scan`. @@ -87,8 +85,11 @@ def wrapper_progress_bar(carry, x): iter_num, *_ = x else: iter_num = x - _update_progress_bar(iter_num) - return func(carry, x) + subcarry, chain_id = carry + chain_id = _update_progress_bar(iter_num, chain_id) + subcarry, y = func(subcarry, x) + + return (subcarry, chain_id), y return wrapper_progress_bar diff --git a/blackjax/util.py b/blackjax/util.py index cdb9f4c91..78a7c0633 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -224,13 +224,19 @@ def one_step(average_and_state, xs, return_state): one_step = jax.jit(partial(one_step, return_state=return_state_history)) + xs = (jnp.arange(num_steps), keys) if progress_bar: one_step = progress_bar_scan(num_steps)(one_step) + (((_, average), final_state), _), history = lax.scan( + one_step, + (((0, expectation(transform(initial_state))), initial_state), -1), + xs, + ) - xs = (jnp.arange(num_steps), keys) - ((_, average), final_state), history = lax.scan( - one_step, ((0, expectation(transform(initial_state))), initial_state), xs - ) + else: + ((_, average), final_state), history = lax.scan( + one_step, ((0, expectation(transform(initial_state))), initial_state), xs + ) if not return_state_history: return average, transform(final_state)