Skip to content

Commit

Permalink
Merge branch 'kalman_reuben' of github.com:reubenharry/blackjax into …
Browse files Browse the repository at this point in the history
…kalman_reuben
  • Loading branch information
reubenharry committed Aug 5, 2024
2 parents e4a7668 + 96aeb94 commit 0d29909
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,6 @@ def run_inference_algorithm(
A transformation of the trace of states to be returned. This is useful for
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
expectation
A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states.
return_state_history
if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck.
Expand Down Expand Up @@ -283,6 +281,7 @@ def one_step(state, xs):
final_state, history = lax.scan(one_step, initial_state, xs)
return final_state, history


def store_only_expectation_values(sampling_algorithm, state_transform= lambda x: x, exp_vals_transform= lambda x: x):
"""Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same
kernel but only stores the streaming expectation values of some observables, not the full states; to save memory.
Expand All @@ -305,7 +304,7 @@ def store_only_expectation_values(sampling_algorithm, state_transform= lambda x:
num_steps = 4
integrator = map_integrator_type_to_integrator['mclmc'][integrator_type]
state_transform = lambda x: x.position
state_transform = lambda state: x.position
memory_efficient_sampling_alg, transform = store_only_expectation_values(
sampling_algorithm=sampling_alg,
state_transform=state_transform)
Expand All @@ -331,9 +330,11 @@ def update_fn(rng_key, state_full):
state, averaging_state = state_full
state, info = sampling_algorithm.step(rng_key, state) # update the state with the sampling algorithm
averaging_state = streaming_average_update(state_transform(state), averaging_state) # update the expectation value with the Kalman filter
return (state, averaging_state), (averaging_state, info)

transform= lambda full_state, info: exp_vals_transform(full_state[1][1])
return (state, averaging_state), info

def transform(full_state, info):
exp_vals = full_state[1][1]
return exp_vals_transform(exp_vals), info

return SamplingAlgorithm(init_fn, update_fn), transform

Expand Down

0 comments on commit 0d29909

Please sign in to comment.