Skip to content

Commit

Permalink
ADD ADJUSTED MCLMC TEST
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed May 19, 2024
1 parent 78f35b6 commit 7b16464
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 5 deletions.
5 changes: 3 additions & 2 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Stan warmup for the HMC family of sampling algorithms."""
from typing import Callable, NamedTuple, Union
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -241,14 +241,15 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]:

return init, update, final


def window_adaptation(
algorithm,
logdensity_fn: Callable,
is_mass_matrix_diagonal: bool = True,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
integrator = mcmc.integrators.velocity_verlet,
integrator=mcmc.integrators.velocity_verlet,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
Expand Down
6 changes: 4 additions & 2 deletions blackjax/mcmc/adjusted_mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def kernel(
key_momentum, key_integrator = jax.random.split(rng_key, 2)
momentum = generate_unit_vector(key_momentum, state.position)
proposal, info, _ = adjusted_mclmc_proposal(
integrator=integrators.with_isokinetic_maruyama(integrator(logdensity_fn, std_mat)),
integrator=integrators.with_isokinetic_maruyama(
integrator(logdensity_fn, std_mat)
),
step_size=step_size,
L_proposal=L_proposal*num_integration_steps,
L_proposal=L_proposal * num_integration_steps,
num_integration_steps=num_integration_steps,
divergence_threshold=divergence_threshold,
)(
Expand Down
1 change: 1 addition & 0 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0):
new streaming average
"""

# x, _ = ravel_pytree(x)
expectation = O(x)
flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = streaming_avg
Expand Down
113 changes: 112 additions & 1 deletion tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
import blackjax.diagnostics as diagnostics
import blackjax.mcmc.random_walk
from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info
from blackjax.mcmc.integrators import isokinetic_mclachlan
from blackjax.mcmc.adjusted_mclmc import rescale
from blackjax.mcmc.integrators import (
generate_isokinetic_integrator,
integrator_order,
isokinetic_mclachlan,
mclachlan_coefficients,
)
from blackjax.util import run_inference_algorithm


Expand Down Expand Up @@ -145,6 +151,86 @@ def run_mclmc(

return samples

def run_adjusted_mclmc(
self,
logdensity_fn,
num_steps,
initial_position,
key,
diagonal_preconditioning=False,
):
coefficients = mclachlan_coefficients
integrator = generate_isokinetic_integrator(coefficients)

init_key, tune_key, run_key = jax.random.split(key, 3)

initial_state = blackjax.mcmc.adjusted_mclmc.init(
position=initial_position,
logdensity_fn=logdensity_fn,
random_generator_arg=init_key,
)

kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.adjusted_mclmc.build_kernel(
integrator=integrator,
integration_steps_fn=lambda k: jnp.ceil(
jax.random.uniform(k) * rescale(avg_num_integration_steps)
),
std_mat=std_mat,
)(
rng_key=rng_key,
state=state,
step_size=step_size,
logdensity_fn=logdensity_fn,
)

target_acceptance_rate_of_order = {2: 0.65, 4: 0.8}

target_acc_rate = target_acceptance_rate_of_order[
integrator_order(coefficients)
]

(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
params_history,
final_da,
) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
target=target_acc_rate,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.1,
diagonal_preconditioning=diagonal_preconditioning,
)

step_size = blackjax_mclmc_sampler_params.step_size
L = blackjax_mclmc_sampler_params.L

alg = blackjax.adjusted_mclmc(
logdensity_fn=logdensity_fn,
step_size=step_size,
integration_steps_fn=lambda key: jnp.ceil(
jax.random.uniform(key) * rescale(L / step_size)
),
integrator=integrator,
std_mat=blackjax_mclmc_sampler_params.std_mat,
)

_, out, info = run_inference_algorithm(
rng_key=run_key,
initial_state=blackjax_state_after_tuning,
inference_algorithm=alg,
num_steps=num_steps,
transform=lambda x: x.position,
expectation=lambda x: x.position,
progress_bar=False,
)

return out

@parameterized.parameters(
itertools.product(
regression_test_cases, [True, False], window_adaptation_filters
Expand Down Expand Up @@ -260,6 +346,31 @@ def test_mclmc(self):
np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2)

def test_adjusted_mclmc(self):
"""Test the MCLMC kernel."""

init_key0, init_key1, inference_key = jax.random.split(self.key, 3)
x_data = jax.random.normal(init_key0, shape=(1000, 1))
y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape)

logposterior_fn_ = functools.partial(
self.regression_logprob, x=x_data, preds=y_data
)
logdensity_fn = lambda x: logposterior_fn_(**x)

states = self.run_adjusted_mclmc(
initial_position={"coefs": 1.0, "log_scale": 1.0},
logdensity_fn=logdensity_fn,
key=inference_key,
num_steps=10000,
)

coefs_samples = states["coefs"][3000:]
scale_samples = np.exp(states["log_scale"][3000:])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2)

def test_mclmc_preconditioning(self):
class IllConditionedGaussian:
"""Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2."""
Expand Down

0 comments on commit 7b16464

Please sign in to comment.