Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply function on pytree directly. #692

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def step(iteration_state, weight_and_key):
x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
expectation=jnp.array([x, jnp.square(x)]),
streaming_avg=streaming_avg,
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)
Expand Down
31 changes: 15 additions & 16 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,31 +240,30 @@ def one_step(average_and_state, xs, return_state):


def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
expectation
the value of the expectation at the current timestep
streaming_avg
tuple of (total, average) where total is the sum of weights and average is
the current average
current_value
the current value of the function that we want to take average of
previous_weight_and_average
tuple of (previous_weight, previous_average) where previous_weight is the
sum of weights and average is the current estimated average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new streaming average
new total weight and streaming average
"""

flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = streaming_avg
flat_average, _ = ravel_pytree(average)
average = (total * flat_average + weight * flat_expectation) / (
total + weight + zero_prevention
previous_weight, previous_average = previous_weight_and_average
current_weight = previous_weight + weight
current_average = jax.tree.map(
lambda x, avg: (previous_weight * avg + weight * x)
/ (current_weight + zero_prevention),
current_value,
previous_average,
)
total += weight
streaming_avg = (total, unravel_fn(average))
return streaming_avg
return current_weight, current_average
Loading