From 3f9947cc2a695427c4461f434940cee6e859c593 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 7 Aug 2024 13:47:58 -0400 Subject: [PATCH] merge main --- blackjax/util.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 3cf5c90fb..d761189cf 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -200,14 +200,8 @@ def one_step(state, xs): state, info = inference_algorithm.step(rng_key, state) return state, transform(state, info) - xs = (jnp.arange(num_steps), keys) if progress_bar: one_step = progress_bar_scan(num_steps)(one_step) - (((_, average), final_state), _), history = lax.scan( - one_step, - (((0, expectation(transform(initial_state))), initial_state), -1), - xs, - ) xs = jnp.arange(num_steps), keys final_state, history = lax.scan(one_step, initial_state, xs)