From 75d24d3d28d165f60a49ff55aa13a928dbcfb946 Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Sat, 3 Aug 2024 02:44:38 -0700 Subject: [PATCH 1/2] storing only expectation values --- blackjax/util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/blackjax/util.py b/blackjax/util.py index 888bed515..8647a983f 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -283,6 +283,7 @@ def one_step(state, xs): final_state, history = lax.scan(one_step, initial_state, xs) return final_state, history + def store_only_expectation_values(sampling_algorithm, state_transform= lambda x: x, exp_vals_transform= lambda x: x): """Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same kernel but only stores the streaming expectation values of some observables, not the full states; to save memory. From 96aeb94b85c0fbf6006e6b7b19766db2e42c8215 Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Mon, 5 Aug 2024 10:00:01 -0700 Subject: [PATCH 2/2] fixed memory efficient sampling --- blackjax/util.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 8647a983f..072fc641c 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -246,8 +246,6 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - expectation - A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states. return_state_history if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. @@ -306,7 +304,7 @@ def store_only_expectation_values(sampling_algorithm, state_transform= lambda x: num_steps = 4 integrator = map_integrator_type_to_integrator['mclmc'][integrator_type] - state_transform = lambda x: x.position + state_transform = lambda state: x.position memory_efficient_sampling_alg, transform = store_only_expectation_values( sampling_algorithm=sampling_alg, state_transform=state_transform) @@ -332,9 +330,11 @@ def update_fn(rng_key, state_full): state, averaging_state = state_full state, info = sampling_algorithm.step(rng_key, state) # update the state with the sampling algorithm averaging_state = streaming_average_update(state_transform(state), averaging_state) # update the expectation value with the Kalman filter - return (state, averaging_state), (averaging_state, info) - - transform= lambda full_state, info: exp_vals_transform(full_state[1][1]) + return (state, averaging_state), info + + def transform(full_state, info): + exp_vals = full_state[1][1] + return exp_vals_transform(exp_vals), info return SamplingAlgorithm(init_fn, update_fn), transform