Skip to content

Commit

Permalink
fixed memory efficient sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobRobnik committed Aug 5, 2024
1 parent 75d24d3 commit 96aeb94
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,6 @@ def run_inference_algorithm(
A transformation of the trace of states to be returned. This is useful for
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
expectation
A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states.
return_state_history
if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck.
Expand Down Expand Up @@ -306,7 +304,7 @@ def store_only_expectation_values(sampling_algorithm, state_transform= lambda x:
num_steps = 4
integrator = map_integrator_type_to_integrator['mclmc'][integrator_type]
state_transform = lambda x: x.position
state_transform = lambda state: x.position
memory_efficient_sampling_alg, transform = store_only_expectation_values(
sampling_algorithm=sampling_alg,
state_transform=state_transform)
Expand All @@ -332,9 +330,11 @@ def update_fn(rng_key, state_full):
state, averaging_state = state_full
state, info = sampling_algorithm.step(rng_key, state) # update the state with the sampling algorithm
averaging_state = streaming_average_update(state_transform(state), averaging_state) # update the expectation value with the Kalman filter
return (state, averaging_state), (averaging_state, info)

transform= lambda full_state, info: exp_vals_transform(full_state[1][1])
return (state, averaging_state), info

def transform(full_state, info):
exp_vals = full_state[1][1]
return exp_vals_transform(exp_vals), info

return SamplingAlgorithm(init_fn, update_fn), transform

Expand Down

0 comments on commit 96aeb94

Please sign in to comment.