Skip to content

Commit

Permalink
Apply function on pytree directly.
Browse files Browse the repository at this point in the history
Avoiding unnecssary unpacking
  • Loading branch information
junpenglao committed Jun 5, 2024
1 parent 83bc3a0 commit bd920f5
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,31 +240,30 @@ def one_step(average_and_state, xs, return_state):


def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
expectation
the value of the expectation at the current timestep
streaming_avg
tuple of (total, average) where total is the sum of weights and average is
the current average
current_value
the current value of the function that we want to take average of
previous_weight_and_average
tuple of (previous_weight, previous_average) where previous_weight is the
sum of weights and average is the current estimated average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new streaming average
new total weight and streaming average
"""

flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = streaming_avg
flat_average, _ = ravel_pytree(average)
average = (total * flat_average + weight * flat_expectation) / (
total + weight + zero_prevention
previous_weight, previous_average = previous_weight_and_average
current_weight = previous_weight + weight
current_average = jax.tree.map(
lambda x, avg: (previous_weight * avg + weight * x)
/ (current_weight + zero_prevention),
current_value,
previous_average,
)
total += weight
streaming_avg = (total, unravel_fn(average))
return streaming_avg
return current_weight, current_average

0 comments on commit bd920f5

Please sign in to comment.