diff --git a/blackjax/util.py b/blackjax/util.py index 3cf5c90fb..d761189cf 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -200,14 +200,8 @@ def one_step(state, xs): state, info = inference_algorithm.step(rng_key, state) return state, transform(state, info) - xs = (jnp.arange(num_steps), keys) if progress_bar: one_step = progress_bar_scan(num_steps)(one_step) - (((_, average), final_state), _), history = lax.scan( - one_step, - (((0, expectation(transform(initial_state))), initial_state), -1), - xs, - ) xs = jnp.arange(num_steps), keys final_state, history = lax.scan(one_step, initial_state, xs)