diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index cacd0b4a6..63c54bad0 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import blackjax.mcmc as mcmc -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.mass_matrix import ( MassMatrixAdaptationState, mass_matrix_adaptation, @@ -249,11 +249,12 @@ def window_adaptation( initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, progress_bar: bool = False, + adaptation_info_fn: Callable = return_all_adapt_info, integrator=mcmc.integrators.velocity_verlet, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of - algorithms in the HMC fmaily. + algorithms in the HMC fmaily. See Blackjax.hmc_family Algorithms in the HMC family on a euclidean manifold depend on the value of at least two parameters: the step size, related to the trajectory @@ -280,6 +281,11 @@ def window_adaptation( The acceptance rate that we target during step size adaptation. progress_bar Whether we should display a progress bar. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. @@ -318,7 +324,7 @@ def one_step(carry, xs): return ( (new_state, new_adaptation_state), - AdaptationInfo(new_state, info, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):