Skip to content

Commit

Permalink
burn in and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Aug 8, 2024
1 parent 3f9947c commit e144ce7
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,20 @@ def one_step(state, xs):

if progress_bar:
one_step = progress_bar_scan(num_steps)(one_step)
xs = jnp.arange(num_steps), keys
final_state, history = lax.scan(one_step, (initial_state, -1), xs)
else:
xs = jnp.arange(num_steps), keys
final_state, history = lax.scan(one_step, initial_state, xs)

xs = jnp.arange(num_steps), keys
final_state, history = lax.scan(one_step, initial_state, xs)
return final_state, history


def store_only_expectation_values(
sampling_algorithm,
state_transform=lambda x: x,
incremental_value_transform=lambda x: x,
burn_in=0,
):
"""Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same
kernel but only stores the streaming expectation values of some observables, not the full states; to save memory.
Expand Down Expand Up @@ -262,8 +266,14 @@ def update_fn(rng_key, state_and_incremental_val):
rng_key, state
) # update the state with the sampling algorithm
averaging_state = incremental_value_update(
state_transform(state), averaging_state
) # update the expectation value with the running average
state_transform(state),
averaging_state,
weight=(
averaging_state[0] >= burn_in
), # If we want to eliminate some number of steps as a burn-in
zero_prevention=1e-10 * (burn_in > 0),
)
# update the expectation value with the running average
return (state, averaging_state), info

def transform(state_and_incremental_val, info):
Expand Down

0 comments on commit e144ce7

Please sign in to comment.