diff --git a/blackjax/util.py b/blackjax/util.py index 8647a983f..072fc641c 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -246,8 +246,6 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - expectation - A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states. return_state_history if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. @@ -306,7 +304,7 @@ def store_only_expectation_values(sampling_algorithm, state_transform= lambda x: num_steps = 4 integrator = map_integrator_type_to_integrator['mclmc'][integrator_type] - state_transform = lambda x: x.position + state_transform = lambda state: x.position memory_efficient_sampling_alg, transform = store_only_expectation_values( sampling_algorithm=sampling_alg, state_transform=state_transform) @@ -332,9 +330,11 @@ def update_fn(rng_key, state_full): state, averaging_state = state_full state, info = sampling_algorithm.step(rng_key, state) # update the state with the sampling algorithm averaging_state = streaming_average_update(state_transform(state), averaging_state) # update the expectation value with the Kalman filter - return (state, averaging_state), (averaging_state, info) - - transform= lambda full_state, info: exp_vals_transform(full_state[1][1]) + return (state, averaging_state), info + + def transform(full_state, info): + exp_vals = full_state[1][1] + return exp_vals_transform(exp_vals), info return SamplingAlgorithm(init_fn, update_fn), transform