diff --git a/.gitignore b/.gitignore index d9186a6e9..25b11a123 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,6 @@ # Created by https://www.gitignore.io/api/python # Edit at https://www.gitignore.io/?templates=python -explore.py - ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 1cc698e8f..58e7ed810 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -24,12 +24,15 @@ __all__ = [ "mclachlan", + "omelyan", "velocity_verlet", "yoshida", - "implicit_midpoint", - "isokinetic_leapfrog", + "with_isokinetic_maruyama", + "isokinetic_velocity_verlet", "isokinetic_mclachlan", + "isokinetic_omelyan", "isokinetic_yoshida", + "implicit_midpoint", ] @@ -70,7 +73,7 @@ def generalized_two_stage_integrator( .. math:: \\frac{d}{dt}f = (O_1+O_2)f - The leapfrog operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}` + The velocity_verlet operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}` by :math:`e^{\\epsilon O_1/2}e^{\\epsilon O_2}e^{\\epsilon O_1/2}`. In a standard Hamiltonian, the forms of :math:`e^{\\epsilon O_2}` and @@ -210,7 +213,7 @@ def format_euclidean_state_output( return IntegratorState(position, momentum, logdensity, logdensity_grad) -def generate_euclidean_integrator(cofficients): +def generate_euclidean_integrator(coefficients): """Generate symplectic integrator for solving a Hamiltonian system. The resulting integrator is volume-preserve and preserves the symplectic structure @@ -225,7 +228,7 @@ def euclidean_integrator( one_step = generalized_two_stage_integrator( momentum_update_fn, position_update_fn, - cofficients, + coefficients, format_output_fn=format_euclidean_state_output, ) return one_step @@ -251,8 +254,8 @@ def euclidean_integrator( of the kinetic energy. We are trading accuracy in exchange, and it is not clear whether this is the right tradeoff. """ -velocity_verlet_cofficients = [0.5, 1.0, 0.5] -velocity_verlet = generate_euclidean_integrator(velocity_verlet_cofficients) +velocity_verlet_coefficients = [0.5, 1.0, 0.5] +velocity_verlet = generate_euclidean_integrator(velocity_verlet_coefficients) """ Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. @@ -268,8 +271,8 @@ def euclidean_integrator( b1 = 0.1931833275037836 a1 = 0.5 b2 = 1 - 2 * b1 -mclachlan_cofficients = [b1, a1, b2, a1, b1] -mclachlan = generate_euclidean_integrator(mclachlan_cofficients) +mclachlan_coefficients = [b1, a1, b2, a1, b1] +mclachlan = generate_euclidean_integrator(mclachlan_coefficients) """ Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` @@ -284,8 +287,20 @@ def euclidean_integrator( a1 = 0.29619504261126 b2 = 0.5 - b1 a2 = 1 - 2 * a1 -yoshida_cofficients = [b1, a1, b2, a2, b2, a1, b1] -yoshida = generate_euclidean_integrator(yoshida_cofficients) +yoshida_coefficients = [b1, a1, b2, a2, b2, a1, b1] +yoshida = generate_euclidean_integrator(yoshida_coefficients) + +"""11 stage Omelyan integrator [I.P. Omelyan, I.M. Mryglod and R. Folk, Comput. Phys. Commun. 151 (2003) 272.], +4MN5FV in [Takaishi, Tetsuya, and Philippe De Forcrand. "Testing and tuning symplectic integrators for the hybrid Monte Carlo algorithm in lattice QCD." Physical Review E 73.3 (2006): 036706.] +popular in LQCD""" +b1 = 0.08398315262876693 +a1 = 0.2539785108410595 +b2 = 0.6822365335719091 +a2 = -0.03230286765269967 +b3 = 0.5 - b1 - b2 +a3 = 1 - 2 * (a1 + a2) +omelyan_coefficients = [b1, a1, b2, a2, b3, a3, b3, a2, b2, a1, b1] +omelyan = generate_euclidean_integrator(omelyan_coefficients) # Intergrators with non Euclidean updates @@ -372,9 +387,12 @@ def isokinetic_integrator( return isokinetic_integrator -isokinetic_leapfrog = generate_isokinetic_integrator(velocity_verlet_cofficients) -isokinetic_yoshida = generate_isokinetic_integrator(yoshida_cofficients) -isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_cofficients) +isokinetic_velocity_verlet = generate_isokinetic_integrator( + velocity_verlet_coefficients +) +isokinetic_yoshida = generate_isokinetic_integrator(yoshida_coefficients) +isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_coefficients) +isokinetic_omelyan = generate_isokinetic_integrator(omelyan_coefficients) def partially_refresh_momentum(momentum, rng_key, step_size, L): diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 85891bda6..7bb1b35a5 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -357,7 +357,7 @@ def buildtree_integrate( """ if tree_depth == 0: - # Base case - take one leapfrog step in the direction v. + # Base case - take one velocity_verlet step in the direction v. next_state = integrator(initial_state, direction * step_size) new_proposal = generate_proposal(initial_energy, next_state) is_diverging = -new_proposal.weight > divergence_threshold diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 937339aaf..c38009e5e 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -136,13 +136,15 @@ def kinetic_energy(p, position=None): "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-4}, "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-4}, + "omelyan": {"algorithm": integrators.omelyan, "precision": 1e-4}, "implicit_midpoint": { "algorithm": integrators.implicit_midpoint, "precision": 1e-4, }, - "isokinetic_leapfrog": {"algorithm": integrators.isokinetic_leapfrog}, + "isokinetic_velocity_verlet": {"algorithm": integrators.isokinetic_velocity_verlet}, "isokinetic_mclachlan": {"algorithm": integrators.isokinetic_mclachlan}, "isokinetic_yoshida": {"algorithm": integrators.isokinetic_yoshida}, + "isokinetic_omelyan": {"algorithm": integrators.isokinetic_omelyan}, } @@ -168,6 +170,7 @@ class IntegratorTest(chex.TestCase): "velocity_verlet", "mclachlan", "yoshida", + "omelyan", "implicit_midpoint", ], ) @@ -241,13 +244,13 @@ def test_esh_momentum_update(self, dims): np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @chex.all_variants(with_pmap=False) - def test_isokinetic_leapfrog(self): + def test_isokinetic_velocity_verlet(self): cov = jnp.asarray([[1.0, 0.5, 0.1], [0.5, 2.0, -0.1], [0.1, -0.1, 3.0]]) logdensity_fn = lambda x: stats.multivariate_normal.logpdf( x, jnp.zeros([3]), cov ) - step = self.variant(integrators.isokinetic_leapfrog(logdensity_fn)) + step = self.variant(integrators.isokinetic_velocity_verlet(logdensity_fn)) rng = jax.random.key(4263456) key0, key1 = jax.random.split(rng, 2) @@ -296,9 +299,10 @@ def test_isokinetic_leapfrog(self): @chex.all_variants(with_pmap=False) @parameterized.parameters( [ - "isokinetic_leapfrog", + "isokinetic_velocity_verlet", "isokinetic_mclachlan", "isokinetic_yoshida", + "isokinetic_omelyan", ], ) def test_isokinetic_integrator(self, integrator_name):