diff --git a/blackjax/util.py b/blackjax/util.py index d761189cf..b5bd1a89e 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -202,9 +202,12 @@ def one_step(state, xs): if progress_bar: one_step = progress_bar_scan(num_steps)(one_step) + xs = jnp.arange(num_steps), keys + final_state, history = lax.scan(one_step, (initial_state, -1), xs) + else: + xs = jnp.arange(num_steps), keys + final_state, history = lax.scan(one_step, initial_state, xs) - xs = jnp.arange(num_steps), keys - final_state, history = lax.scan(one_step, initial_state, xs) return final_state, history @@ -212,6 +215,7 @@ def store_only_expectation_values( sampling_algorithm, state_transform=lambda x: x, incremental_value_transform=lambda x: x, + burn_in=0, ): """Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same kernel but only stores the streaming expectation values of some observables, not the full states; to save memory. @@ -262,8 +266,14 @@ def update_fn(rng_key, state_and_incremental_val): rng_key, state ) # update the state with the sampling algorithm averaging_state = incremental_value_update( - state_transform(state), averaging_state - ) # update the expectation value with the running average + state_transform(state), + averaging_state, + weight=( + averaging_state[0] >= burn_in + ), # If we want to eliminate some number of steps as a burn-in + zero_prevention=1e-10 * (burn_in > 0), + ) + # update the expectation value with the running average return (state, averaging_state), info def transform(state_and_incremental_val, info):