From de4d4a6e7741211e72de208368914cf59de2f02c Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Thu, 8 Aug 2024 11:55:18 -0700 Subject: [PATCH] remove labels --- blackjax/adaptation/window_adaptation.py | 17 +++++------------ blackjax/progress_bar.py | 14 ++++++++++++++ blackjax/util.py | 20 +++++++------------- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index cb02eb2c4..69a098325 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -28,7 +28,7 @@ dual_averaging_adaptation, ) from blackjax.base import AdaptationAlgorithm -from blackjax.progress_bar import progress_bar_scan +from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, PRNGKey from blackjax.util import pytree_size @@ -333,23 +333,16 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): if progress_bar: print("Running window adaptation") - one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step)) - start_state = ((init_state, init_adaptation_state), -1) - else: - one_step_ = jax.jit(one_step) - start_state = (init_state, init_adaptation_state) - + scan_fn = gen_scan_fn(num_steps, progress_bar=progress_bar) + start_state = (init_state, init_adaptation_state) keys = jax.random.split(rng_key, num_steps) schedule = build_schedule(num_steps) - last_state, info = jax.lax.scan( - one_step_, + last_state, info = scan_fn( + one_step, start_state, (jnp.arange(num_steps), keys, schedule), ) - if progress_bar: - last_state, _ = last_state - last_chain_state, last_warmup_state, *_ = last_state step_size, inverse_mass_matrix = adapt_final(last_warmup_state) diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index 188ab7dba..a1425df88 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -94,3 +94,17 @@ def wrapper_progress_bar(carry, x): return wrapper_progress_bar return _progress_bar_scan + + +def gen_scan_fn(num_samples, progress_bar, print_rate=None): + if progress_bar: + + def scan_wrap(f, init, *args, **kwargs): + func = progress_bar_scan(num_samples, print_rate)(f) + carry = (init, -1) + (last_state, _), output = lax.scan(func, carry, *args, **kwargs) + return last_state, output + + return scan_wrap + else: + return lax.scan diff --git a/blackjax/util.py b/blackjax/util.py index 78a7c0633..9f4d6f9c7 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -11,7 +11,7 @@ from jax.tree_util import tree_leaves from blackjax.base import SamplingAlgorithm, VIAlgorithm -from blackjax.progress_bar import progress_bar_scan +from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -225,18 +225,12 @@ def one_step(average_and_state, xs, return_state): one_step = jax.jit(partial(one_step, return_state=return_state_history)) xs = (jnp.arange(num_steps), keys) - if progress_bar: - one_step = progress_bar_scan(num_steps)(one_step) - (((_, average), final_state), _), history = lax.scan( - one_step, - (((0, expectation(transform(initial_state))), initial_state), -1), - xs, - ) - - else: - ((_, average), final_state), history = lax.scan( - one_step, ((0, expectation(transform(initial_state))), initial_state), xs - ) + scan_fn = gen_scan_fn(num_steps, progress_bar) + ((_, average), final_state), history = scan_fn( + one_step, + ((0, expectation(transform(initial_state))), initial_state), + xs, + ) if not return_state_history: return average, transform(final_state)