Skip to content

Commit

Permalink
FIX MERGE
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed May 19, 2024
1 parent 9dd740f commit 4d03b89
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.numpy as jnp

import blackjax.mcmc as mcmc
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.adaptation.mass_matrix import (
MassMatrixAdaptationState,
mass_matrix_adaptation,
Expand Down Expand Up @@ -249,11 +249,12 @@ def window_adaptation(
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
adaptation_info_fn: Callable = return_all_adapt_info,
integrator=mcmc.integrators.velocity_verlet,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
algorithms in the HMC fmaily.
algorithms in the HMC fmaily. See Blackjax.hmc_family
Algorithms in the HMC family on a euclidean manifold depend on the value of
at least two parameters: the step size, related to the trajectory
Expand All @@ -280,6 +281,11 @@ def window_adaptation(
The acceptance rate that we target during step size adaptation.
progress_bar
Whether we should display a progress bar.
adaptation_info_fn
Function to select the adaptation info returned. See return_all_adapt_info
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
information is saved - this can result in excessive memory usage if the
information is unused.
**extra_parameters
The extra parameters to pass to the algorithm, e.g. the number of
integration steps for HMC.
Expand Down Expand Up @@ -318,7 +324,7 @@ def one_step(carry, xs):

return (
(new_state, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
adaptation_info_fn(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
Expand Down

0 comments on commit 4d03b89

Please sign in to comment.