From b956b2509301e1ff491973726cdec07adc89881b Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 3 Jun 2024 09:28:00 +0200 Subject: [PATCH] Change isokinetic_integrator generation API --- blackjax/mcmc/integrators.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index fd8ee1bfa..1d4b95a09 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -374,9 +374,8 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, *args, **kwargs + logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0 ) -> GeneralIntegrator: - sqrt_diag_cov = kwargs.get("sqrt_diag_cov", 1.0) position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( esh_dynamics_momentum_update_one_step(sqrt_diag_cov),