Skip to content

Commit

Permalink
renaming vars
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Aug 5, 2024
1 parent 0d29909 commit aa602ca
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average_update
from blackjax.util import pytree_size, incremental_value_update


class MCLMCAdaptationState(NamedTuple):
Expand Down
28 changes: 14 additions & 14 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,11 @@ def one_step(state, xs):
return final_state, history


def store_only_expectation_values(sampling_algorithm, state_transform= lambda x: x, exp_vals_transform= lambda x: x):
def store_only_expectation_values(sampling_algorithm, state_transform= lambda x: x, incremental_value_transform= lambda x: x):
"""Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same
kernel but only stores the streaming expectation values of some observables, not the full states; to save memory.
It saves exp_vals_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample.
It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample.
Example:
Expand All @@ -304,7 +304,7 @@ def store_only_expectation_values(sampling_algorithm, state_transform= lambda x:
num_steps = 4
integrator = map_integrator_type_to_integrator['mclmc'][integrator_type]
state_transform = lambda state: x.position
state_transform = lambda state: state.position
memory_efficient_sampling_alg, transform = store_only_expectation_values(
sampling_algorithm=sampling_alg,
state_transform=state_transform)
Expand All @@ -326,27 +326,27 @@ def init_fn(state):
averaging_state = (0., state_transform(state))
return (state, averaging_state)

def update_fn(rng_key, state_full):
state, averaging_state = state_full
def update_fn(rng_key, state_and_incremental_val):
state, averaging_state = state_and_incremental_val
state, info = sampling_algorithm.step(rng_key, state) # update the state with the sampling algorithm
averaging_state = streaming_average_update(state_transform(state), averaging_state) # update the expectation value with the Kalman filter
averaging_state = incremental_value_update(state_transform(state), averaging_state) # update the expectation value with the running average
return (state, averaging_state), info

def transform(full_state, info):
exp_vals = full_state[1][1]
return exp_vals_transform(exp_vals), info
def transform(state_and_incremental_val, info):
(state, (_, incremental_value)) = state_and_incremental_val
return incremental_value_transform(incremental_value), info

return SamplingAlgorithm(init_fn, update_fn), transform



def streaming_average_update(expectation, streaming_avg, weight=1.0, zero_prevention=0.0):
def incremental_value_update(expectation, incremental_val, weight=1.0, zero_prevention=0.0):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
expectation
the value of the expectation at the current timestep
streaming_avg
incremental_val
tuple of (total, average) where total is the sum of weights and average is the current average
weight
weight of the current state
Expand All @@ -358,11 +358,11 @@ def streaming_average_update(expectation, streaming_avg, weight=1.0, zero_preven
"""

flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = streaming_avg
total, average = incremental_val
flat_average, _ = ravel_pytree(average)
average = (total * flat_average + weight * flat_expectation) / (
total + weight + zero_prevention
)
total += weight
streaming_avg = (total, unravel_fn(average))
return streaming_avg
incremental_val = (total, unravel_fn(average))
return incremental_val

0 comments on commit aa602ca

Please sign in to comment.