From 3fbdac64af810b69385c570af5c055c1cc268e14 Mon Sep 17 00:00:00 2001 From: Reuben Date: Mon, 3 Jun 2024 01:30:03 -0400 Subject: [PATCH] MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (#687) --- blackjax/adaptation/window_adaptation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index dd3e7b282..63c54bad0 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp +import blackjax.mcmc as mcmc from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.mass_matrix import ( MassMatrixAdaptationState, @@ -249,10 +250,11 @@ def window_adaptation( target_acceptance_rate: float = 0.80, progress_bar: bool = False, adaptation_info_fn: Callable = return_all_adapt_info, + integrator=mcmc.integrators.velocity_verlet, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of - algorithms in the HMC family. See Blackjax.hmc_family + algorithms in the HMC fmaily. See Blackjax.hmc_family Algorithms in the HMC family on a euclidean manifold depend on the value of at least two parameters: the step size, related to the trajectory @@ -294,7 +296,7 @@ def window_adaptation( """ - mcmc_kernel = algorithm.build_kernel() + mcmc_kernel = algorithm.build_kernel(integrator) adapt_init, adapt_step, adapt_final = base( is_mass_matrix_diagonal,