diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index fd8ee1bfa..1d4b95a09 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -374,9 +374,8 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, *args, **kwargs + logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0 ) -> GeneralIntegrator: - sqrt_diag_cov = kwargs.get("sqrt_diag_cov", 1.0) position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( esh_dynamics_momentum_update_one_step(sqrt_diag_cov),