diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml deleted file mode 100644 index f9b6f998b..000000000 --- a/.github/workflows/nightly.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: Nightly - -on: - push: - branches: [main] - -jobs: - build_and_publish: - name: Build and publish on PyPi - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - uses: actions/setup-python@v4 - with: - python-version: 3.9 - - name: Update pyproject.toml - # Taken from https://github.com/aesara-devs/aesara/pull/1375 - run: | - curl -sSLf https://github.com/TomWright/dasel/releases/download/v2.0.2/dasel_linux_amd64 \ - -L -o /tmp/dasel && chmod +x /tmp/dasel - /tmp/dasel put -f pyproject.toml project.name -v blackjax-nightly - /tmp/dasel put -f pyproject.toml tool.setuptools_scm.version_scheme -v post-release - /tmp/dasel put -f pyproject.toml tool.setuptools_scm.local_scheme -v no-local-version - - name: Build the sdist and wheel - run: | - python -m pip install -U pip - python -m pip install build - python -m build - - name: Check sdist install and imports - run: | - mkdir -p test-sdist - cd test-sdist - python -m venv venv-sdist - venv-sdist/bin/python -m pip install ../dist/blackjax-nightly-*.tar.gz - venv-sdist/bin/python -c "import blackjax" - - name: Check wheel install and imports - run: | - mkdir -p test-wheel - cd test-wheel - python -m venv venv-wheel - venv-wheel/bin/python -m pip install ../dist/blackjax_nightly-*.whl - - name: Publish to PyPi - uses: pypa/gh-action-pypi-publish@v1.4.2 - with: - user: __token__ - password: ${{ secrets.PYPI_NIGHTLY_TOKEN }} diff --git a/.github/workflows/publish_documentation.yml b/.github/workflows/publish_documentation.yml index db41816f1..b8685bb91 100644 --- a/.github/workflows/publish_documentation.yml +++ b/.github/workflows/publish_documentation.yml @@ -14,10 +14,10 @@ jobs: with: persist-credentials: false - - name: Set up Python 3.9 + - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - name: Build the documentation with Sphinx run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f0a7319de..4bca92fd7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - name: Build sdist and wheel run: | python -m pip install -U pip @@ -51,7 +51,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - name: Give PyPI some time to update the index run: sleep 240 - name: Attempt install from PyPI diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f6cc6cd36..93a9cb5d6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - uses: pre-commit/action@v3.0.0 test: @@ -24,7 +24,7 @@ jobs: - style strategy: matrix: - python-version: [ '3.9', '3.11'] + python-version: ['3.11', '3.12'] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/README.md b/README.md index 45fc65dd2..06d5b46cf 100644 --- a/README.md +++ b/README.md @@ -41,12 +41,6 @@ or via conda-forge: conda install -c conda-forge blackjax ``` -Nightly builds (bleeding edge) of Blackjax can also be installed using `pip`: - -```bash -pip install blackjax-nightly -``` - BlackJAX is written in pure Python but depends on XLA via JAX. By default, the version of JAX that will be installed along with BlackJAX will make your code run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow @@ -81,9 +75,10 @@ state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) -for _ in range(100): - rng_key, nuts_key = jax.random.split(rng_key) - state, _ = nuts.step(nuts_key, state) +step = jax.jit(nuts.step) +for i in range(100): + nuts_key = jax.random.fold_in(rng_key, i) + state, _ = step(nuts_key, state) ``` See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc. @@ -138,12 +133,13 @@ Please follow our [short guide](https://github.com/blackjax-devs/blackjax/blob/m To cite this repository: ``` -@software{blackjax2020github, - author = {Cabezas, Alberto, Lao, Junpeng, and Louf, R\'emi}, - title = {{B}lackjax: A sampling library for {JAX}}, - url = {http://github.com/blackjax-devs/blackjax}, - version = {}, - year = {2023}, +@misc{cabezas2024blackjax, + title={BlackJAX: Composable {B}ayesian inference in {JAX}}, + author={Alberto Cabezas and Adrien Corenflos and Junpeng Lao and Rémi Louf}, + year={2024}, + eprint={2402.10797}, + archivePrefix={arXiv}, + primaryClass={cs.MS} } ``` In the above bibtex entry, names are in alphabetical order, the version number should be the last tag on the `main` branch. diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 96df14920..dfdcfc545 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -1,3 +1,6 @@ +import dataclasses +from typing import Callable + from blackjax._version import __version__ from .adaptation.chees_adaptation import chees_adaptation @@ -5,67 +8,156 @@ from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation +from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat -from .mcmc.barker import barker_proposal -from .mcmc.dynamic_hmc import dynamic_hmc -from .mcmc.elliptical_slice import elliptical_slice -from .mcmc.ghmc import ghmc -from .mcmc.hmc import hmc -from .mcmc.mala import mala -from .mcmc.marginal_latent_gaussian import mgrad_gaussian -from .mcmc.mclmc import mclmc -from .mcmc.nuts import nuts -from .mcmc.periodic_orbital import orbital_hmc -from .mcmc.random_walk import additive_step_random_walk, irmh, rmh -from .mcmc.rmhmc import rmhmc +from .mcmc import barker +from .mcmc import dynamic_hmc as _dynamic_hmc +from .mcmc import elliptical_slice as _elliptical_slice +from .mcmc import ghmc as _ghmc +from .mcmc import hmc as _hmc +from .mcmc import mala as _mala +from .mcmc import marginal_latent_gaussian +from .mcmc import mclmc as _mclmc +from .mcmc import nuts as _nuts +from .mcmc import periodic_orbital, random_walk +from .mcmc import rmhmc as _rmhmc +from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk +from .mcmc.random_walk import ( + irmh_as_top_level_api, + normal_random_walk, + rmh_as_top_level_api, +) from .optimizers import dual_averaging, lbfgs -from .sgmcmc.csgld import csgld -from .sgmcmc.sghmc import sghmc -from .sgmcmc.sgld import sgld -from .sgmcmc.sgnht import sgnht -from .smc.adaptive_tempered import adaptive_tempered_smc -from .smc.inner_kernel_tuning import inner_kernel_tuning -from .smc.tempered import tempered_smc -from .vi.meanfield_vi import meanfield_vi -from .vi.pathfinder import pathfinder -from .vi.schrodinger_follmer import schrodinger_follmer -from .vi.svgd import svgd +from .sgmcmc import csgld as _csgld +from .sgmcmc import sghmc as _sghmc +from .sgmcmc import sgld as _sgld +from .sgmcmc import sgnht as _sgnht +from .smc import adaptive_tempered +from .smc import inner_kernel_tuning as _inner_kernel_tuning +from .smc import tempered +from .vi import meanfield_vi as _meanfield_vi +from .vi import pathfinder as _pathfinder +from .vi import schrodinger_follmer as _schrodinger_follmer +from .vi import svgd as _svgd +from .vi.pathfinder import PathFinderAlgorithm + +""" +The above three classes exist as a backwards compatible way of exposing both the high level, differentiable +factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower +level to be mostly functional programming in nature and reducing boilerplate code. +""" + + +@dataclasses.dataclass +class GenerateSamplingAPI: + differentiable: Callable + init: Callable + build_kernel: Callable + + def __call__(self, *args, **kwargs) -> SamplingAlgorithm: + return self.differentiable(*args, **kwargs) + + def register_factory(self, name, callable): + setattr(self, name, callable) + + +@dataclasses.dataclass +class GenerateVariationalAPI: + differentiable: Callable + init: Callable + step: Callable + sample: Callable + + def __call__(self, *args, **kwargs) -> VIAlgorithm: + return self.differentiable(*args, **kwargs) + + +@dataclasses.dataclass +class GeneratePathfinderAPI: + differentiable: Callable + approximate: Callable + sample: Callable + + def __call__(self, *args, **kwargs) -> PathFinderAlgorithm: + return self.differentiable(*args, **kwargs) + + +def generate_top_level_api_from(module): + return GenerateSamplingAPI( + module.as_top_level_api, module.init, module.build_kernel + ) + + +# MCMC +hmc = generate_top_level_api_from(_hmc) +nuts = generate_top_level_api_from(_nuts) +rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh) +irmh = GenerateSamplingAPI( + irmh_as_top_level_api, random_walk.init, random_walk.build_irmh +) +dynamic_hmc = generate_top_level_api_from(_dynamic_hmc) +rmhmc = generate_top_level_api_from(_rmhmc) +mala = generate_top_level_api_from(_mala) +mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian) +orbital_hmc = generate_top_level_api_from(periodic_orbital) + +additive_step_random_walk = GenerateSamplingAPI( + _additive_step_random_walk, random_walk.init, random_walk.build_additive_step +) + +additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) + +mclmc = generate_top_level_api_from(_mclmc) +elliptical_slice = generate_top_level_api_from(_elliptical_slice) +ghmc = generate_top_level_api_from(_ghmc) +barker_proposal = generate_top_level_api_from(barker) + +hmc_family = [hmc, nuts] + +# SMC +adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered) +tempered_smc = generate_top_level_api_from(tempered) +inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) + +smc_family = [tempered_smc, adaptive_tempered_smc] +"Step_fn returning state has a .particles attribute" + +# stochastic gradient mcmc +sgld = generate_top_level_api_from(_sgld) +sghmc = generate_top_level_api_from(_sghmc) +sgnht = generate_top_level_api_from(_sgnht) +csgld = generate_top_level_api_from(_csgld) +svgd = generate_top_level_api_from(_svgd) + +# variational inference +meanfield_vi = GenerateVariationalAPI( + _meanfield_vi.as_top_level_api, + _meanfield_vi.init, + _meanfield_vi.step, + _meanfield_vi.sample, +) +schrodinger_follmer = GenerateVariationalAPI( + _schrodinger_follmer.as_top_level_api, + _schrodinger_follmer.init, + _schrodinger_follmer.step, + _schrodinger_follmer.sample, +) + +pathfinder = GeneratePathfinderAPI( + _pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample +) + __all__ = [ "__version__", "dual_averaging", # optimizers "lbfgs", - "hmc", # mcmc - "dynamic_hmc", - "rmhmc", - "mala", - "mgrad_gaussian", - "nuts", - "orbital_hmc", - "additive_step_random_walk", - "rmh", - "irmh", - "mclmc", - "elliptical_slice", - "ghmc", - "barker_proposal", - "sgld", # stochastic gradient mcmc - "sghmc", - "sgnht", - "csgld", "window_adaptation", # mcmc adaptation "meads_adaptation", "chees_adaptation", "pathfinder_adaptation", "mclmc_find_L_and_step_size", # mclmc adaptation - "adaptive_tempered_smc", # smc - "tempered_smc", - "inner_kernel_tuning", - "meanfield_vi", # variational inference - "pathfinder", - "schrodinger_follmer", - "svgd", "ess", # diagnostics "rhat", ] diff --git a/blackjax/adaptation/base.py b/blackjax/adaptation/base.py index e0a01e596..c510abaf4 100644 --- a/blackjax/adaptation/base.py +++ b/blackjax/adaptation/base.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple +from typing import NamedTuple, Set + +import jax from blackjax.types import ArrayTree @@ -25,3 +27,34 @@ class AdaptationInfo(NamedTuple): state: NamedTuple info: NamedTuple adaptation_state: NamedTuple + + +def return_all_adapt_info(state, info, adaptation_state): + """Return fully populated AdaptationInfo. Used for adaptation_info_fn + parameters of the adaptation algorithms. + """ + return AdaptationInfo(state, info, adaptation_state) + + +def get_filter_adapt_info_fn( + state_keys: Set[str] = set(), + info_keys: Set[str] = set(), + adapt_state_keys: Set[str] = set(), +): + """Generate a function to filter what is saved in AdaptationInfo. Used + for adptation_info_fn parameters of the adaptation algorithms. + adaptation_info_fn=get_filter_adapt_info_fn() saves no auxiliary information + """ + + def filter_tuple(tup, key_set): + mapfn = lambda key, val: None if key not in key_set else val + return jax.tree.map(mapfn, type(tup)(*tup._fields), tup) + + def filter_fn(state, info, adaptation_state): + sample_state = filter_tuple(state, state_keys) + new_info = filter_tuple(info, info_keys) + new_adapt_state = filter_tuple(adaptation_state, adapt_state_keys) + + return AdaptationInfo(sample_state, new_info, new_adapt_state) + + return filter_fn diff --git a/blackjax/adaptation/chees_adaptation.py b/blackjax/adaptation/chees_adaptation.py index 0448c26cb..60b3e719f 100644 --- a/blackjax/adaptation/chees_adaptation.py +++ b/blackjax/adaptation/chees_adaptation.py @@ -10,7 +10,7 @@ import blackjax.mcmc.dynamic_hmc as dynamic_hmc import blackjax.optimizers.dual_averaging as dual_averaging -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.base import AdaptationAlgorithm from blackjax.types import Array, ArrayLikeTree, PRNGKey from blackjax.util import pytree_size @@ -278,6 +278,7 @@ def chees_adaptation( jitter_amount: float = 1.0, target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE, decay_rate: float = 0.5, + adaptation_info_fn: Callable = return_all_adapt_info, ) -> AdaptationAlgorithm: """Adapt the step size and trajectory length (number of integration steps / step size) parameters of the jittered HMC algorthm. @@ -337,6 +338,11 @@ def chees_adaptation( Float representing how much to favor recent iterations over earlier ones in the optimization of step size and trajectory length. A value of 1 gives equal weight to all history. A value of 0 gives weight only to the most recent iteration. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. Returns ------- @@ -361,20 +367,18 @@ def run( ), "initial `positions` leading dimension must be equal to the `num_chains`" num_dim = pytree_size(positions) // num_chains - key_init, key_step = jax.random.split(rng_key) + next_random_arg_fn = lambda i: i + 1 + init_random_arg = 0 if jitter_generator is not None: - jitter_gn = lambda key: jitter_generator(key) * jitter_amount + ( - 1.0 - jitter_amount - ) - next_random_arg_fn = lambda key: jax.random.split(key)[1] - init_random_arg = key_init + rng_key, carry_key = jax.random.split(rng_key) + jitter_gn = lambda i: jitter_generator( + jax.random.fold_in(carry_key, i) + ) * jitter_amount + (1.0 - jitter_amount) else: jitter_gn = lambda i: dynamic_hmc.halton_sequence( i, np.ceil(np.log2(num_steps + max_sampling_steps)) ) * jitter_amount + (1.0 - jitter_amount) - next_random_arg_fn = lambda i: i + 1 - init_random_arg = 0 def integration_steps_fn(random_generator_arg, trajectory_length_adjusted): return jnp.asarray( @@ -413,10 +417,8 @@ def one_step(carry, rng_key): info.is_divergent, ) - return (new_states, new_adaptation_state), AdaptationInfo( - new_states, - info, - new_adaptation_state, + return (new_states, new_adaptation_state), adaptation_info_fn( + new_states, info, new_adaptation_state ) batch_init = jax.vmap( @@ -425,7 +427,7 @@ def one_step(carry, rng_key): init_states = batch_init(positions) init_adaptation_state = init(init_random_arg, step_size) - keys_step = jax.random.split(key_step, num_steps) + keys_step = jax.random.split(rng_key, num_steps) (last_states, last_adaptation_state), info = jax.lax.scan( one_step, (init_states, init_adaptation_state), keys_step ) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index ba7d0f399..7645a890b 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import pytree_size +from blackjax.util import incremental_value_update, pytree_size class MCLMCAdaptationState(NamedTuple): @@ -30,10 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. + sqrt_diag_cov + A matrix used for preconditioning. """ L: float step_size: float + sqrt_diag_cov: float def mclmc_find_L_and_step_size( @@ -47,6 +50,7 @@ def mclmc_find_L_and_step_size( desired_energy_var=5e-4, trust_in_estimate=1.5, num_effective_samples=150, + diagonal_preconditioning=True, ): """ Finds the optimal value of the parameters for the MCLMC algorithm. @@ -78,38 +82,30 @@ def mclmc_find_L_and_step_size( ------- A tuple containing the final state of the MCMC algorithm and the final hyperparameters. - - Examples + Example ------- - .. code:: + kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=integrator, + std_mat=std_mat, + ) - # Define the kernel function - def kernel(x): - return x ** 2 - - # Define the initial state - initial_state = MCMCState(position=0, momentum=1) - - # Generate a random number generator key - rng_key = jax.random.key(0) - - # Find the optimal parameters for the MCLMC algorithm - final_state, final_params = mclmc_find_L_and_step_size( + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, - num_steps=1000, + num_steps=num_steps, state=initial_state, - rng_key=rng_key, - frac_tune1=0.2, - frac_tune2=0.3, - frac_tune3=0.1, - desired_energy_var=1e-4, - trust_in_estimate=2.0, - num_effective_samples=200, + rng_key=tune_key, + diagonal_preconditioning=preconditioning, ) """ dim = pytree_size(state.position) - params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25) + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) + ) part1_key, part2_key = jax.random.split(rng_key, 2) state, params = make_L_step_size_adaptation( @@ -120,12 +116,13 @@ def kernel(x): desired_energy_var=desired_energy_var, trust_in_estimate=trust_in_estimate, num_effective_samples=num_effective_samples, + diagonal_preconditioning=diagonal_preconditioning, )(state, params, num_steps, part1_key) if frac_tune3 != 0: - state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)( - state, params, num_steps, part2_key - ) + state, params = make_adaptation_L( + mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 + )(state, params, num_steps, part2_key) return state, params @@ -135,6 +132,7 @@ def make_L_step_size_adaptation( dim, frac_tune1, frac_tune2, + diagonal_preconditioning, desired_energy_var=1e-3, trust_in_estimate=1.5, num_effective_samples=150, @@ -150,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state # dynamics - next_state, info = kernel( + next_state, info = kernel(params.sqrt_diag_cov)( rng_key=rng_key, state=previous_state, L=params.L, @@ -185,68 +183,84 @@ def predictor(previous_state, params, adaptive_state, rng_key): ) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences params_new = params._replace(step_size=step_size) - return state, params_new, params_new, (time, x_average, step_size_max), success - - def update_kalman(x, state, outer_weight, success, step_size): - """kalman filter to estimate the size of the posterior""" - time, x_average, x_squared_average = state - weight = outer_weight * step_size * success - zero_prevention = 1 - outer_weight - x_average = (time * x_average + weight * x) / ( - time + weight + zero_prevention - ) # Update with a Kalman filter - x_squared_average = (time * x_squared_average + weight * jnp.square(x)) / ( - time + weight + zero_prevention - ) # Update with a Kalman filter - time += weight - return (time, x_average, x_squared_average) + adaptive_state = (time, x_average, step_size_max) - adap0 = (0.0, 0.0, jnp.inf) + return state, params_new, adaptive_state, success def step(iteration_state, weight_and_key): """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" - outer_weight, rng_key = weight_and_key - state, params, adaptive_state, kalman_state = iteration_state - state, params, params_final, adaptive_state, success = predictor( + mask, rng_key = weight_and_key + state, params, adaptive_state, streaming_avg = iteration_state + + state, params, adaptive_state, success = predictor( state, params, adaptive_state, rng_key ) - position, _ = ravel_pytree(state.position) - kalman_state = update_kalman( - position, kalman_state, outer_weight, success, params.step_size + + x = ravel_pytree(state.position)[0] + # update the running average of x, x^2 + streaming_avg = incremental_value_update( + expectation=jnp.array([x, jnp.square(x)]), + incremental_val=streaming_avg, + weight=(1 - mask) * success * params.step_size, + zero_prevention=mask, ) - return (state, params_final, adaptive_state, kalman_state), None + return (state, params, adaptive_state, streaming_avg), None + + run_steps = lambda xs, state, params: jax.lax.scan( + step, + init=( + state, + params, + (0.0, 0.0, jnp.inf), + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=xs, + )[0] def L_step_size_adaptation(state, params, num_steps, rng_key): - num_steps1, num_steps2 = int(num_steps * frac_tune1), int( - num_steps * frac_tune2 + num_steps1, num_steps2 = ( + int(num_steps * frac_tune1) + 1, + int(num_steps * frac_tune2) + 1, + ) + L_step_size_adaptation_keys = jax.random.split( + rng_key, num_steps1 + num_steps2 + 1 + ) + L_step_size_adaptation_keys, final_key = ( + L_step_size_adaptation_keys[:-1], + L_step_size_adaptation_keys[-1], ) - L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) # we use the last num_steps2 to compute the diagonal preconditioner - outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - - # initial state of the kalman filter - kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim)) + mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) # run the steps - kalman_state = jax.lax.scan( - step, - init=(state, params, adap0, kalman_state), - xs=(outer_weights, L_step_size_adaptation_keys), - length=num_steps1 + num_steps2, - )[0] - state, params, _, kalman_state_output = kalman_state + state, params, _, (_, average) = run_steps( + xs=(mask, L_step_size_adaptation_keys), state=state, params=params + ) L = params.L # determine L + sqrt_diag_cov = params.sqrt_diag_cov if num_steps2 != 0.0: - _, F1, F2 = kalman_state_output - variances = F2 - jnp.square(F1) + x_average, x_squared_average = average[0], average[1] + variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) - return state, MCLMCAdaptationState(L, params.step_size) + if diagonal_preconditioning: + sqrt_diag_cov = jnp.sqrt(variances) + params = params._replace(sqrt_diag_cov=sqrt_diag_cov) + L = jnp.sqrt(dim) + + # readjust the stepsize + steps = num_steps2 // 3 # we do some small number of steps + keys = jax.random.split(final_key, steps) + state, params, _, (_, average) = run_steps( + xs=(jnp.ones(steps), keys), state=state, params=params + ) + + return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) return L_step_size_adaptation @@ -258,7 +272,6 @@ def adaptation_L(state, params, num_steps, key): num_steps = int(num_steps * frac) adaptation_L_keys = jax.random.split(key, num_steps) - # run kernel in the normal way def step(state, key): next_state, _ = kernel( rng_key=key, @@ -297,5 +310,4 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch (next_state, step_size_max, kinetic_change), (previous_state, step_size * reduced_step_size, 0.0), ) - return nonans, state, step_size, kinetic_change diff --git a/blackjax/adaptation/meads_adaptation.py b/blackjax/adaptation/meads_adaptation.py index e50065710..b383653e8 100644 --- a/blackjax/adaptation/meads_adaptation.py +++ b/blackjax/adaptation/meads_adaptation.py @@ -17,7 +17,7 @@ import jax.numpy as jnp import blackjax.mcmc as mcmc -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.base import AdaptationAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -36,7 +36,7 @@ class MEADSAdaptationState(NamedTuple): alpha Value of the alpha parameter of the generalized HMC algorithm. delta - Value of the alpha parameter of the generalized HMC algorithm. + Value of the delta parameter of the generalized HMC algorithm. """ @@ -60,7 +60,7 @@ def base(): with shape. This is an implementation of Algorithm 3 of :cite:p:`hoffman2022tuning` using cross-chain - adaptation instead of parallel ensample chain adaptation. + adaptation instead of parallel ensemble chain adaptation. Returns ------- @@ -93,16 +93,16 @@ def compute_parameters( of the generalized HMC algorithm. """ - mean_position = jax.tree_map(lambda p: p.mean(axis=0), positions) - sd_position = jax.tree_map(lambda p: p.std(axis=0), positions) - normalized_positions = jax.tree_map( + mean_position = jax.tree.map(lambda p: p.mean(axis=0), positions) + sd_position = jax.tree.map(lambda p: p.std(axis=0), positions) + normalized_positions = jax.tree.map( lambda p, mu, sd: (p - mu) / sd, positions, mean_position, sd_position, ) - batch_grad_scaled = jax.tree_map( + batch_grad_scaled = jax.tree.map( lambda grad, sd: grad * sd, logdensity_grad, sd_position ) @@ -165,6 +165,7 @@ def update( def meads_adaptation( logdensity_fn: Callable, num_chains: int, + adaptation_info_fn: Callable = return_all_adapt_info, ) -> AdaptationAlgorithm: """Adapt the parameters of the Generalized HMC algorithm. @@ -194,6 +195,11 @@ def meads_adaptation( The log density probability density function from which we wish to sample. num_chains Number of chains used for cross-chain warm-up training. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. Returns ------- @@ -227,10 +233,8 @@ def one_step(carry, rng_key): adaptation_state, new_states.position, new_states.logdensity_grad ) - return (new_states, new_adaptation_state), AdaptationInfo( - new_states, - info, - new_adaptation_state, + return (new_states, new_adaptation_state), adaptation_info_fn( + new_states, info, new_adaptation_state ) def run(rng_key: PRNGKey, positions: ArrayLikeTree, num_steps: int = 1000): diff --git a/blackjax/adaptation/pathfinder_adaptation.py b/blackjax/adaptation/pathfinder_adaptation.py index c70ed3f99..c0b4ebc50 100644 --- a/blackjax/adaptation/pathfinder_adaptation.py +++ b/blackjax/adaptation/pathfinder_adaptation.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Pathinder warmup for the HMC family of sampling algorithms.""" -from typing import Callable, NamedTuple, Union +from typing import Callable, NamedTuple import jax import jax.numpy as jnp -import blackjax.mcmc as mcmc import blackjax.vi as vi -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.step_size import ( DualAveragingAdaptationState, dual_averaging_adaptation, @@ -138,10 +137,11 @@ def final(warmup_state: PathfinderAdaptationState) -> tuple[float, Array]: def pathfinder_adaptation( - algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts], + algorithm, logdensity_fn: Callable, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, + adaptation_info_fn: Callable = return_all_adapt_info, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of @@ -157,6 +157,11 @@ def pathfinder_adaptation( The initial step size used in the algorithm. target_acceptance_rate The acceptance rate that we target during step size adaptation. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. @@ -189,7 +194,7 @@ def one_step(carry, rng_key): ) return ( (new_state, new_adaptation_state), - AdaptationInfo(new_state, info, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400): diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 2d6b0182f..2b06172c0 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -158,8 +158,8 @@ def final(da_state: DualAveragingAdaptationState) -> float: class ReasonableStepSizeState(NamedTuple): """State carried through the search for a reasonable first step size. - rng_key - Key used by JAX's random number generator. + step + The current iteration step. direction: {-1, 1} Determines whether the step size should be increased or decreased during the previous step search. If direction = 1 it will be increased, otherwise decreased. @@ -171,7 +171,7 @@ class ReasonableStepSizeState(NamedTuple): """ - rng_key: PRNGKey + step: int direction: int previous_direction: int step_size: float @@ -243,17 +243,17 @@ def do_continue(rss_state: ReasonableStepSizeState) -> bool: def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: """Perform one step of the step size search.""" - rng_key, direction, _, step_size = rss_state - rng_key, subkey = jax.random.split(rng_key) + i, direction, _, step_size = rss_state + subkey = jax.random.fold_in(rng_key, i) step_size = (2.0**direction) * step_size kernel = kernel_generator(step_size) _, info = kernel(subkey, reference_state) new_direction = jnp.where(target_accept < info.acceptance_rate, 1, -1) - return ReasonableStepSizeState(rng_key, new_direction, direction, step_size) + return ReasonableStepSizeState(i + 1, new_direction, direction, step_size) - rss_state = ReasonableStepSizeState(rng_key, 0, 0, initial_step_size) + rss_state = ReasonableStepSizeState(0, 0, 0, initial_step_size) rss_state = jax.lax.while_loop(do_continue, update, rss_state) return rss_state.step_size diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index cc871b4b6..69a098325 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Stan warmup for the HMC family of sampling algorithms.""" -from typing import Callable, NamedTuple, Union +from typing import Callable, NamedTuple import jax import jax.numpy as jnp import blackjax.mcmc as mcmc -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.mass_matrix import ( MassMatrixAdaptationState, mass_matrix_adaptation, @@ -28,7 +28,7 @@ dual_averaging_adaptation, ) from blackjax.base import AdaptationAlgorithm -from blackjax.progress_bar import progress_bar_scan +from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, PRNGKey from blackjax.util import pytree_size @@ -243,16 +243,18 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]: def window_adaptation( - algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts], + algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, progress_bar: bool = False, + adaptation_info_fn: Callable = return_all_adapt_info, + integrator=mcmc.integrators.velocity_verlet, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of - algorithms in the HMC fmaily. + algorithms in the HMC fmaily. See Blackjax.hmc_family Algorithms in the HMC family on a euclidean manifold depend on the value of at least two parameters: the step size, related to the trajectory @@ -279,6 +281,11 @@ def window_adaptation( The acceptance rate that we target during step size adaptation. progress_bar Whether we should display a progress bar. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. @@ -289,7 +296,7 @@ def window_adaptation( """ - mcmc_kernel = algorithm.build_kernel() + mcmc_kernel = algorithm.build_kernel(integrator) adapt_init, adapt_step, adapt_final = base( is_mass_matrix_diagonal, @@ -317,7 +324,7 @@ def one_step(carry, xs): return ( (new_state, new_adaptation_state), - AdaptationInfo(new_state, info, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): @@ -326,17 +333,16 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): if progress_bar: print("Running window adaptation") - one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step)) - else: - one_step_ = jax.jit(one_step) - + scan_fn = gen_scan_fn(num_steps, progress_bar=progress_bar) + start_state = (init_state, init_adaptation_state) keys = jax.random.split(rng_key, num_steps) schedule = build_schedule(num_steps) - last_state, info = jax.lax.scan( - one_step_, - (init_state, init_adaptation_state), + last_state, info = scan_fn( + one_step, + start_state, (jnp.arange(num_steps), keys, schedule), ) + last_chain_state, last_warmup_state, *_ = last_state step_size, inverse_mass_matrix = adapt_final(last_warmup_state) diff --git a/blackjax/base.py b/blackjax/base.py index f766e98b5..8ea24cd70 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -89,7 +89,7 @@ class SamplingAlgorithm(NamedTuple): """A pair of functions that represents a MCMC sampling algorithm. Blackjax sampling algorithms are implemented as a pair of pure functions: a - kernel, that takes a new samples starting from the current state, and an + kernel, that generates a new sample from the current state, and an initialization function that creates a kernel state from a chain position. As they represent Markov kernels, the kernel functions are pure functions diff --git a/blackjax/mcmc/barker.py b/blackjax/mcmc/barker.py index b91721a71..9923bd5f3 100644 --- a/blackjax/mcmc/barker.py +++ b/blackjax/mcmc/barker.py @@ -24,7 +24,7 @@ from blackjax.mcmc.proposal import static_binomial_sampling from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "barker_proposal"] +__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"] class BarkerState(NamedTuple): @@ -128,7 +128,10 @@ def kernel( return kernel -class barker_proposal: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a Gaussian base kernel. @@ -179,24 +182,16 @@ class barker_proposal: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel() - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) + def step_fn(rng_key: PRNGKey, state): + return kernel(rng_key, state, logdensity_fn, step_size) - def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, step_size) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def _barker_sample_nd(key, mean, a, scale): diff --git a/blackjax/mcmc/dynamic_hmc.py b/blackjax/mcmc/dynamic_hmc.py index 0fe4ec992..de77be825 100644 --- a/blackjax/mcmc/dynamic_hmc.py +++ b/blackjax/mcmc/dynamic_hmc.py @@ -27,8 +27,8 @@ "DynamicHMCState", "init", "build_kernel", - "dynamic_hmc", "halton_sequence", + "as_top_level_api", ] @@ -115,7 +115,16 @@ def kernel( return kernel -class dynamic_hmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: """Implements the (basic) user interface for the dynamic HMC kernel. Parameters @@ -144,41 +153,26 @@ class dynamic_hmc: ------- A ``SamplingAlgorithm``. """ - - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - inverse_mass_matrix: Array, - *, - divergence_threshold: int = 1000, - integrator: Callable = integrators.velocity_verlet, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), - ) -> SamplingAlgorithm: - kernel = cls.build_kernel( - integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn + kernel = build_kernel( + integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + # Note that rng_key here is not necessarily a PRNGKey, could be a Array that + # for generates a sequence of pseudo or quasi-random numbers (previously + # named as `random_generator_arg`) + return init(position, logdensity_fn, rng_key) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, ) - def init_fn(position: ArrayLikeTree, rng_key: Array): - # Note that rng_key here is not necessarily a PRNGKey, could be a Array that - # for generates a sequence of pseudo or quasi-random numbers (previously - # named as `random_generator_arg`) - return cls.init(position, logdensity_fn, rng_key) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - inverse_mass_matrix, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def halton_sequence(i: Array, max_bits: int = 10) -> float: diff --git a/blackjax/mcmc/elliptical_slice.py b/blackjax/mcmc/elliptical_slice.py index c0d1c5998..09ed66c86 100644 --- a/blackjax/mcmc/elliptical_slice.py +++ b/blackjax/mcmc/elliptical_slice.py @@ -26,7 +26,7 @@ "EllipSliceInfo", "init", "build_kernel", - "elliptical_slice", + "as_top_level_api", ] @@ -119,7 +119,12 @@ def kernel( return kernel -class elliptical_slice: +def as_top_level_api( + loglikelihood_fn: Callable, + *, + mean: Array, + cov: Array, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Elliptical Slice sampling kernel. Examples @@ -151,31 +156,20 @@ class elliptical_slice: ------- A ``SamplingAlgorithm``. """ + kernel = build_kernel(cov, mean) - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - loglikelihood_fn: Callable, - *, - mean: Array, - cov: Array, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(cov, mean) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, loglikelihood_fn) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, loglikelihood_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - loglikelihood_fn, - ) + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + loglikelihood_fn, + ) - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def elliptical_proposal( @@ -208,7 +202,7 @@ def generate( rng_key: PRNGKey, state: EllipSliceState ) -> tuple[EllipSliceState, EllipSliceInfo]: position, logdensity = state - key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 3) + key_slice, key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 4) # step 1: sample momentum momentum = momentum_generator(key_momentum, position) # step 2: get slice (y) @@ -235,20 +229,20 @@ def slice_fn(vals): likelihood is continuous with respect to the parameter being sampled. """ - rng, _, subiter, theta, theta_min, theta_max, *_ = vals - rng, thetak = jax.random.split(rng) + _, subiter, theta, theta_min, theta_max, *_ = vals + thetak = jax.random.fold_in(key_slice, subiter) theta = jax.random.uniform(thetak, minval=theta_min, maxval=theta_max) p, m = ellipsis(position, momentum, theta, mean) logdensity = logdensity_fn(p) theta_min = jnp.where(theta < 0, theta, theta_min) theta_max = jnp.where(theta > 0, theta, theta_max) subiter += 1 - return rng, logdensity, subiter, theta, theta_min, theta_max, p, m + return logdensity, subiter, theta, theta_min, theta_max, p, m - _, logdensity, subiter, theta, *_, position, momentum = jax.lax.while_loop( - lambda vals: vals[1] <= logy, + logdensity, subiter, theta, *_, position, momentum = jax.lax.while_loop( + lambda vals: vals[0] <= logy, slice_fn, - (rng_key, logdensity, 1, theta, theta_min, theta_max, p, m), + (logdensity, 1, theta, theta_min, theta_max, p, m), ) return ( EllipSliceState(position, logdensity), diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index ada6bea9c..a04ce0641 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -25,7 +25,7 @@ from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise -__all__ = ["GHMCState", "init", "build_kernel", "ghmc"] +__all__ = ["GHMCState", "init", "build_kernel", "as_top_level_api"] class GHMCState(NamedTuple): @@ -185,7 +185,7 @@ def update_momentum(rng_key, state, alpha, momentum_generator): """ position, momentum, *_ = state - momentum = jax.tree_map( + momentum = jax.tree.map( lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha) + jnp.sqrt(alpha) * shifted_momentum, momentum, @@ -195,7 +195,16 @@ def update_momentum(rng_key, state, alpha, momentum_generator): return momentum -class ghmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + momentum_inverse_scale: ArrayLikeTree, + alpha: float, + delta: float, + *, + divergence_threshold: int = 1000, + noise_gn: Callable = lambda _: 0.0, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Generalized HMC kernel. The Generalized HMC kernel performs a similar procedure to the standard HMC @@ -257,34 +266,20 @@ class ghmc: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(noise_gn, divergence_threshold) - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - momentum_inverse_scale: ArrayLikeTree, - alpha: float, - delta: float, - *, - divergence_threshold: int = 1000, - noise_gn: Callable = lambda _: 0.0, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(noise_gn, divergence_threshold) - - def init_fn(position: ArrayLikeTree, rng_key: PRNGKey): - return cls.init(position, rng_key, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - momentum_inverse_scale, - alpha, - delta, - ) - - return SamplingAlgorithm(init_fn, step_fn) + def init_fn(position: ArrayLikeTree, rng_key: PRNGKey): + return init(position, rng_key, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + momentum_inverse_scale, + alpha, + delta, + ) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index b48834e5f..452b94e44 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -29,7 +29,7 @@ "HMCInfo", "init", "build_kernel", - "hmc", + "as_top_level_api", ] @@ -150,7 +150,15 @@ def kernel( return kernel -class hmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes, + num_integration_steps: int, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the HMC kernel. The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias @@ -225,36 +233,23 @@ class hmc: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(integrator, divergence_threshold) - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - inverse_mass_matrix: metrics.MetricTypes, - num_integration_steps: int, - *, - divergence_threshold: int = 1000, - integrator: Callable = integrators.velocity_verlet, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(integrator, divergence_threshold) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - inverse_mass_matrix, - num_integration_steps, - ) - - return SamplingAlgorithm(init_fn, step_fn) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + num_integration_steps, + ) + + return SamplingAlgorithm(init_fn, step_fn) def hmc_proposal( diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index f4009b16e..1d4b95a09 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -17,18 +17,22 @@ import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree +from jax.random import normal from blackjax.mcmc.metrics import KineticEnergy from blackjax.types import ArrayTree __all__ = [ "mclachlan", + "omelyan", "velocity_verlet", "yoshida", - "implicit_midpoint", - "isokinetic_leapfrog", + "with_isokinetic_maruyama", + "isokinetic_velocity_verlet", "isokinetic_mclachlan", + "isokinetic_omelyan", "isokinetic_yoshida", + "implicit_midpoint", ] @@ -69,7 +73,7 @@ def generalized_two_stage_integrator( .. math:: \\frac{d}{dt}f = (O_1+O_2)f - The leapfrog operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}` + The velocity_verlet operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}` by :math:`e^{\\epsilon O_1/2}e^{\\epsilon O_2}e^{\\epsilon O_1/2}`. In a standard Hamiltonian, the forms of :math:`e^{\\epsilon O_2}` and @@ -209,7 +213,7 @@ def format_euclidean_state_output( return IntegratorState(position, momentum, logdensity, logdensity_grad) -def generate_euclidean_integrator(cofficients): +def generate_euclidean_integrator(coefficients): """Generate symplectic integrator for solving a Hamiltonian system. The resulting integrator is volume-preserve and preserves the symplectic structure @@ -224,7 +228,7 @@ def euclidean_integrator( one_step = generalized_two_stage_integrator( momentum_update_fn, position_update_fn, - cofficients, + coefficients, format_output_fn=format_euclidean_state_output, ) return one_step @@ -250,8 +254,8 @@ def euclidean_integrator( of the kinetic energy. We are trading accuracy in exchange, and it is not clear whether this is the right tradeoff. """ -velocity_verlet_cofficients = [0.5, 1.0, 0.5] -velocity_verlet = generate_euclidean_integrator(velocity_verlet_cofficients) +velocity_verlet_coefficients = [0.5, 1.0, 0.5] +velocity_verlet = generate_euclidean_integrator(velocity_verlet_coefficients) """ Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. @@ -267,8 +271,8 @@ def euclidean_integrator( b1 = 0.1931833275037836 a1 = 0.5 b2 = 1 - 2 * b1 -mclachlan_cofficients = [b1, a1, b2, a1, b1] -mclachlan = generate_euclidean_integrator(mclachlan_cofficients) +mclachlan_coefficients = [b1, a1, b2, a1, b1] +mclachlan = generate_euclidean_integrator(mclachlan_coefficients) """ Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` @@ -283,8 +287,22 @@ def euclidean_integrator( a1 = 0.29619504261126 b2 = 0.5 - b1 a2 = 1 - 2 * a1 -yoshida_cofficients = [b1, a1, b2, a2, b2, a1, b1] -yoshida = generate_euclidean_integrator(yoshida_cofficients) +yoshida_coefficients = [b1, a1, b2, a2, b2, a1, b1] +yoshida = generate_euclidean_integrator(yoshida_coefficients) + +""" +Eleven-stage palindromic symplectic integrator derived in :cite:p:`omelyan2003symplectic`. + +Popular in LQCD, see also :cite:p:`takaishi2006testing`. +""" +b1 = 0.08398315262876693 +a1 = 0.2539785108410595 +b2 = 0.6822365335719091 +a2 = -0.03230286765269967 +b3 = 0.5 - b1 - b2 +a3 = 1 - 2 * (a1 + a2) +omelyan_coefficients = [b1, a1, b2, a2, b3, a3, b3, a2, b2, a1, b1] +omelyan = generate_euclidean_integrator(omelyan_coefficients) # Intergrators with non Euclidean updates @@ -293,43 +311,49 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step( - momentum: ArrayTree, - logdensity_grad: ArrayTree, - step_size: float, - coef: float, - previous_kinetic_energy_change=None, - is_last_call=False, -): - """Momentum update based on Esh dynamics. +def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): + def update( + momentum: ArrayTree, + logdensity_grad: ArrayTree, + step_size: float, + coef: float, + previous_kinetic_energy_change=None, + is_last_call=False, + ): + """Momentum update based on Esh dynamics. + + The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian` + There are no exponentials e^delta, which prevents overflows when the gradient norm + is large. + """ + del is_last_call + + logdensity_grad = logdensity_grad + flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) + flatten_grads = flatten_grads * sqrt_diag_cov + flatten_momentum, _ = ravel_pytree(momentum) + dims = flatten_momentum.shape[0] + normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) + momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) + delta = step_size * coef * gradient_norm / (dims - 1) + zeta = jnp.exp(-delta) + new_momentum_raw = ( + normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) + + 2 * zeta * flatten_momentum + ) + new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) + gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) + next_momentum = unravel_fn(new_momentum_normalized) + kinetic_energy_change = ( + delta + - jnp.log(2) + + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) + ) * (dims - 1) + if previous_kinetic_energy_change is not None: + kinetic_energy_change += previous_kinetic_energy_change + return next_momentum, gr, kinetic_energy_change - The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian` - There are no exponentials e^delta, which prevents overflows when the gradient norm - is large. - """ - del is_last_call - - flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_momentum, _ = ravel_pytree(momentum) - dims = flatten_momentum.shape[0] - normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) - momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) - delta = step_size * coef * gradient_norm / (dims - 1) - zeta = jnp.exp(-delta) - new_momentum_raw = ( - normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) - + 2 * zeta * flatten_momentum - ) - new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - next_momentum = unravel_fn(new_momentum_normalized) - kinetic_energy_change = ( - delta - - jnp.log(2) - + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) - ) * (dims - 1) - if previous_kinetic_energy_change is not None: - kinetic_energy_change += previous_kinetic_energy_change - return next_momentum, next_momentum, kinetic_energy_change + return update def format_isokinetic_state_output( @@ -348,15 +372,15 @@ def format_isokinetic_state_output( ) -def generate_isokinetic_integrator(cofficients): +def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, *args, **kwargs + logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0 ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step, + esh_dynamics_momentum_update_one_step(sqrt_diag_cov), position_update_fn, - cofficients, + coefficients, format_output_fn=format_isokinetic_state_output, ) return one_step @@ -364,9 +388,66 @@ def isokinetic_integrator( return isokinetic_integrator -isokinetic_leapfrog = generate_isokinetic_integrator(velocity_verlet_cofficients) -isokinetic_yoshida = generate_isokinetic_integrator(yoshida_cofficients) -isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_cofficients) +isokinetic_velocity_verlet = generate_isokinetic_integrator( + velocity_verlet_coefficients +) +isokinetic_yoshida = generate_isokinetic_integrator(yoshida_coefficients) +isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_coefficients) +isokinetic_omelyan = generate_isokinetic_integrator(omelyan_coefficients) + + +def partially_refresh_momentum(momentum, rng_key, step_size, L): + """Adds a small noise to momentum and normalizes. + + Parameters + ---------- + rng_key + The pseudo-random number generator key used to generate random numbers. + momentum + PyTree that the structure the output should to match. + step_size + Step size + L + controls rate of momentum change + + Returns + ------- + momentum with random change in angle + """ + m, unravel_fn = ravel_pytree(momentum) + dim = m.shape[0] + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) + z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) + return unravel_fn((m + z) / jnp.linalg.norm(m + z)) + + +def with_isokinetic_maruyama(integrator): + def stochastic_integrator(init_state, step_size, L_proposal, rng_key): + key1, key2 = jax.random.split(rng_key) + # partial refreshment + state = init_state._replace( + momentum=partially_refresh_momentum( + momentum=init_state.momentum, + rng_key=key1, + L=L_proposal, + step_size=step_size * 0.5, + ) + ) + # one step of the deterministic dynamics + state, info = integrator(state, step_size) + # partial refreshment + state = state._replace( + momentum=partially_refresh_momentum( + momentum=state.momentum, + rng_key=key2, + L=L_proposal, + step_size=step_size * 0.5, + ) + ) + return state, info + + return stochastic_integrator + FixedPointSolver = Callable[ [Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree], diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 9690bc7f5..56c0c0077 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -23,7 +23,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["MALAState", "MALAInfo", "init", "build_kernel", "mala"] +__all__ = ["MALAState", "MALAInfo", "init", "build_kernel", "as_top_level_api"] class MALAState(NamedTuple): @@ -79,15 +79,15 @@ def build_kernel(): def transition_energy(state, new_state, step_size): """Transition energy to go from `state` to `new_state`""" theta = jax.tree_util.tree_map( - lambda new_x, x, g: new_x - x - step_size * g, - new_state.position, + lambda x, new_x, g: x - new_x - step_size * g, state.position, - state.logdensity_grad, + new_state.position, + new_state.logdensity_grad, ) theta_dot = jax.tree_util.tree_reduce( operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta) ) - return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot + return -new_state.logdensity + 0.25 * (1.0 / step_size) * theta_dot compute_acceptance_ratio = proposal.compute_asymmetric_acceptance_ratio( transition_energy @@ -117,7 +117,10 @@ def kernel( return kernel -class mala: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the MALA kernel. The general mala kernel builder (:meth:`blackjax.mcmc.mala.build_kernel`, alias `blackjax.mala.build_kernel`) can be @@ -167,21 +170,13 @@ class mala: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() + kernel = build_kernel() - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) - def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, step_size) + def step_fn(rng_key: PRNGKey, state): + return kernel(rng_key, state, logdensity_fn, step_size) - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/marginal_latent_gaussian.py b/blackjax/mcmc/marginal_latent_gaussian.py index 8d4d76f6a..d2783f8d9 100644 --- a/blackjax/mcmc/marginal_latent_gaussian.py +++ b/blackjax/mcmc/marginal_latent_gaussian.py @@ -22,7 +22,13 @@ from blackjax.mcmc.proposal import static_binomial_sampling from blackjax.types import Array, PRNGKey -__all__ = ["MarginalState", "MarginalInfo", "init", "build_kernel", "mgrad_gaussian"] +__all__ = [ + "MarginalState", + "MarginalInfo", + "init", + "build_kernel", + "as_top_level_api", +] # [TODO](https://github.com/blackjax-devs/blackjax/issues/237) @@ -206,7 +212,13 @@ def kernel(key: PRNGKey, state: MarginalState, logdensity_fn, delta): return kernel -class mgrad_gaussian: +def as_top_level_api( + logdensity_fn: Callable, + covariance: Optional[Array] = None, + mean: Optional[Array] = None, + cov_svd: Optional[CovarianceSVD] = None, + step_size: float = 1.0, +) -> SamplingAlgorithm: """Implements the marginal sampler for latent Gaussian model of :cite:p:`titsias2018auxiliary`. It uses a first order approximation to the log_likelihood of a model with Gaussian prior. @@ -247,41 +259,28 @@ class mgrad_gaussian: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - covariance: Optional[Array] = None, - mean: Optional[Array] = None, - cov_svd: Optional[CovarianceSVD] = None, - step_size: float = 1.0, - ) -> SamplingAlgorithm: - if cov_svd is None: - if covariance is None: - raise ValueError("Either covariance or cov_svd must be provided.") - cov_svd = svd_from_covariance(covariance) - - U, Gamma, U_t = cov_svd - - if mean is not None: - logdensity_fn = generate_mean_shifted_logprob( - logdensity_fn, mean, covariance - ) - - kernel = cls.build_kernel(cov_svd) - - def init_fn(position: Array, rng_key=None): - del rng_key - return init(position, logdensity_fn, U_t) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - ) - - return SamplingAlgorithm(init_fn, step_fn) + if cov_svd is None: + if covariance is None: + raise ValueError("Either covariance or cov_svd must be provided.") + cov_svd = svd_from_covariance(covariance) + + U, Gamma, U_t = cov_svd + + if mean is not None: + logdensity_fn = generate_mean_shifted_logprob(logdensity_fn, mean, covariance) + + kernel = build_kernel(cov_svd) + + def init_fn(position: Array, rng_key=None): + del rng_key + return init(position, logdensity_fn, U_t) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + ) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 7c636181f..e7a69849b 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -15,16 +15,17 @@ from typing import Callable, NamedTuple import jax -import jax.numpy as jnp -from jax.flatten_util import ravel_pytree -from jax.random import normal from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan +from blackjax.mcmc.integrators import ( + IntegratorState, + isokinetic_mclachlan, + with_isokinetic_maruyama, +) from blackjax.types import ArrayLike, PRNGKey from blackjax.util import generate_unit_vector, pytree_size -__all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] +__all__ = ["MCLMCInfo", "init", "build_kernel", "as_top_level_api"] class MCLMCInfo(NamedTuple): @@ -59,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, integrator): +def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """Build a HMC kernel. Parameters @@ -78,18 +79,16 @@ def build_kernel(logdensity_fn, integrator): information about the transition. """ - step = integrator(logdensity_fn) + + step = with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: (position, momentum, logdensity, logdensitygrad), kinetic_change = step( - state, step_size - ) - - # Langevin-like noise - momentum = partially_refresh_momentum( - momentum=momentum, rng_key=rng_key, L=L, step_size=step_size + state, step_size, L, rng_key ) return IntegratorState( @@ -103,7 +102,13 @@ def kernel( return kernel -class mclmc: +def as_top_level_api( + logdensity_fn: Callable, + L, + step_size, + integrator=isokinetic_mclachlan, + sqrt_diag_cov=1.0, +) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel parameters at initialization time, we provide a helper function that @@ -150,47 +155,12 @@ class mclmc: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - L, - step_size, - integrator=isokinetic_mclachlan, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(logdensity_fn, integrator) - - def init_fn(position: ArrayLike, rng_key: PRNGKey): - return cls.init(position, logdensity_fn, rng_key) - - def update_fn(rng_key, state): - return kernel(rng_key, state, L, step_size) - - return SamplingAlgorithm(init_fn, update_fn) + kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return init(position, logdensity_fn, rng_key) -def partially_refresh_momentum(momentum, rng_key, step_size, L): - """Adds a small noise to momentum and normalizes. + def update_fn(rng_key, state): + return kernel(rng_key, state, L, step_size) - Parameters - ---------- - rng_key - The pseudo-random number generator key used to generate random numbers. - momentum - PyTree that the structure the output should to match. - step_size - Step size - L - controls rate of momentum change - - Returns - ------- - momentum with random change in angle - """ - m, unravel_fn = ravel_pytree(momentum) - dim = m.shape[0] - nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) - z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - return unravel_fn((m + z) / jnp.linalg.norm(m + z)) + return SamplingAlgorithm(init_fn, update_fn) diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 5ffc083b1..c75ecdec6 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -27,7 +27,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["NUTSInfo", "init", "build_kernel", "nuts"] +__all__ = ["NUTSInfo", "init", "build_kernel", "as_top_level_api"] init = hmc.init @@ -147,7 +147,15 @@ def kernel( return kernel -class nuts: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes, + *, + max_num_doublings: int = 10, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the nuts kernel. Examples @@ -202,37 +210,23 @@ class nuts: A ``SamplingAlgorithm``. """ + kernel = build_kernel(integrator, divergence_threshold) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + max_num_doublings, + ) - init = staticmethod(hmc.init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - inverse_mass_matrix: metrics.MetricTypes, - *, - max_num_doublings: int = 10, - divergence_threshold: int = 1000, - integrator: Callable = integrators.velocity_verlet, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(integrator, divergence_threshold) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - inverse_mass_matrix, - max_num_doublings, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def iterative_nuts_proposal( diff --git a/blackjax/mcmc/periodic_orbital.py b/blackjax/mcmc/periodic_orbital.py index 6e4a2ca5e..61625a0b8 100644 --- a/blackjax/mcmc/periodic_orbital.py +++ b/blackjax/mcmc/periodic_orbital.py @@ -22,7 +22,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["PeriodicOrbitalState", "init", "build_kernel", "orbital_hmc"] +__all__ = ["PeriodicOrbitalState", "init", "build_kernel", "as_top_level_api"] class PeriodicOrbitalState(NamedTuple): @@ -217,7 +217,14 @@ def kernel( return kernel -class orbital_hmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, # assume momentum is always Gaussian + period: int, + *, + bijection: Callable = integrators.velocity_verlet, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Periodic orbital MCMC kernel. Each iteration of the periodic orbital MCMC outputs ``period`` weighted samples from @@ -261,36 +268,23 @@ class orbital_hmc: ------- A ``SamplingAlgorithm``. """ + kernel = build_kernel(bijection) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn, period) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + period, + ) - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - inverse_mass_matrix: Array, # assume momentum is always Gaussian - period: int, - *, - bijection: Callable = integrators.velocity_verlet, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(bijection) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn, period) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - inverse_mass_matrix, - period, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def periodic_orbital_proposal( diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 5ec95edf0..258dbf29e 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -153,7 +153,7 @@ def progressive_biased_sampling( biases the transition away from the trajectory's initial state. """ - p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), a_max=1) + p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), max=1) do_accept = jax.random.bernoulli(rng_key, p_accept) new_weight = jnp.logaddexp(proposal.weight, new_proposal.weight) new_sum_log_p_accept = jnp.logaddexp( @@ -224,7 +224,7 @@ def static_binomial_sampling( then the new proposal is accepted with probability 1. """ - p_accept = jnp.clip(jnp.exp(log_p_accept), a_max=1) + p_accept = jnp.clip(jnp.exp(log_p_accept), max=1) do_accept = jax.random.bernoulli(rng_key, p_accept) info = do_accept, p_accept, None return ( @@ -253,7 +253,7 @@ def nonreversible_slice_sampling( to the accept/reject step of a current state and new proposal. """ - p_accept = jnp.clip(jnp.exp(delta_energy), a_max=1) + p_accept = jnp.clip(jnp.exp(delta_energy), max=1) do_accept = jnp.log(jnp.abs(slice)) <= delta_energy slice_next = slice * (jnp.exp(-delta_energy) * do_accept + (1 - do_accept)) info = do_accept, p_accept, slice_next diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index e454c057d..a1d1c3bd6 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -80,8 +80,9 @@ "rmh_proposal", "build_rmh_transition_energy", "additive_step_random_walk", - "irmh", - "rmh", + "irmh_as_top_level_api", + "rmh_as_top_level_api", + "normal_random_walk", ] @@ -182,7 +183,25 @@ def proposal_generator(key_proposal, position): return kernel -class additive_step_random_walk: +def normal_random_walk(logdensity_fn: Callable, sigma): + """ + Parameters + ---------- + logdensity_fn + The log density probability density function from which we wish to sample. + sigma + The value of the covariance matrix of the gaussian proposal distribution. + + Returns + ------- + A ``SamplingAlgorithm``. + """ + return additive_step_random_walk(logdensity_fn, normal(sigma)) + + +def additive_step_random_walk( + logdensity_fn: Callable, random_step: Callable +) -> SamplingAlgorithm: """Implements the user interface for the Additive Step RMH Examples @@ -218,39 +237,16 @@ class additive_step_random_walk: ------- A ``SamplingAlgorithm``. """ + kernel = build_additive_step() - init = staticmethod(init) - build_kernel = staticmethod(build_additive_step) - - @classmethod - def normal_random_walk(cls, logdensity_fn: Callable, sigma): - """ - Parameters - ---------- - logdensity_fn - The log density probability density function from which we wish to sample. - sigma - The value of the covariance matrix of the gaussian proposal distribution. - - Returns - ------- - A ``SamplingAlgorithm``. - """ - return cls(logdensity_fn, normal(sigma)) - - def __new__( # type: ignore[misc] - cls, logdensity_fn: Callable, random_step: Callable - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) + def step_fn(rng_key: PRNGKey, state): + return kernel(rng_key, state, logdensity_fn, random_step) - def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, random_step) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def build_irmh() -> Callable: @@ -297,7 +293,11 @@ def proposal_generator(rng_key: PRNGKey, position: ArrayTree): return kernel -class irmh: +def irmh_as_top_level_api( + logdensity_fn: Callable, + proposal_distribution: Callable, + proposal_logdensity_fn: Optional[Callable] = None, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the independent RMH. Examples @@ -334,32 +334,22 @@ class irmh: A ``SamplingAlgorithm``. """ + kernel = build_irmh() + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + proposal_distribution, + proposal_logdensity_fn, + ) - init = staticmethod(init) - build_kernel = staticmethod(build_irmh) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - proposal_distribution: Callable, - proposal_logdensity_fn: Optional[Callable] = None, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - proposal_distribution, - proposal_logdensity_fn, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def build_rmh(): @@ -420,7 +410,11 @@ def kernel( return kernel -class rmh: +def rmh_as_top_level_api( + logdensity_fn: Callable, + proposal_generator: Callable[[PRNGKey, ArrayLikeTree], ArrayTree], + proposal_logdensity_fn: Optional[Callable[[ArrayLikeTree], ArrayTree]] = None, +) -> SamplingAlgorithm: """Implements the user interface for the RMH. Examples @@ -456,32 +450,22 @@ class rmh: ------- A ``SamplingAlgorithm``. """ + kernel = build_rmh() + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + proposal_generator, + proposal_logdensity_fn, + ) - init = staticmethod(init) - build_kernel = staticmethod(build_rmh) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - proposal_generator: Callable[[PRNGKey, ArrayLikeTree], ArrayTree], - proposal_logdensity_fn: Optional[Callable[[ArrayLikeTree], ArrayTree]] = None, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - proposal_generator, - proposal_logdensity_fn, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def build_rmh_transition_energy(proposal_logdensity_fn: Optional[Callable]) -> Callable: diff --git a/blackjax/mcmc/rmhmc.py b/blackjax/mcmc/rmhmc.py index edcfb3571..a4551a781 100644 --- a/blackjax/mcmc/rmhmc.py +++ b/blackjax/mcmc/rmhmc.py @@ -20,14 +20,22 @@ from blackjax.mcmc import hmc from blackjax.types import ArrayTree, PRNGKey -__all__ = ["init", "build_kernel", "rmhmc"] +__all__ = ["init", "build_kernel", "as_top_level_api"] init = hmc.init build_kernel = hmc.build_kernel -class rmhmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + mass_matrix: Union[metrics.Metric, Callable], + num_integration_steps: int, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.implicit_midpoint, +) -> SamplingAlgorithm: """A Riemannian Manifold Hamiltonian Monte Carlo kernel Of note, this kernel is simply an alias of the ``hmc`` kernel with a @@ -62,34 +70,20 @@ class rmhmc: ------- A ``SamplingAlgorithm``. """ + kernel = build_kernel(integrator, divergence_threshold) - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + def init_fn(position: ArrayTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - mass_matrix: Union[metrics.Metric, Callable], - num_integration_steps: int, - *, - divergence_threshold: int = 1000, - integrator: Callable = integrators.implicit_midpoint, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(integrator, divergence_threshold) + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + mass_matrix, + num_integration_steps, + ) - def init_fn(position: ArrayTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - mass_matrix, - num_integration_steps, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/termination.py b/blackjax/mcmc/termination.py index 24e17c3a5..eb1276da3 100644 --- a/blackjax/mcmc/termination.py +++ b/blackjax/mcmc/termination.py @@ -64,16 +64,10 @@ def _leaf_idx_to_ckpt_idxs(n): """Find the checkpoint id from a step number.""" # computes the number of non-zero bits except the last bit # e.g. 6 -> 2, 7 -> 2, 13 -> 2 - _, idx_max = jax.lax.while_loop( - lambda nc: nc[0] > 0, - lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)), - (n >> 1, 0), - ) + idx_max = jnp.bitwise_count(n >> 1).astype(jnp.int32) # computes the number of contiguous last non-zero bits # e.g. 6 -> 0, 7 -> 3, 13 -> 1 - _, num_subtrees = jax.lax.while_loop( - lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0) - ) + num_subtrees = jnp.bitwise_count((~n & (n + 1)) - 1).astype(jnp.int32) idx_min = idx_max - num_subtrees + 1 return idx_min, idx_max diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 6deeb9bef..7bb1b35a5 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -201,7 +201,7 @@ def integrate( def do_keep_integrating(loop_state): """Decide whether we should continue integrating the trajectory""" - _, integration_state, (is_diverging, has_terminated) = loop_state + integration_state, (is_diverging, has_terminated) = loop_state return ( (integration_state.step < max_num_steps) & ~has_terminated @@ -209,9 +209,9 @@ def do_keep_integrating(loop_state): ) def add_one_state(loop_state): - rng_key, integration_state, _ = loop_state + integration_state, _ = loop_state step, proposal, trajectory, termination_state = integration_state - rng_key, proposal_key = jax.random.split(rng_key) + proposal_key = jax.random.fold_in(rng_key, step) new_state = integrator(trajectory.rightmost_state, direction * step_size) new_proposal = generate_proposal(initial_energy, new_state) @@ -246,7 +246,7 @@ def add_one_state(loop_state): new_termination_state, ) - return (rng_key, new_integration_state, (is_diverging, has_terminated)) + return (new_integration_state, (is_diverging, has_terminated)) proposal_placeholder = generate_proposal(initial_energy, initial_state) trajectory_placeholder = Trajectory( @@ -259,12 +259,12 @@ def add_one_state(loop_state): termination_state, ) - _, integration_state, (is_diverging, has_terminated) = jax.lax.while_loop( + new_integration_state, (is_diverging, has_terminated) = jax.lax.while_loop( do_keep_integrating, add_one_state, - (rng_key, integration_state_placeholder, (False, False)), + (integration_state_placeholder, (False, False)), ) - step, proposal, trajectory, termination_state = integration_state + _, proposal, trajectory, termination_state = new_integration_state # In the while_loop we always extend on the right most direction. new_trajectory = jax.lax.cond( @@ -357,7 +357,7 @@ def buildtree_integrate( """ if tree_depth == 0: - # Base case - take one leapfrog step in the direction v. + # Base case - take one velocity_verlet step in the direction v. next_state = integrator(initial_state, direction * step_size) new_proposal = generate_proposal(initial_energy, next_state) is_diverging = -new_proposal.weight > divergence_threshold @@ -385,7 +385,7 @@ def buildtree_integrate( initial_energy, ) # Note that is_diverging and is_turning is inplace updated - if ~is_diverging & ~is_turning: + if (not is_diverging) & (not is_turning): start_state = jax.lax.cond( direction > 0, lambda _: trajectory.rightmost_state, @@ -411,7 +411,7 @@ def buildtree_integrate( ) trajectory = merge_trajectories(left_trajectory, right_trajectory) - if ~is_turning: + if not is_turning: is_turning = uturn_check_fn( trajectory.leftmost_state.momentum, trajectory.rightmost_state.momentum, @@ -496,8 +496,7 @@ def dynamic_multiplicative_expansion( step_size The step size used by the symplectic integrator. max_num_expansions - The maximum number of trajectory expansions until the proposal is - returned. + The maximum number of trajectory expansions until the proposal is returned. rate The rate of the geometrical expansion. Typically 2 in NUTS, this is why the literature often refers to "tree doubling". @@ -513,7 +512,7 @@ def expand( ): def do_keep_expanding(loop_state) -> bool: """Determine whether we need to keep expanding the trajectory.""" - _, expansion_state, (is_diverging, is_turning) = loop_state + expansion_state, (is_diverging, is_turning) = loop_state return ( (expansion_state.step < max_num_expansions) & ~is_diverging @@ -531,12 +530,11 @@ def expand_once(loop_state): the subtrajectory. """ - rng_key, expansion_state, _ = loop_state + expansion_state, _ = loop_state step, proposal, trajectory, termination_state = expansion_state - rng_key, direction_key, trajectory_key, proposal_key = jax.random.split( - rng_key, 4 - ) + subkey = jax.random.fold_in(rng_key, step) + direction_key, trajectory_key, proposal_key = jax.random.split(subkey, 3) # create new subtrajectory that is twice as long as the current # trajectory. @@ -608,12 +606,12 @@ def update_sum_log_p_accept(inputs): ) info = (is_diverging, is_turning_subtree | is_turning) - return (rng_key, new_state, info) + return (new_state, info) - _, expansion_state, (is_diverging, is_turning) = jax.lax.while_loop( + expansion_state, (is_diverging, is_turning) = jax.lax.while_loop( do_keep_expanding, expand_once, - (rng_key, initial_expansion_state, (False, False)), + (initial_expansion_state, (False, False)), ) return expansion_state, (is_diverging, is_turning) diff --git a/blackjax/optimizers/lbfgs.py b/blackjax/optimizers/lbfgs.py index d549240df..0dd59f003 100644 --- a/blackjax/optimizers/lbfgs.py +++ b/blackjax/optimizers/lbfgs.py @@ -139,7 +139,7 @@ def minimize_lbfgs( f=history_raveled.f, g=unravel_fn_mapped(history_raveled.g), alpha=unravel_fn_mapped(history_raveled.alpha), - update_mask=jax.tree_map( + update_mask=jax.tree.map( lambda x: x.astype(history_raveled.update_mask.dtype), unravel_fn_mapped(history_raveled.update_mask.astype(x0_raveled.dtype)), ), @@ -230,7 +230,7 @@ def scan_body(tup, it): scan_body, ((init_step, initial_history), True), jnp.arange(maxiter) ) # Append initial state to history. - history = jax.tree_map( + history = jax.tree.map( lambda x, y: jnp.concatenate([x[None, ...], y], axis=0), initial_history, history, diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index d4fa45ca5..a1425df88 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -14,14 +14,19 @@ """Progress bar decorators for use with step functions. Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`. """ +from threading import Lock + from fastprogress.fastprogress import progress_bar from jax import lax -from jax.experimental import host_callback +from jax.experimental import io_callback +from jax.numpy import array def progress_bar_scan(num_samples, print_rate=None): "Progress bar for a JAX scan" progress_bars = {} + idx_counter = 0 + lock = Lock() if print_rate is None: if num_samples > 20: @@ -29,57 +34,44 @@ def progress_bar_scan(num_samples, print_rate=None): else: print_rate = 1 # if you run the sampler for less than 20 iterations - def _define_bar(arg, transform, device): - progress_bars[0] = progress_bar(range(num_samples)) - progress_bars[0].update(0) + def _calc_chain_idx(iter_num): + nonlocal idx_counter + with lock: + idx = idx_counter + idx_counter += 1 + return idx + + def _update_bar(arg, chain_id): + chain_id = int(chain_id) + if arg == 0: + chain_id = _calc_chain_idx(arg) + progress_bars[chain_id] = progress_bar(range(num_samples)) + progress_bars[chain_id].update(0) + + progress_bars[chain_id].update_bar(arg + 1) + return chain_id - def _update_bar(arg, transform, device): - progress_bars[0].update_bar(arg) + def _close_bar(arg, chain_id): + progress_bars[int(chain_id)].on_iter_end() - def _update_progress_bar(iter_num): + def _update_progress_bar(iter_num, chain_id): "Updates progress bar of a JAX scan or loop" - _ = lax.cond( - iter_num == 0, - lambda _: host_callback.id_tap( - _define_bar, iter_num, result=iter_num, tap_with_device=True - ), - lambda _: iter_num, - operand=None, - ) - _ = lax.cond( + chain_id = lax.cond( # update every multiple of `print_rate` except at the end - (iter_num % print_rate == 0), - lambda _: host_callback.id_tap( - _update_bar, iter_num, result=iter_num, tap_with_device=True - ), - lambda _: iter_num, + (iter_num % print_rate == 0) | (iter_num == (num_samples - 1)), + lambda _: io_callback(_update_bar, array(0), iter_num, chain_id), + lambda _: chain_id, operand=None, ) _ = lax.cond( - # update by `remainder` iter_num == num_samples - 1, - lambda _: host_callback.id_tap( - _update_bar, num_samples, result=iter_num, tap_with_device=True - ), - lambda _: iter_num, - operand=None, - ) - - def _close_bar(arg, transform, device): - progress_bars[0].on_iter_end() - print() - - def close_bar(result, iter_num): - return lax.cond( - iter_num == num_samples - 1, - lambda _: host_callback.id_tap( - _close_bar, None, result=result, tap_with_device=True - ), - lambda _: result, + lambda _: io_callback(_close_bar, None, iter_num + 1, chain_id), + lambda _: None, operand=None, ) + return chain_id def _progress_bar_scan(func): """Decorator that adds a progress bar to `body_fun` used in `lax.scan`. @@ -93,10 +85,26 @@ def wrapper_progress_bar(carry, x): iter_num, *_ = x else: iter_num = x - _update_progress_bar(iter_num) - result = func(carry, x) - return close_bar(result, iter_num) + subcarry, chain_id = carry + chain_id = _update_progress_bar(iter_num, chain_id) + subcarry, y = func(subcarry, x) + + return (subcarry, chain_id), y return wrapper_progress_bar return _progress_bar_scan + + +def gen_scan_fn(num_samples, progress_bar, print_rate=None): + if progress_bar: + + def scan_wrap(f, init, *args, **kwargs): + func = progress_bar_scan(num_samples, print_rate)(f) + carry = (init, -1) + (last_state, _), output = lax.scan(func, carry, *args, **kwargs) + return last_state, output + + return scan_wrap + else: + return lax.scan diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index e0e008a33..506740c50 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -23,7 +23,7 @@ from blackjax.sgmcmc.diffusions import overdamped_langevin from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["ContourSGLDState", "init", "build_kernel", "csgld"] +__all__ = ["ContourSGLDState", "init", "build_kernel", "as_top_level_api"] class ContourSGLDState(NamedTuple): @@ -174,7 +174,14 @@ def kernel( return kernel -class csgld: +def as_top_level_api( + logdensity_estimator: Callable, + gradient_estimator: Callable, + zeta: float = 1, + num_partitions: int = 512, + energy_gap: float = 100, + min_energy: float = 0, +) -> SamplingAlgorithm: r"""Implements the (basic) user interface for the Contour SGLD kernel. Parameters @@ -209,42 +216,30 @@ class csgld: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(num_partitions, energy_gap, min_energy) - def __new__( # type: ignore[misc] - cls, - logdensity_estimator: Callable, - gradient_estimator: Callable, - zeta: float = 1, - num_partitions: int = 512, - energy_gap: float = 100, - min_energy: float = 0, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(num_partitions, energy_gap, min_energy) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, num_partitions) - - def step_fn( - rng_key: PRNGKey, - state: ContourSGLDState, - minibatch: ArrayLikeTree, - step_size_diff: float, - step_size_stoch: float, - temperature: float = 1.0, - ) -> ContourSGLDState: - return kernel( - rng_key, - state, - logdensity_estimator, - gradient_estimator, - minibatch, - step_size_diff, - step_size_stoch, - zeta, - temperature, - ) - - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, num_partitions) + + def step_fn( + rng_key: PRNGKey, + state: ContourSGLDState, + minibatch: ArrayLikeTree, + step_size_diff: float, + step_size_stoch: float, + temperature: float = 1.0, + ) -> ContourSGLDState: + return kernel( + rng_key, + state, + logdensity_estimator, + gradient_estimator, + minibatch, + step_size_diff, + step_size_stoch, + zeta, + temperature, + ) + + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/gradients.py b/blackjax/sgmcmc/gradients.py index a326fefaa..f3686924b 100644 --- a/blackjax/sgmcmc/gradients.py +++ b/blackjax/sgmcmc/gradients.py @@ -125,7 +125,7 @@ def cv_grad_estimator_fn( grad_estimate = logdensity_grad_estimator(position, minibatch) center_grad_estimate = logdensity_grad_estimator(centering_position, minibatch) - return jax.tree_map( + return jax.tree.map( lambda grad_est, cv_grad_est, cv_grad: cv_grad + grad_est - cv_grad_est, grad_estimate, center_grad_estimate, diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 806bbc14e..afa8e2e42 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -21,7 +21,7 @@ from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise -__all__ = ["init", "build_kernel", "sghmc"] +__all__ = ["init", "build_kernel", "as_top_level_api"] def init(position: ArrayLikeTree) -> ArrayLikeTree: @@ -58,7 +58,12 @@ def body_fn(state, rng_key): return kernel -class sghmc: +def as_top_level_api( + grad_estimator: Callable, + num_integration_steps: int = 10, + alpha: float = 0.01, + beta: float = 0, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the SGHMC kernel. The general sghmc kernel builder (:meth:`blackjax.sgmcmc.sghmc.build_kernel`, alias @@ -111,37 +116,27 @@ class sghmc: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(alpha, beta) - def __new__( # type: ignore[misc] - cls, - grad_estimator: Callable, - num_integration_steps: int = 10, - alpha: float = 0.01, - beta: float = 0, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(alpha, beta) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position) - - def step_fn( - rng_key: PRNGKey, - state: ArrayLikeTree, - minibatch: ArrayLikeTree, - step_size: float, - temperature: float = 1, - ) -> ArrayTree: - return kernel( - rng_key, - state, - grad_estimator, - minibatch, - step_size, - num_integration_steps, - temperature, - ) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position) - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + def step_fn( + rng_key: PRNGKey, + state: ArrayLikeTree, + minibatch: ArrayLikeTree, + step_size: float, + temperature: float = 1, + ) -> ArrayTree: + return kernel( + rng_key, + state, + grad_estimator, + minibatch, + step_size, + num_integration_steps, + temperature, + ) + + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index e2055c511..dca47a983 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -18,7 +18,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["init", "build_kernel", "sgld"] +__all__ = ["init", "build_kernel", "as_top_level_api"] def init(position: ArrayLikeTree) -> ArrayLikeTree: @@ -47,7 +47,9 @@ def kernel( return kernel -class sgld: +def as_top_level_api( + grad_estimator: Callable, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the SGLD kernel. The general sgld kernel builder (:meth:`blackjax.sgmcmc.sgld.build_kernel`, alias @@ -100,28 +102,19 @@ class sgld: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel() - def __new__( # type: ignore[misc] - cls, - grad_estimator: Callable, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position) - - def step_fn( - rng_key: PRNGKey, - state: ArrayLikeTree, - minibatch: ArrayLikeTree, - step_size: float, - temperature: float = 1, - ) -> ArrayTree: - return kernel( - rng_key, state, grad_estimator, minibatch, step_size, temperature - ) - - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position) + + def step_fn( + rng_key: PRNGKey, + state: ArrayLikeTree, + minibatch: ArrayLikeTree, + step_size: float, + temperature: float = 1, + ) -> ArrayTree: + return kernel(rng_key, state, grad_estimator, minibatch, step_size, temperature) + + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/sgnht.py b/blackjax/sgmcmc/sgnht.py index 57b0a4ca2..ad9547406 100644 --- a/blackjax/sgmcmc/sgnht.py +++ b/blackjax/sgmcmc/sgnht.py @@ -19,7 +19,7 @@ from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise -__all__ = ["SGNHTState", "init", "build_kernel", "sgnht"] +__all__ = ["SGNHTState", "init", "build_kernel", "as_top_level_api"] class SGNHTState(NamedTuple): @@ -67,7 +67,11 @@ def kernel( return kernel -class sgnht: +def as_top_level_api( + grad_estimator: Callable, + alpha: float = 0.01, + beta: float = 0.0, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the SGNHT kernel. The general sgnht kernel (:meth:`blackjax.sgmcmc.sgnht.build_kernel`, alias @@ -121,33 +125,22 @@ class sgnht: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(alpha, beta) - def __new__( # type: ignore[misc] - cls, - grad_estimator: Callable, - alpha: float = 0.01, - beta: float = 0.0, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(alpha, beta) - - def init_fn( - position: ArrayLikeTree, - rng_key: PRNGKey, - init_xi: Union[None, float] = None, - ): - return cls.init(position, rng_key, init_xi or alpha) - - def step_fn( - rng_key: PRNGKey, - state: SGNHTState, - minibatch: ArrayLikeTree, - step_size: float, - temperature: float = 1, - ) -> SGNHTState: - return kernel( - rng_key, state, grad_estimator, minibatch, step_size, temperature - ) - - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + def init_fn( + position: ArrayLikeTree, + rng_key: PRNGKey, + init_xi: Union[None, float] = None, + ): + return init(position, rng_key, init_xi or alpha) + + def step_fn( + rng_key: PRNGKey, + state: SGNHTState, + minibatch: ArrayLikeTree, + step_size: float, + temperature: float = 1, + ) -> SGNHTState: + return kernel(rng_key, state, grad_estimator, minibatch, step_size, temperature) + + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/smc/__init__.py b/blackjax/smc/__init__.py index 180cd8259..ef10b10e6 100644 --- a/blackjax/smc/__init__.py +++ b/blackjax/smc/__init__.py @@ -1,3 +1,9 @@ from . import adaptive_tempered, inner_kernel_tuning, tempered +from .base import extend_params -__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning"] +__all__ = [ + "adaptive_tempered", + "tempered", + "inner_kernel_tuning", + "extend_params", +] diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index 5b02c783b..10fb194fa 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -23,7 +23,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, PRNGKey -__all__ = ["build_kernel", "adaptive_tempered_smc"] +__all__ = ["build_kernel", "init", "as_top_level_api"] def build_kernel( @@ -103,7 +103,20 @@ def kernel( return kernel -class adaptive_tempered_smc: +init = tempered.init + + +def as_top_level_api( + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + mcmc_parameters: dict, + resampling_fn: Callable, + target_ess: float, + root_solver: Callable = solver.dichotomy, + num_mcmc_steps: int = 10, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. Parameters @@ -117,7 +130,8 @@ class adaptive_tempered_smc: mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters - The parameters of the MCMC step function. + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. target_ess @@ -133,42 +147,26 @@ class adaptive_tempered_smc: A ``SamplingAlgorithm``. """ + kernel = build_kernel( + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + target_ess, + root_solver, + ) - init = staticmethod(tempered.init) - build_kernel = staticmethod(build_kernel) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position) - def __new__( # type: ignore[misc] - cls, - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - mcmc_parameters: dict, - resampling_fn: Callable, - target_ess: float, - root_solver: Callable = solver.dichotomy, - num_mcmc_steps: int = 10, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel( - logprior_fn, - loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, - target_ess, - root_solver, + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + num_mcmc_steps, + mcmc_parameters, ) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - num_mcmc_steps, - mcmc_parameters, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 409f588d2..5093cf06b 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -40,6 +40,7 @@ class SMCState(NamedTuple): particles: ArrayTree weights: Array + update_parameters: ArrayTree class SMCInfo(NamedTuple): @@ -59,12 +60,12 @@ class SMCInfo(NamedTuple): update_info: NamedTuple -def init(particles: ArrayLikeTree): +def init(particles: ArrayLikeTree, init_update_params): # Infer the number of particles from the size of the leading dimension of # the first leaf of the inputted PyTree. num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] weights = jnp.ones(num_particles) / num_particles - return SMCState(particles, weights) + return SMCState(particles, weights, init_update_params) def step( @@ -134,16 +135,24 @@ def step( num_resampled = num_particles resampling_idx = resample_fn(resampling_key, state.weights, num_resampled) - particles = jax.tree_map(lambda x: x[resampling_idx], state.particles) + particles = jax.tree.map(lambda x: x[resampling_idx], state.particles) keys = jax.random.split(updating_key, num_resampled) - particles, update_info = update_fn(keys, particles) + particles, update_info = update_fn(keys, particles, state.update_parameters) log_weights = weight_fn(particles) logsum_weights = jax.scipy.special.logsumexp(log_weights) normalizing_constant = logsum_weights - jnp.log(num_particles) weights = jnp.exp(log_weights - logsum_weights) - return SMCState(particles, weights), SMCInfo( + return SMCState(particles, weights, state.update_parameters), SMCInfo( resampling_idx, normalizing_constant, update_info ) + + +def extend_params(params): + """Given a dictionary of params, repeats them for every single particle. The expected + usage is in cases where the aim is to repeat the same parameters for all chains within SMC. + """ + + return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params) diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 6aaf3a5d3..2a63fd1ce 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -1,15 +1,20 @@ -from typing import Callable, Dict, NamedTuple, Tuple, Union +from typing import Callable, Dict, NamedTuple, Tuple from blackjax.base import SamplingAlgorithm -from blackjax.smc.adaptive_tempered import adaptive_tempered_smc from blackjax.smc.base import SMCInfo, SMCState -from blackjax.smc.tempered import tempered_smc from blackjax.types import ArrayTree, PRNGKey class StateWithParameterOverride(NamedTuple): + """ + Stores both the sampling status and also a dictionary + that contains an dictionary with parameter names as key + and (n_particles, *) arrays as meanings. The latter + represent a parameter per chain for the next mutation step. + """ + sampler_state: ArrayTree - parameter_override: ArrayTree + parameter_override: Dict[str, ArrayTree] def init(alg_init_fn, position, initial_parameter_value): @@ -20,11 +25,10 @@ def build_kernel( smc_algorithm, logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_factory: Callable, + mcmc_step_fn: Callable, mcmc_init_fn: Callable, - mcmc_parameters: Dict, resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], + mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], num_mcmc_steps: int = 10, **extra_parameters, ) -> Callable: @@ -41,12 +45,11 @@ def build_kernel( A function that computes the log density of the prior distribution loglikelihood_fn A function that returns the probability at a given position. - mcmc_factory - A callable that can construct an inner kernel out of the newly-computed parameter + mcmc_step_fn: + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. + mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) mcmc_init_fn A callable that initializes the inner kernel - mcmc_parameters - Other (fixed across SMC iterations) parameters for the inner kernel mcmc_parameter_update_fn A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. extra_parameters: @@ -59,9 +62,9 @@ def kernel( step_fn = smc_algorithm( logprior_fn=logprior_fn, loglikelihood_fn=loglikelihood_fn, - mcmc_step_fn=mcmc_factory(state.parameter_override), + mcmc_step_fn=mcmc_step_fn, mcmc_init_fn=mcmc_init_fn, - mcmc_parameters=mcmc_parameters, + mcmc_parameters=state.parameter_override, resampling_fn=resampling_fn, num_mcmc_steps=num_mcmc_steps, **extra_parameters, @@ -73,7 +76,18 @@ def kernel( return kernel -class inner_kernel_tuning: +def as_top_level_api( + smc_algorithm, + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], + initial_parameter_value, + num_mcmc_steps: int = 10, + **extra_parameters, +) -> SamplingAlgorithm: """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some @@ -84,22 +98,20 @@ class inner_kernel_tuning: ---------- smc_algorithm Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of - a sampling algorithm that returns an SMCState and SMCInfo pair). + a sampling algorithm that returns an SMCState and SMCInfo pair). See blackjax.smc_family logprior_fn A function that computes the log density of the prior distribution loglikelihood_fn A function that returns the probability at a given position. - mcmc_factory - A callable that can construct an inner kernel out of the newly-computed parameter + mcmc_step_fn + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. mcmc_init_fn A callable that initializes the inner kernel - mcmc_parameters - Other (fixed across SMC iterations) parameters for the inner kernel step mcmc_parameter_update_fn A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. initial_parameter_value - Paramter to be used by the mcmc_factory before the first iteration. + Parameter to be used by the mcmc_factory before the first iteration. extra_parameters: parameters to be used for the creation of the smc_algorithm. @@ -109,43 +121,25 @@ class inner_kernel_tuning: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - smc_algorithm: Union[adaptive_tempered_smc, tempered_smc], - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_factory: Callable, - mcmc_init_fn: Callable, - mcmc_parameters: Dict, - resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], - initial_parameter_value, - num_mcmc_steps: int = 10, + kernel = build_kernel( + smc_algorithm, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + mcmc_parameter_update_fn, + num_mcmc_steps, **extra_parameters, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel( - smc_algorithm, - logprior_fn, - loglikelihood_fn, - mcmc_factory, - mcmc_init_fn, - mcmc_parameters, - resampling_fn, - mcmc_parameter_update_fn, - num_mcmc_steps, - **extra_parameters, - ) + ) - def init_fn(position, rng_key=None): - del rng_key - return cls.init(smc_algorithm.init, position, initial_parameter_value) + def init_fn(position, rng_key=None): + del rng_key + return init(smc_algorithm.init, position, initial_parameter_value) - def step_fn( - rng_key: PRNGKey, state, **extra_step_parameters - ) -> Tuple[StateWithParameterOverride, SMCInfo]: - return kernel(rng_key, state, **extra_step_parameters) + def step_fn( + rng_key: PRNGKey, state, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + return kernel(rng_key, state, **extra_step_parameters) - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 49fa21277..43b83d034 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, NamedTuple import jax @@ -21,7 +22,7 @@ from blackjax.smc.base import SMCState from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["TemperedSMCState", "init", "build_kernel"] +__all__ = ["TemperedSMCState", "init", "build_kernel", "as_top_level_api"] class TemperedSMCState(NamedTuple): @@ -108,6 +109,9 @@ def kernel( Current state of the tempered SMC algorithm lmbda Current value of the tempering parameter + mcmc_parameters + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. Returns ------- @@ -119,6 +123,14 @@ def kernel( """ delta = lmbda - state.lmbda + shared_mcmc_parameters = {} + unshared_mcmc_parameters = {} + for k, v in mcmc_parameters.items(): + if v.shape[0] == 1: + shared_mcmc_parameters[k] = v[0, ...] + else: + unshared_mcmc_parameters[k] = v + def log_weights_fn(position: ArrayLikeTree) -> float: return delta * loglikelihood_fn(position) @@ -127,12 +139,14 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood - def mcmc_kernel(rng_key, position): + shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + + def mcmc_kernel(rng_key, position, step_parameters): state = mcmc_init_fn(position, tempered_logposterior_fn) def body_fn(state, rng_key): - new_state, info = mcmc_step_fn( - rng_key, state, tempered_logposterior_fn, **mcmc_parameters + new_state, info = shared_mcmc_step_fn( + rng_key, state, tempered_logposterior_fn, **step_parameters ) return new_state, info @@ -142,7 +156,7 @@ def body_fn(state, rng_key): smc_state, info = smc.base.step( rng_key, - SMCState(state.particles, state.weights), + SMCState(state.particles, state.weights, unshared_mcmc_parameters), jax.vmap(mcmc_kernel), jax.vmap(log_weights_fn), resampling_fn, @@ -156,7 +170,15 @@ def body_fn(state, rng_key): return kernel -class tempered_smc: +def as_top_level_api( + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + mcmc_parameters: dict, + resampling_fn: Callable, + num_mcmc_steps: int = 10, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. Parameters @@ -170,7 +192,8 @@ class tempered_smc: mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters - The parameters of the MCMC step function. + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. num_mcmc_steps @@ -181,39 +204,25 @@ class tempered_smc: A ``SamplingAlgorithm``. """ - - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - mcmc_parameters: dict, - resampling_fn: Callable, - num_mcmc_steps: int = 10, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel( - logprior_fn, - loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, + kernel = build_kernel( + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + ) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position) + + def step_fn(rng_key: PRNGKey, state, lmbda): + return kernel( + rng_key, + state, + num_mcmc_steps, + lmbda, + mcmc_parameters, ) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position) - - def step_fn(rng_key: PRNGKey, state, lmbda): - return kernel( - rng_key, - state, - num_mcmc_steps, - lmbda, - mcmc_parameters, - ) - - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/util.py b/blackjax/util.py index df527ed01..b6c5367b5 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -1,4 +1,5 @@ """Utility functions for BlackJax.""" + from functools import partial from typing import Callable, Union @@ -6,10 +7,10 @@ from jax import jit, lax from jax.flatten_util import ravel_pytree from jax.random import normal, split -from jax.tree_util import tree_leaves +from jax.tree_util import tree_leaves, tree_map -from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm -from blackjax.progress_bar import progress_bar_scan +from blackjax.base import SamplingAlgorithm, VIAlgorithm +from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -142,12 +143,13 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: def run_inference_algorithm( rng_key: PRNGKey, - initial_state_or_position: ArrayLikeTree, inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm], num_steps: int, + initial_state: ArrayLikeTree = None, + initial_position: ArrayLikeTree = None, progress_bar: bool = False, - transform: Callable = lambda x: x, -) -> tuple[State, State, Info]: + transform: Callable = lambda state, info: (state, info), +) -> tuple: """Wrapper to run an inference algorithm. Note that this utility function does not work for Stochastic Gradient MCMC samplers @@ -158,9 +160,10 @@ def run_inference_algorithm( ---------- rng_key The random state used by JAX's random numbers generator. - initial_state_or_position - The initial state OR the initial position of the inference algorithm. If an initial position - is passed in, the function will automatically convert it into an initial state. + initial_state + The initial state of the inference algorithm. + initial_position + The initial position of the inference algorithm. This is used when the initial state is not provided. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps @@ -168,37 +171,141 @@ def run_inference_algorithm( progress_bar Whether to display a progress bar. transform - A transformation of the trace of states to be returned. This is useful for + A transformation of the trace of states (and info) to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. Returns ------- - Tuple[State, State, Info] - 1. The final state of the inference algorithm. - 2. The trace of states of the inference algorithm (contains the MCMC samples). - 3. The trace of the info of the inference algorithm for diagnostics. + 1. The final state. + 2. The history of states. """ - init_key, sample_key = split(rng_key, 2) - try: - initial_state = inference_algorithm.init(initial_state_or_position, init_key) - except (TypeError, ValueError, AttributeError): - # We assume initial_state is already in the right format. - initial_state = initial_state_or_position - keys = split(sample_key, num_steps) + if initial_state is None and initial_position is None: + raise ValueError( + "Either `initial_state` or `initial_position` must be provided." + ) + if initial_state is not None and initial_position is not None: + raise ValueError( + "Only one of `initial_state` or `initial_position` must be provided." + ) + + if initial_state is None: + rng_key, init_key = split(rng_key, 2) + initial_state = inference_algorithm.init(initial_position, init_key) + + keys = split(rng_key, num_steps) - @jit - def _one_step(state, xs): + def one_step(state, xs): _, rng_key = xs state, info = inference_algorithm.step(rng_key, state) - return state, (transform(state), info) + return state, transform(state, info) - if progress_bar: - one_step = progress_bar_scan(num_steps)(_one_step) - else: - one_step = _one_step + scan_fn = gen_scan_fn(num_steps, progress_bar) + + xs = jnp.arange(num_steps), keys + final_state, history = scan_fn(one_step, initial_state, xs) + + return final_state, history + + +def store_only_expectation_values( + sampling_algorithm, + state_transform=lambda x: x, + incremental_value_transform=lambda x: x, + burn_in=0, +): + """Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same + kernel but only stores the streaming expectation values of some observables, not the full states; to save memory. + + It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample. + + Example: + + .. code:: + + init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0),3) + model = StandardNormal(2) + initial_position = model.sample_init(init_key) + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=model.logdensity_fn, rng_key=state_key + ) + integrator_type = "mclachlan" + L = 1.0 + step_size = 0.1 + num_steps = 4 + + integrator = map_integrator_type_to_integrator['mclmc'][integrator_type] + state_transform = lambda state: state.position + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, + state_transform=state_transform) + + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( + + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, + ) + """ + + def init_fn(state): + averaging_state = (0.0, state_transform(state)) + return (state, averaging_state) + + def update_fn(rng_key, state_and_incremental_val): + state, averaging_state = state_and_incremental_val + state, info = sampling_algorithm.step( + rng_key, state + ) # update the state with the sampling algorithm + averaging_state = incremental_value_update( + state_transform(state), + averaging_state, + weight=( + averaging_state[0] >= burn_in + ), # If we want to eliminate some number of steps as a burn-in + zero_prevention=1e-10 * (burn_in > 0), + ) + # update the expectation value with the running average + return (state, averaging_state), info + + def transform(state_and_incremental_val, info): + (state, (_, incremental_value)) = state_and_incremental_val + return incremental_value_transform(incremental_value), info + + return SamplingAlgorithm(init_fn, update_fn), transform + + +def incremental_value_update( + expectation, incremental_val, weight=1.0, zero_prevention=0.0 +): + """Compute the streaming average of a function O(x) using a weight. + Parameters: + ---------- + expectation + the value of the expectation at the current timestep + incremental_val + tuple of (total, average) where total is the sum of weights and average is the current average + weight + weight of the current state + zero_prevention + small value to prevent division by zero + Returns: + ---------- + new streaming average + """ - xs = (jnp.arange(num_steps), keys) - final_state, (state_history, info_history) = lax.scan(one_step, initial_state, xs) - return final_state, state_history, info_history + total, average = incremental_val + average = tree_map( + lambda exp, av: (total * av + weight * exp) + / (total + weight + zero_prevention), + expectation, + average, + ) + total += weight + return total, average diff --git a/blackjax/vi/meanfield_vi.py b/blackjax/vi/meanfield_vi.py index 8d5defa15..6f379c7b0 100644 --- a/blackjax/vi/meanfield_vi.py +++ b/blackjax/vi/meanfield_vi.py @@ -27,7 +27,7 @@ "sample", "generate_meanfield_logdensity", "step", - "meanfield_vi", + "as_top_level_api", ] @@ -48,8 +48,8 @@ def init( **optimizer_kwargs, ) -> MFVIState: """Initialize the mean-field VI state.""" - mu = jax.tree_map(jnp.zeros_like, position) - rho = jax.tree_map(lambda x: -2.0 * jnp.ones_like(x), position) + mu = jax.tree.map(jnp.zeros_like, position) + rho = jax.tree.map(lambda x: -2.0 * jnp.ones_like(x), position) opt_state = optimizer.init((mu, rho)) return MFVIState(mu, rho, opt_state) @@ -99,7 +99,7 @@ def kl_divergence_fn(parameters): elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters) updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters) - new_parameters = jax.tree_map(lambda p, u: p + u, parameters, updates) + new_parameters = jax.tree.map(lambda p, u: p + u, parameters, updates) new_state = MFVIState(new_parameters[0], new_parameters[1], new_opt_state) return new_state, MFVIInfo(elbo) @@ -109,7 +109,11 @@ def sample(rng_key: PRNGKey, state: MFVIState, num_samples: int = 1): return _sample(rng_key, state.mu, state.rho, num_samples) -class meanfield_vi: +def as_top_level_api( + logdensity_fn: Callable, + optimizer: GradientTransformation, + num_samples: int = 100, +): """High-level implementation of Mean-Field Variational Inference. Parameters @@ -128,30 +132,20 @@ class meanfield_vi: """ - init = staticmethod(init) - step = staticmethod(step) - sample = staticmethod(sample) - - def __new__( - cls, - logdensity_fn: Callable, - optimizer: GradientTransformation, - num_samples: int = 100, - ): # type: ignore[misc] - def init_fn(position: ArrayLikeTree): - return cls.init(position, optimizer) + def init_fn(position: ArrayLikeTree): + return init(position, optimizer) - def step_fn(rng_key: PRNGKey, state: MFVIState) -> tuple[MFVIState, MFVIInfo]: - return cls.step(rng_key, state, logdensity_fn, optimizer, num_samples) + def step_fn(rng_key: PRNGKey, state: MFVIState) -> tuple[MFVIState, MFVIInfo]: + return step(rng_key, state, logdensity_fn, optimizer, num_samples) - def sample_fn(rng_key: PRNGKey, state: MFVIState, num_samples: int): - return cls.sample(rng_key, state, num_samples) + def sample_fn(rng_key: PRNGKey, state: MFVIState, num_samples: int): + return sample(rng_key, state, num_samples) - return VIAlgorithm(init_fn, step_fn, sample_fn) + return VIAlgorithm(init_fn, step_fn, sample_fn) def _sample(rng_key, mu, rho, num_samples): - sigma = jax.tree_map(jnp.exp, rho) + sigma = jax.tree.map(jnp.exp, rho) mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma) flatten_sample = ( @@ -162,11 +156,11 @@ def _sample(rng_key, mu, rho, num_samples): def generate_meanfield_logdensity(mu, rho): - sigma_param = jax.tree_map(jnp.exp, rho) + sigma_param = jax.tree.map(jnp.exp, rho) def meanfield_logdensity(position): - logq_pytree = jax.tree_map(jsp.stats.norm.logpdf, position, mu, sigma_param) - logq = jax.tree_map(jnp.sum, logq_pytree) + logq_pytree = jax.tree.map(jsp.stats.norm.logpdf, position, mu, sigma_param) + logq = jax.tree.map(jnp.sum, logq_pytree) return jax.tree_util.tree_reduce(jnp.add, logq) return meanfield_logdensity diff --git a/blackjax/vi/pathfinder.py b/blackjax/vi/pathfinder.py index 504d0a2a0..c1b7dc113 100644 --- a/blackjax/vi/pathfinder.py +++ b/blackjax/vi/pathfinder.py @@ -25,7 +25,7 @@ ) from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["PathfinderState", "approximate", "sample", "pathfinder"] +__all__ = ["PathfinderState", "approximate", "sample", "as_top_level_api"] class PathfinderState(NamedTuple): @@ -197,7 +197,7 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad): ) max_elbo_idx = jnp.argmax(elbo) - return jax.tree_map(lambda x: x[max_elbo_idx], pathfinder_result), PathfinderInfo( + return jax.tree.map(lambda x: x[max_elbo_idx], pathfinder_result), PathfinderInfo( pathfinder_result ) @@ -242,7 +242,7 @@ def sample( return jax.vmap(unravel_fn)(phi), logq -class pathfinder: +def as_top_level_api(logdensity_fn: Callable) -> PathFinderAlgorithm: """Implements the (basic) user interface for the pathfinder kernel. Pathfinder locates normal approximations to the target density along a @@ -266,21 +266,17 @@ class pathfinder: """ - approximate = staticmethod(approximate) - sample = staticmethod(sample) - - def __new__(cls, logdensity_fn: Callable) -> PathFinderAlgorithm: # type: ignore[misc] - def approximate_fn( - rng_key: PRNGKey, - position: ArrayLikeTree, - num_samples: int = 200, - **lbfgs_parameters, - ): - return cls.approximate( - rng_key, logdensity_fn, position, num_samples, **lbfgs_parameters - ) + def approximate_fn( + rng_key: PRNGKey, + position: ArrayLikeTree, + num_samples: int = 200, + **lbfgs_parameters, + ): + return approximate( + rng_key, logdensity_fn, position, num_samples, **lbfgs_parameters + ) - def sample_fn(rng_key: PRNGKey, state: PathfinderState, num_samples: int): - return cls.sample(rng_key, state, num_samples) + def sample_fn(rng_key: PRNGKey, state: PathfinderState, num_samples: int): + return sample(rng_key, state, num_samples) - return PathFinderAlgorithm(approximate_fn, sample_fn) + return PathFinderAlgorithm(approximate_fn, sample_fn) diff --git a/blackjax/vi/schrodinger_follmer.py b/blackjax/vi/schrodinger_follmer.py index 07d0186dc..51d1e88fe 100644 --- a/blackjax/vi/schrodinger_follmer.py +++ b/blackjax/vi/schrodinger_follmer.py @@ -55,7 +55,7 @@ class SchrodingerFollmerInfo(NamedTuple): def init(example_position: ArrayLikeTree) -> SchrodingerFollmerState: - zero = jax.tree_map(jnp.zeros_like, example_position) + zero = jax.tree.map(jnp.zeros_like, example_position) return SchrodingerFollmerState(zero, 0.0) @@ -95,7 +95,7 @@ def step( eps_drift = jax.random.normal(drift_key, (n_samples,) + ravelled_position.shape) eps_drift = jax.vmap(unravel_fn)(eps_drift) - perturbed_position = jax.tree_map( + perturbed_position = jax.tree.map( lambda a, b: a[None, ...] + scale * b, state.position, eps_drift ) @@ -105,14 +105,14 @@ def step( log_pdf -= jnp.max(log_pdf, axis=0, keepdims=True) pdf = jnp.exp(log_pdf) - num = jax.tree_map(lambda a: pdf @ a, eps_drift) + num = jax.tree.map(lambda a: pdf @ a, eps_drift) den = scale * jnp.sum(pdf, axis=0) - drift = jax.tree_map(lambda a: a / den, num) + drift = jax.tree.map(lambda a: a / den, num) eps_sde = jax.random.normal(sde_key, ravelled_position.shape) eps_sde = unravel_fn(eps_sde) - next_position = jax.tree_map( + next_position = jax.tree.map( lambda a, b, c: a + step_size * b + step_size**0.5 * c, state.position, drift, @@ -151,20 +151,20 @@ def sample( dt = 1.0 / n_steps initial_position = initial_state.position - initial_positions = jax.tree_map( + initial_positions = jax.tree.map( lambda a: jnp.zeros([n_samples, *a.shape], dtype=a.dtype), initial_position ) initial_states = SchrodingerFollmerState(initial_positions, jnp.zeros((n_samples,))) - def body(_, carry): - key, states = carry - keys = jax.random.split(key, 1 + n_samples) - states, _ = jax.vmap(step, [0, 0, None, None, None])( - keys[1:], states, log_density_fn, dt, n_inner_samples + def body(i, states): + subkey = jax.random.fold_in(rng_key, i) + keys = jax.random.split(subkey, n_samples) + next_states, _ = jax.vmap(step, [0, 0, None, None, None])( + keys, states, log_density_fn, dt, n_inner_samples ) - return keys[0], states + return next_states - _, final_states = jax.lax.fori_loop(0, n_steps, body, (rng_key, initial_states)) + final_states = jax.lax.fori_loop(0, n_steps, body, initial_states) return final_states @@ -176,12 +176,12 @@ def _log_fn_corrected(position, logdensity_fn): This corrects the gradient of the log-density function to account for this. """ log_pdf_val = logdensity_fn(position) - norm = jax.tree_map(lambda a: 0.5 * jnp.sum(a**2), position) + norm = jax.tree.map(lambda a: 0.5 * jnp.sum(a**2), position) norm = sum(tree_leaves(norm)) return log_pdf_val + norm -class schrodinger_follmer: +def as_top_level_api(logdensity_fn: Callable, n_steps: int, n_inner_samples: int) -> VIAlgorithm: # type: ignore[misc] """Implements the (basic) user interface for the Schrödinger-Föllmer algortithm :cite:p:`huang2021schrodingerfollmer`. The Schrödinger-Föllmer algorithm obtains (approximate) samples from the target distribution by means of a diffusion with @@ -202,22 +202,17 @@ class schrodinger_follmer: """ - init = staticmethod(init) - step = staticmethod(step) - sample = staticmethod(sample) + def init_fn(position: ArrayLikeTree): + return init(position) - def __new__(cls, logdensity_fn: Callable, n_steps: int, n_inner_samples: int) -> VIAlgorithm: # type: ignore[misc] - def init_fn(position: ArrayLikeTree): - return cls.init(position) + def step_fn( + rng_key: PRNGKey, state: SchrodingerFollmerState + ) -> tuple[SchrodingerFollmerState, SchrodingerFollmerInfo]: + return step(rng_key, state, logdensity_fn, 1 / n_steps, n_inner_samples) - def step_fn( - rng_key: PRNGKey, state: SchrodingerFollmerState - ) -> tuple[SchrodingerFollmerState, SchrodingerFollmerInfo]: - return cls.step(rng_key, state, logdensity_fn, 1 / n_steps, n_inner_samples) - - def sample_fn(rng_key: PRNGKey, state: SchrodingerFollmerState, n_samples: int): - return cls.sample( - rng_key, state, logdensity_fn, n_steps, n_inner_samples, n_samples - ) + def sample_fn(rng_key: PRNGKey, state: SchrodingerFollmerState, n_samples: int): + return sample( + rng_key, state, logdensity_fn, n_steps, n_inner_samples, n_samples + ) - return VIAlgorithm(init_fn, step_fn, sample_fn) + return VIAlgorithm(init_fn, step_fn, sample_fn) diff --git a/blackjax/vi/svgd.py b/blackjax/vi/svgd.py index f93941aee..e287b813f 100644 --- a/blackjax/vi/svgd.py +++ b/blackjax/vi/svgd.py @@ -9,7 +9,13 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree -__all__ = ["svgd", "rbf_kernel", "update_median_heuristic"] +__all__ = [ + "as_top_level_api", + "init", + "build_kernel", + "rbf_kernel", + "update_median_heuristic", +] class SVGDState(NamedTuple): @@ -123,8 +129,13 @@ def update_median_heuristic(state: SVGDState) -> SVGDState: return SVGDState(position, median_heuristic(kernel_parameters, position), opt_state) -class svgd: - """Implements the (basic) user interface for the svgd algorithm. +def as_top_level_api( + grad_logdensity_fn: Callable, + optimizer, + kernel: Callable = rbf_kernel, + update_kernel_parameters: Callable = update_median_heuristic, +): + """Implements the (basic) user interface for the svgd algorithm :cite:p:`liu2016stein`. Parameters ---------- @@ -142,26 +153,16 @@ class svgd: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel_ = build_kernel(optimizer) - def __new__( - cls, - grad_logdensity_fn: Callable, - optimizer, - kernel: Callable = rbf_kernel, - update_kernel_parameters: Callable = update_median_heuristic, + def init_fn( + initial_position: ArrayLikeTree, + kernel_parameters: dict[str, Any] = {"length_scale": 1.0}, ): - kernel_ = cls.build_kernel(optimizer) - - def init_fn( - initial_position: ArrayLikeTree, - kernel_parameters: dict[str, Any] = {"length_scale": 1.0}, - ): - return cls.init(initial_position, kernel_parameters, optimizer) + return init(initial_position, kernel_parameters, optimizer) - def step_fn(state, **grad_params): - state = kernel_(state, grad_logdensity_fn, kernel, **grad_params) - return update_kernel_parameters(state) + def step_fn(state, **grad_params): + state = kernel_(state, grad_logdensity_fn, kernel, **grad_params) + return update_kernel_parameters(state) - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/docs/examples/howto_custom_gradients.md b/docs/examples/howto_custom_gradients.md index 731de0eea..653e6d393 100644 --- a/docs/examples/howto_custom_gradients.md +++ b/docs/examples/howto_custom_gradients.md @@ -29,10 +29,9 @@ Functions can be defined as the minimum of another one, $f(x) = min_{y} g(x,y)$. Our example is taken from the theory of [convex conjugates](https://en.wikipedia.org/wiki/Convex_conjugate), used for example in optimal transport. Let's consider the following function: $$ -\begin{align*} -g(x, y) &= h(y) - \langle x, y\rangle\\ -h(x) &= \frac{1}{p}|x|^p,\qquad p > 1.\\ -\end{align*} +\begin{equation*} +g(x, y) = h(y) - \langle x, y\rangle,\qquad h(x) = \frac{1}{p}|x|^p,\qquad p > 1. +\end{equation*} $$ And define the function $f$ as $f(x) = -min_y g(x, y)$ which we can be implemented as: @@ -69,7 +68,7 @@ Note the we also return the value of $y$ where the minimum of $g$ is achieved (t ### Trying to differentate the function with `jax.grad` -The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops, and trying to compute it directly raises an error: +The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops used in BFGS, and trying to compute it directly raises an error: ```{code-cell} ipython3 # We only want the gradient with respect to `x` @@ -97,7 +96,7 @@ The first order optimality criterion \end{equation*} ``` -Ensures that: +ensures that ```{math} \begin{equation*} @@ -105,7 +104,7 @@ Ensures that: \end{equation*} ``` -i.e. the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved. +In other words, the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved. ### Telling JAX to use a custom gradient diff --git a/docs/examples/howto_metropolis_within_gibbs.md b/docs/examples/howto_metropolis_within_gibbs.md index 44f7ed9bb..e3edb8b6b 100644 --- a/docs/examples/howto_metropolis_within_gibbs.md +++ b/docs/examples/howto_metropolis_within_gibbs.md @@ -325,7 +325,7 @@ positions_general = sampling_loop_general( ### Check Result ```{code-cell} ipython3 -jax.tree_map(lambda x, y: jnp.max(jnp.abs(x-y)), positions, positions_general) +jax.tree.map(lambda x, y: jnp.max(jnp.abs(x-y)), positions, positions_general) ``` ## Developer Notes diff --git a/docs/examples/howto_reproduce_the_blackjax_image.md b/docs/examples/howto_reproduce_the_blackjax_image.md index 8320c10ac..6e3ccdb73 100644 --- a/docs/examples/howto_reproduce_the_blackjax_image.md +++ b/docs/examples/howto_reproduce_the_blackjax_image.md @@ -139,12 +139,12 @@ def smc_inference_loop(loop_key, smc_kernel, init_state, schedule): """ def body_fn(carry, lmbda): - carry_key, state = carry - carry_key, subkey = jax.random.split(carry_key) + i, state = carry + subkey = jax.random.fold_in(loop_key, i) new_state, info = smc_kernel(subkey, state, lmbda) - return (rng_key, new_state), (new_state, info) + return (i + 1, new_state), (new_state, info) - _, (all_samples, _) = jax.lax.scan(body_fn, (loop_key, init_state), schedule) + _, (all_samples, _) = jax.lax.scan(body_fn, (0, init_state), schedule) return all_samples diff --git a/docs/index.md b/docs/index.md index 59e4b3f0d..fca4787c4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -39,9 +39,9 @@ state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) step = jax.jit(nuts.step) -for _ in range(1_000): - rng_key, nuts_key = jax.random.split(rng_key) - state, _ = nuts.step(nuts_key, state) +for i in range(1_000): + nuts_key = jax.random.fold_in(rng_key, i) + state, _ = step(nuts_key, state) ``` :::{note} @@ -57,13 +57,7 @@ If you want to use Blackjax with a model implemented with a PPL, go to the relat ```{code-block} bash pip install blackjax ``` -::: -:::{tab-item} Nightly -```{code-block} bash -pip install blackjax-nightly -``` -::: :::{tab-item} Conda ```{code-block} bash @@ -114,6 +108,7 @@ Sample with multiple chains? Use custom gradients? Use non-JAX log-prob functions? Build a Metropolis-Within-Gibbs sampler? +Sample from the word BlackJAX using BlackJAX? ``` ```{toctree} diff --git a/docs/refs.bib b/docs/refs.bib index c5e66ca41..dda41909d 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -423,3 +423,25 @@ @misc{huang2021schrodingerfollmer archivePrefix={arXiv}, primaryClass={stat.CO} } + +@article{omelyan2003symplectic, + title={Symplectic analytically integrable decomposition algorithms: classification, derivation, and application to molecular dynamics, quantum and celestial mechanics simulations}, + author={Omelyan, IP and Mryglod, IM and Folk, R}, + journal={Computer Physics Communications}, + volume={151}, + number={3}, + pages={272--314}, + year={2003}, + publisher={Elsevier} +} + +@article{takaishi2006testing, + title={Testing and tuning symplectic integrators for the hybrid Monte Carlo algorithm in lattice QCD}, + author={Takaishi, Tetsuya and De Forcrand, Philippe}, + journal={Physical Review E}, + volume={73}, + number={3}, + pages={036706}, + year={2006}, + publisher={APS} +} diff --git a/requirements-doc.txt b/requirements-doc.txt index 338073a88..83af1ffe3 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -4,8 +4,8 @@ aesara>=2.8.8 arviz flax ipython -jax>=0.4.16 -jaxlib>=0.4.16 +jax>=0.4.25 +jaxlib>=0.4.25 jaxopt jupytext myst_nb>=1.0.0 diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 286bf30aa..4b34511be 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -6,6 +6,7 @@ import blackjax from blackjax.adaptation import window_adaptation +from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info from blackjax.util import run_inference_algorithm @@ -34,7 +35,32 @@ def test_adaptation_schedule(num_steps, expected_schedule): assert np.array_equal(adaptation_schedule, expected_schedule) -def test_chees_adaptation(): +@pytest.mark.parametrize( + "adaptation_filters", + [ + { + "filter_fn": return_all_adapt_info, + "return_sets": None, + }, + { + "filter_fn": get_filter_adapt_info_fn(), + "return_sets": (set(), set(), set()), + }, + { + "filter_fn": get_filter_adapt_info_fn( + {"logdensity"}, + {"proposal"}, + {"random_generator_arg", "step", "da_state"}, + ), + "return_sets": ( + {"logdensity"}, + {"proposal"}, + {"random_generator_arg", "step", "da_state"}, + ), + }, + ], +) +def test_chees_adaptation(adaptation_filters): logprob_fn = lambda x: jax.scipy.stats.norm.logpdf( x, loc=0.0, scale=jnp.array([1.0, 10.0]) ).sum() @@ -44,10 +70,13 @@ def test_chees_adaptation(): num_chains = 16 step_size = 0.1 - init_key, warmup_key, inference_key = jax.random.split(jax.random.key(0), 3) + init_key, warmup_key, inference_key = jax.random.split(jax.random.key(346), 3) warmup = blackjax.chees_adaptation( - logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75 + logprob_fn, + num_chains=num_chains, + target_acceptance_rate=0.75, + adaptation_info_fn=adaptation_filters["filter_fn"], ) initial_positions = jax.random.normal(init_key, (num_chains, 2)) @@ -61,11 +90,35 @@ def test_chees_adaptation(): algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, _, infos = jax.vmap( - lambda key, state: run_inference_algorithm(key, state, algorithm, num_results) + _, (_, infos) = jax.vmap( + lambda key, state: run_inference_algorithm( + rng_key=key, + initial_state=state, + inference_algorithm=algorithm, + num_steps=num_results, + ) )(chain_keys, last_states) harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate) - np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1) + + def check_attrs(attribute, keyset): + for name, param in getattr(warmup_info, attribute)._asdict().items(): + print(name, param) + if name in keyset: + assert param is not None + else: + assert param is None + + keysets = adaptation_filters["return_sets"] + if keysets is None: + keysets = ( + warmup_info.state._fields, + warmup_info.info._fields, + warmup_info.adaptation_state._fields, + ) + for i, attribute in enumerate(["state", "info", "adaptation_state"]): + check_attrs(attribute, keysets[i]) + + np.testing.assert_allclose(harmonic_mean, 0.75, atol=1e-1) np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1) - np.testing.assert_allclose(infos.num_integration_steps.mean(), 15.0, rtol=3e-1) + np.testing.assert_array_less(infos.num_integration_steps.mean(), 15.0) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index ddb13ad57..c38009e5e 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -134,15 +134,17 @@ def kinetic_energy(p, position=None): algorithms = { "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, - "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, - "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, + "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-4}, + "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-4}, + "omelyan": {"algorithm": integrators.omelyan, "precision": 1e-4}, "implicit_midpoint": { "algorithm": integrators.implicit_midpoint, "precision": 1e-4, }, - "isokinetic_leapfrog": {"algorithm": integrators.isokinetic_leapfrog}, + "isokinetic_velocity_verlet": {"algorithm": integrators.isokinetic_velocity_verlet}, "isokinetic_mclachlan": {"algorithm": integrators.isokinetic_mclachlan}, "isokinetic_yoshida": {"algorithm": integrators.isokinetic_yoshida}, + "isokinetic_omelyan": {"algorithm": integrators.isokinetic_omelyan}, } @@ -168,6 +170,7 @@ class IntegratorTest(chex.TestCase): "velocity_verlet", "mclachlan", "yoshida", + "omelyan", "implicit_midpoint", ], ) @@ -234,18 +237,20 @@ def test_esh_momentum_update(self, dims): ) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta))) # Efficient implementation - update_stable = self.variant(esh_dynamics_momentum_update_one_step) + update_stable = self.variant( + esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @chex.all_variants(with_pmap=False) - def test_isokinetic_leapfrog(self): + def test_isokinetic_velocity_verlet(self): cov = jnp.asarray([[1.0, 0.5, 0.1], [0.5, 2.0, -0.1], [0.1, -0.1, 3.0]]) logdensity_fn = lambda x: stats.multivariate_normal.logpdf( x, jnp.zeros([3]), cov ) - step = self.variant(integrators.isokinetic_leapfrog(logdensity_fn)) + step = self.variant(integrators.isokinetic_velocity_verlet(logdensity_fn)) rng = jax.random.key(4263456) key0, key1 = jax.random.split(rng, 2) @@ -258,7 +263,7 @@ def test_isokinetic_leapfrog(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step + op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( @@ -294,9 +299,10 @@ def test_isokinetic_leapfrog(self): @chex.all_variants(with_pmap=False) @parameterized.parameters( [ - "isokinetic_leapfrog", + "isokinetic_velocity_verlet", "isokinetic_mclachlan", "isokinetic_yoshida", + "isokinetic_omelyan", ], ) def test_isokinetic_integrator(self, integrator_name): diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 7d20805ab..c399929da 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -13,6 +13,8 @@ import blackjax import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk +from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info +from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -27,12 +29,12 @@ def sample_orbit(orbit, weights, rng_key): return samples -def irmh_proposal_distribution(rng_key): +def irmh_proposal_distribution(rng_key, mean): """ The proposal distribution is chosen to be wider than the target, so that the RMH rejection doesn't make the sample overemphasize the center of the target distribution. """ - return 1.0 + jax.random.normal(rng_key) * 25.0 + return mean + jax.random.normal(rng_key) * 25.0 def rmh_proposal_distribution(rng_key, position): @@ -56,6 +58,27 @@ def rmh_proposal_distribution(rng_key, position): }, ] +window_adaptation_filters = [ + { + "filter_fn": return_all_adapt_info, + "return_sets": None, + }, + { + "filter_fn": get_filter_adapt_info_fn(), + "return_sets": (set(), set(), set()), + }, + { + "filter_fn": get_filter_adapt_info_fn( + {"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"} + ), + "return_sets": ( + {"position"}, + {"is_divergent"}, + {"ss_state", "inverse_mass_matrix"}, + ), + }, +] + class LinearRegressionTest(chex.TestCase): """Test sampling of a linear regression model.""" @@ -74,16 +97,24 @@ def regression_logprob(self, log_scale, coefs, preds, x): # reduce sum otherwise broacasting will make the logprob biased. return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) - def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): + def run_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): init_key, tune_key, run_key = jax.random.split(key, 3) initial_state = blackjax.mcmc.mclmc.init( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, + sqrt_diag_cov=sqrt_diag_cov, ) ( @@ -94,26 +125,34 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): num_steps=num_steps, state=initial_state, rng_key=tune_key, + diagonal_preconditioning=diagonal_preconditioning, ) sampling_alg = blackjax.mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, ) - _, samples, _ = run_inference_algorithm( + _, samples = run_inference_algorithm( rng_key=run_key, - initial_state_or_position=blackjax_state_after_tuning, + initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, num_steps=num_steps, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) return samples - @parameterized.parameters(itertools.product(regression_test_cases, [True, False])) - def test_window_adaptation(self, case, is_mass_matrix_diagonal): + @parameterized.parameters( + itertools.product( + regression_test_cases, [True, False], window_adaptation_filters + ) + ) + def test_window_adaptation( + self, case, is_mass_matrix_diagonal, window_adapt_config + ): """Test the HMC kernel and the Stan warmup.""" rng_key, init_key0, init_key1 = jax.random.split(self.key, 3) x_data = jax.random.normal(init_key0, shape=(1000, 1)) @@ -131,17 +170,38 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): logposterior_fn, is_mass_matrix_diagonal, progress_bar=True, + adaptation_info_fn=window_adapt_config["filter_fn"], **case["parameters"], ) - (state, parameters), _ = warmup.run( + (state, parameters), info = warmup.run( warmup_key, case["initial_position"], case["num_warmup_steps"], ) inference_algorithm = case["algorithm"](logposterior_fn, **parameters) - _, states, _ = run_inference_algorithm( - inference_key, state, inference_algorithm, case["num_sampling_steps"] + def check_attrs(attribute, keyset): + for name, param in getattr(info, attribute)._asdict().items(): + if name in keyset: + assert param is not None + else: + assert param is None + + keysets = window_adapt_config["return_sets"] + if keysets is None: + keysets = ( + info.state._fields, + info.info._fields, + info.adaptation_state._fields, + ) + for i, attribute in enumerate(["state", "info", "adaptation_state"]): + check_attrs(attribute, keysets[i]) + + _, (states, _) = run_inference_algorithm( + rng_key=inference_key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=case["num_sampling_steps"], ) coefs_samples = states.position["coefs"] @@ -163,10 +223,16 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm(inference_key, state, mala, 10_000) + _, states = run_inference_algorithm( + rng_key=inference_key, + initial_state=state, + inference_algorithm=mala, + transform=lambda state, info: state.position, + num_steps=10_000, + ) - coefs_samples = states.position["coefs"][3000:] - scale_samples = np.exp(states.position["log_scale"][3000:]) + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -195,6 +261,88 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + def test_mclmc_preconditioning(self): + class IllConditionedGaussian: + """Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2.""" + + def __init__(self, d, condition_number): + """numpy_seed is used to generate a random rotation for the covariance matrix. + If None, the covariance matrix is diagonal.""" + + self.ndims = d + self.name = "IllConditionedGaussian" + self.condition_number = condition_number + eigs = jnp.logspace( + -0.5 * jnp.log10(condition_number), + 0.5 * jnp.log10(condition_number), + d, + ) + self.E_x2 = eigs + self.R = jnp.eye(d) + self.Hessian = jnp.diag(1 / eigs) + self.Cov = jnp.diag(eigs) + self.Var_x2 = 2 * jnp.square(self.E_x2) + + self.logdensity_fn = lambda x: -0.5 * x.T @ self.Hessian @ x + self.transform = lambda x: x + + self.sample_init = lambda key: jax.random.normal( + key, shape=(self.ndims,) + ) * jnp.max(jnp.sqrt(eigs)) + + dim = 100 + condition_number = 10 + eigs = jnp.logspace( + -0.5 * jnp.log10(condition_number), 0.5 * jnp.log10(condition_number), dim + ) + model = IllConditionedGaussian(dim, condition_number) + num_steps = 20000 + key = jax.random.PRNGKey(2) + + integrator = isokinetic_mclachlan + + def get_sqrt_diag_cov(): + init_key, tune_key = jax.random.split(key) + + initial_position = model.sample_init(init_key) + + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, + logdensity_fn=model.logdensity_fn, + rng_key=init_key, + ) + + kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=model.logdensity_fn, + integrator=integrator, + sqrt_diag_cov=sqrt_diag_cov, + ) + + ( + _, + blackjax_mclmc_sampler_params, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + diagonal_preconditioning=True, + ) + + return blackjax_mclmc_sampler_params.sqrt_diag_cov + + sqrt_diag_cov = get_sqrt_diag_cov() + assert ( + jnp.abs( + jnp.dot( + (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), + eigs / jnp.linalg.norm(eigs), + ) + - 1 + ) + < 0.1 + ) + @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( self, @@ -228,12 +376,16 @@ def test_pathfinder_adaptation( ) inference_algorithm = algorithm(logposterior_fn, **parameters) - _, states, _ = run_inference_algorithm( - inference_key, state, inference_algorithm, num_sampling_steps + _, states = run_inference_algorithm( + rng_key=inference_key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=num_sampling_steps, + transform=lambda state, info: state.position, ) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -268,14 +420,18 @@ def test_meads(self): inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( - key, state, inference_algorithm, 100 + rng_key=key, + initial_state=state, + inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, + num_steps=100, ) )(chain_keys, last_states) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -312,14 +468,18 @@ def test_chees(self, jitter_generator): inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( - key, state, inference_algorithm, 100 + rng_key=key, + initial_state=state, + inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, + num_steps=100, ) )(chain_keys, last_states) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -338,10 +498,16 @@ def test_barker(self): barker = blackjax.barker_proposal(logposterior_fn, 1e-1) state = barker.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000) + _, states = run_inference_algorithm( + rng_key=inference_key, + initial_state=state, + inference_algorithm=barker, + transform=lambda state, info: state.position, + num_steps=10_000, + ) - coefs_samples = states.position["coefs"][3000:] - scale_samples = np.exp(states.position["log_scale"][3000:]) + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @@ -457,7 +623,7 @@ def test_linear_regression_sghmc_cv(self): _ = sghmc.step(rng_key, init_position, data_batch, 1e-3) def test_linear_regression_sgnht(self): - rng_key, data_key = jax.random.split(self.key, 2) + step_key, data_key = jax.random.split(self.key, 2) data_size = 1000 X_data = jax.random.normal(data_key, shape=(data_size, 5)) @@ -467,15 +633,14 @@ def test_linear_regression_sgnht(self): ) sgnht = blackjax.sgnht(grad_fn) - _, rng_key = jax.random.split(rng_key) data_batch = X_data[100:200, :] init_position = 1.0 data_batch = X_data[:100, :] init_state = sgnht.init(init_position, self.key) - _ = sgnht.step(rng_key, init_state, data_batch, 1e-3) + _ = sgnht.step(step_key, init_state, data_batch, 1e-3) def test_linear_regression_sgnhtc_cv(self): - rng_key, data_key = jax.random.split(self.key, 2) + step_key, data_key = jax.random.split(self.key, 2) data_size = 1000 X_data = jax.random.normal(data_key, shape=(data_size, 5)) @@ -490,11 +655,10 @@ def test_linear_regression_sgnhtc_cv(self): sgnht = blackjax.sgnht(cv_grad_fn) - _, rng_key = jax.random.split(rng_key) init_position = 1.0 data_batch = X_data[:100, :] init_state = sgnht.init(init_position, self.key) - _ = sgnht.step(rng_key, init_state, data_batch, 1e-3) + _ = sgnht.step(step_key, init_state, data_batch, 1e-3) class LatentGaussianTest(chex.TestCase): @@ -520,19 +684,20 @@ def test_latent_gaussian(self): initial_state = inference_algorithm.init(jnp.zeros((1,))) - _, states, _ = self.variant( + _, states = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=self.sampling_steps, ), - )(self.key, initial_state) + )(rng_key=self.key, initial_state=initial_state) np.testing.assert_allclose( - np.var(states.position[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 + np.var(states[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 ) np.testing.assert_allclose( - np.mean(states.position[self.burnin :]), 2 / 3, rtol=1e-2, atol=1e-2 + np.mean(states[self.burnin :]), 2 / 3, rtol=1e-2, atol=1e-2 ) @@ -541,103 +706,6 @@ def rmhmc_static_mass_matrix_fn(position): return jnp.array([1.0]) -normal_test_cases = [ - { - "algorithm": blackjax.hmc, - "initial_position": jnp.array(3.0), - "parameters": { - "step_size": 3.9, - "inverse_mass_matrix": jnp.array([1.0]), - "num_integration_steps": 30, - }, - "num_sampling_steps": 6000, - "burnin": 1_000, - }, - { - "algorithm": blackjax.nuts, - "initial_position": jnp.array(3.0), - "parameters": {"step_size": 4.0, "inverse_mass_matrix": jnp.array([1.0])}, - "num_sampling_steps": 6000, - "burnin": 1_000, - }, - { - "algorithm": blackjax.orbital_hmc, - "initial_position": jnp.array(100.0), - "parameters": { - "step_size": 0.1, - "inverse_mass_matrix": jnp.array([0.1]), - "period": 100, - }, - "num_sampling_steps": 20_000, - "burnin": 15_000, - }, - { - "algorithm": blackjax.additive_step_random_walk.normal_random_walk, - "initial_position": 1.0, - "parameters": {"sigma": jnp.array([1.0])}, - "num_sampling_steps": 20_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.rmh, - "parameters": {}, - "initial_position": 1.0, - "num_sampling_steps": 20_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.mala, - "initial_position": 1.0, - "parameters": {"step_size": 1e-1}, - "num_sampling_steps": 45_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.elliptical_slice, - "initial_position": 1.0, - "parameters": {"cov": jnp.array([2.0**2]), "mean": 1.0}, - "num_sampling_steps": 20_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.irmh, - "initial_position": jnp.array(1.0), - "parameters": {}, - "num_sampling_steps": 50_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.ghmc, - "initial_position": jnp.array(1.0), - "parameters": { - "step_size": 1.0, - "momentum_inverse_scale": jnp.array(1.0), - "alpha": 0.8, - "delta": 2.0, - }, - "num_sampling_steps": 6000, - "burnin": 1_000, - }, - { - "algorithm": blackjax.barker_proposal, - "initial_position": 1.0, - "parameters": {"step_size": 1.5}, - "num_sampling_steps": 20_000, - "burnin": 2_000, - }, - { - "algorithm": blackjax.rmhmc, - "initial_position": jnp.array(3.0), - "parameters": { - "step_size": 1.0, - "num_integration_steps": 30, - }, - "num_sampling_steps": 6000, - "burnin": 1_000, - }, -] - - class UnivariateNormalTest(chex.TestCase): """Test sampling of a univariate Normal distribution. @@ -651,50 +719,180 @@ def setUp(self): def normal_logprob(self, x): return stats.norm.logpdf(x, loc=1.0, scale=2.0) - @chex.all_variants(with_pmap=False) - @parameterized.parameters(normal_test_cases) - def test_univariate_normal( - self, algorithm, initial_position, parameters, num_sampling_steps, burnin + def univariate_normal_test_case( + self, + inference_algorithm, + rng_key, + initial_state, + num_sampling_steps, + burnin, + postprocess_samples=None, + **kwargs, ): - if algorithm == blackjax.irmh: - parameters["proposal_distribution"] = irmh_proposal_distribution - - if algorithm == blackjax.rmh: - parameters["proposal_generator"] = rmh_proposal_distribution - - if algorithm == blackjax.rmhmc: - parameters["mass_matrix"] = rmhmc_static_mass_matrix_fn - - inference_algorithm = algorithm(self.normal_logprob, **parameters) - rng_key = self.key - if algorithm == blackjax.elliptical_slice: - inference_algorithm = algorithm(lambda x: jnp.ones_like(x), **parameters) - if algorithm == blackjax.ghmc: - rng_key, initial_state_key = jax.random.split(rng_key) - initial_state = inference_algorithm.init( - initial_position, initial_state_key - ) - else: - initial_state = inference_algorithm.init(initial_position) - inference_key, orbit_key = jax.random.split(rng_key) - _, states, _ = self.variant( + _, (states, info) = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, num_steps=num_sampling_steps, + **kwargs, ) - )(inference_key, initial_state) + )(rng_key=inference_key, initial_state=initial_state) - if algorithm == blackjax.orbital_hmc: - samples = orbit_samples( - states.positions[burnin:], states.weights[burnin:], orbit_key - ) + if postprocess_samples: + samples = postprocess_samples(states, orbit_key) else: samples = states.position[burnin:] np.testing.assert_allclose(np.mean(samples), 1.0, rtol=1e-1) np.testing.assert_allclose(np.var(samples), 4.0, rtol=1e-1) + @chex.all_variants(with_pmap=False) + def test_irmh(self): + inference_algorithm = blackjax.irmh( + self.normal_logprob, + proposal_distribution=functools.partial( + irmh_proposal_distribution, mean=1.0 + ), + ) + initial_state = inference_algorithm.init(jnp.array(1.0)) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 50000, 5000 + ) + + @chex.all_variants(with_pmap=False) + def test_nuts(self): + inference_algorithm = blackjax.nuts( + self.normal_logprob, step_size=4.0, inverse_mass_matrix=jnp.array([1.0]) + ) + + initial_state = inference_algorithm.init(jnp.array(3.0)) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 5000, 1000 + ) + + @chex.all_variants(with_pmap=False) + def test_rmh(self): + inference_algorithm = blackjax.rmh( + self.normal_logprob, proposal_generator=rmh_proposal_distribution + ) + initial_state = inference_algorithm.init(1.0) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 20_000, 5_000 + ) + + @chex.all_variants(with_pmap=False) + def test_rmhmc(self): + inference_algorithm = blackjax.rmhmc( + self.normal_logprob, + mass_matrix=rmhmc_static_mass_matrix_fn, + step_size=1.0, + num_integration_steps=30, + ) + + initial_state = inference_algorithm.init(jnp.array(3.0)) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 6_000, 1_000 + ) + + @chex.all_variants(with_pmap=False) + def test_elliptical_slice(self): + inference_algorithm = blackjax.elliptical_slice( + lambda x: jnp.ones_like(x), cov=jnp.array([2.0**2]), mean=1.0 + ) + + initial_state = inference_algorithm.init(1.0) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 20_000, 5_000 + ) + + @chex.all_variants(with_pmap=False) + def test_ghmc(self): + rng_key, initial_state_key = jax.random.split(self.key) + inference_algorithm = blackjax.ghmc( + self.normal_logprob, + step_size=1.0, + momentum_inverse_scale=jnp.array(1.0), + alpha=0.8, + delta=2.0, + ) + initial_state = inference_algorithm.init(jnp.array(1.0), initial_state_key) + self.univariate_normal_test_case( + inference_algorithm, rng_key, initial_state, 6000, 1000 + ) + + @chex.all_variants(with_pmap=False) + def test_hmc(self): + rng_key, initial_state_key = jax.random.split(self.key) + inference_algorithm = blackjax.hmc( + self.normal_logprob, + step_size=3.9, + inverse_mass_matrix=jnp.array([1.0]), + num_integration_steps=30, + ) + initial_state = inference_algorithm.init(jnp.array(3.0)) + self.univariate_normal_test_case( + inference_algorithm, rng_key, initial_state, 6000, 1000 + ) + + @chex.all_variants(with_pmap=False) + def test_orbital_hmc(self): + inference_algorithm = blackjax.orbital_hmc( + self.normal_logprob, + step_size=0.1, + inverse_mass_matrix=jnp.array([0.1]), + period=100, + ) + initial_state = inference_algorithm.init(jnp.array(100.0)) + burnin = 15_000 + + def postprocess_samples(states, key): + positions, weights = states + return orbit_samples(positions[burnin:], weights[burnin:], key) + + self.univariate_normal_test_case( + inference_algorithm, + self.key, + initial_state, + 20_000, + burnin, + postprocess_samples, + transform=lambda state, info: ((state.positions, state.weights), info), + ) + + @chex.all_variants(with_pmap=False) + def test_random_walk(self): + inference_algorithm = blackjax.additive_step_random_walk.normal_random_walk( + self.normal_logprob, sigma=jnp.array([1.0]) + ) + initial_state = inference_algorithm.init(jnp.array(1.0)) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 20_000, 5_000 + ) + + @chex.all_variants(with_pmap=False) + def test_mala(self): + inference_algorithm = blackjax.mala(self.normal_logprob, step_size=0.2) + initial_state = inference_algorithm.init(jnp.array(1.0)) + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 45000, 5_000 + ) + + @chex.all_variants(with_pmap=False) + def test_barker(self): + inference_algorithm = blackjax.barker_proposal( + self.normal_logprob, step_size=1.5 + ) + initial_state = inference_algorithm.init(jnp.array(1.0)) + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 20000, 2_000 + ) + mcse_test_cases = [ { @@ -736,7 +934,7 @@ class MonteCarloStandardErrorTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.key(20220203) + self.key = jax.random.key(8456) def generate_multivariate_target(self, rng=None): """Genrate a Multivariate Normal distribution as target.""" @@ -805,21 +1003,22 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=2_000, ) ) - _, states, _ = inference_loop_multiple_chains( - multi_chain_sample_key, initial_states + _, states = inference_loop_multiple_chains( + rng_key=multi_chain_sample_key, initial_state=initial_states ) - posterior_samples = states.position[:, -1000:] + posterior_samples = states[:, -1000:] posterior_delta = posterior_samples - true_loc posterior_variance = posterior_delta**2.0 posterior_correlation = jnp.prod(posterior_delta, axis=-1, keepdims=True) / ( true_scale[0] * true_scale[1] ) - _ = jax.tree_map( + _ = jax.tree.map( self.mcse_test, [posterior_samples, posterior_variance, posterior_correlation], [true_loc, true_scale**2, true_rho], diff --git a/tests/mcmc/test_trajectory.py b/tests/mcmc/test_trajectory.py index dccffa564..c8a5aa908 100644 --- a/tests/mcmc/test_trajectory.py +++ b/tests/mcmc/test_trajectory.py @@ -1,6 +1,4 @@ """Test the trajectory integration""" -import functools - import chex import jax import jax.numpy as jnp @@ -75,7 +73,7 @@ def test_dynamic_progressive_integration_divergence( assert is_diverging.item() is should_diverge def test_dynamic_progressive_equal_recursive(self): - rng_key = jax.random.key(23132) + rng_key = jax.random.key(23133) def logdensity_fn(x): return -((1.0 - x[0]) ** 2) - 1.5 * (x[1] - x[0] ** 2) ** 2 @@ -124,15 +122,16 @@ def logdensity_fn(x): divergence_threshold, ) - for _ in range(50): + for i in range(50): + subkey = jax.random.fold_in(rng_key, i) ( - rng_key, + rng_buildtree, rng_direction, rng_tree_depth, rng_step_size, rng_position, rng_momentum, - ) = jax.random.split(rng_key, 6) + ) = jax.random.split(subkey, 6) direction = jax.random.choice(rng_direction, jnp.array([-1, 1])) tree_depth = jax.random.choice(rng_tree_depth, np.arange(2, 5)) initial_state = integrators.new_integrator_state( @@ -153,7 +152,7 @@ def logdensity_fn(x): is_diverging0, has_terminated0, ) = trajectory_integrator( - rng_key, + rng_buildtree, initial_state, direction, termination_state, @@ -169,7 +168,7 @@ def logdensity_fn(x): is_diverging1, has_terminated1, ) = buildtree_integrator( - rng_key, + rng_buildtree, initial_state, direction, tree_depth, @@ -177,11 +176,8 @@ def logdensity_fn(x): initial_energy, ) # Assert that the trajectory being built is the same - jax.tree_map( - functools.partial(np.testing.assert_allclose, rtol=1e-5), - trajectory0, - trajectory1, - ) + chex.assert_trees_all_close(trajectory0, trajectory1, rtol=1e-5) + assert is_diverging0 == is_diverging1 assert has_terminated0 == has_terminated1 # We dont expect the proposal to be the same (even with the same PRNGKey @@ -287,11 +283,7 @@ def test_static_integration_variable_num_steps(self): # we still get the same result fori_state = jax.jit(static_integration)(initial_state, 0.1, 10) - jax.tree_util.tree_map( - functools.partial(np.testing.assert_allclose, rtol=1e-5), - fori_state, - scan_state, - ) + chex.assert_trees_all_close(fori_state, scan_state, rtol=1e-5) def test_dynamic_hmc_integration_steps(self): rng_key = jax.random.key(0) diff --git a/tests/optimizers/test_optimizers.py b/tests/optimizers/test_optimizers.py index a715acc18..a7549842f 100644 --- a/tests/optimizers/test_optimizers.py +++ b/tests/optimizers/test_optimizers.py @@ -88,7 +88,7 @@ def regression_model(key): minimize_lbfgs, objective_fn, maxiter=maxiter, maxcor=maxcor ) )(b0_flatten) - history = jax.tree_map(lambda x: x[: status.iter_num + 1], history) + history = jax.tree.map(lambda x: x[: status.iter_num + 1], history) # Test recover alpha S = jnp.diff(history.x, axis=0) @@ -138,7 +138,7 @@ def loss_fn(x): (result, status), history = self.variant( functools.partial(minimize_lbfgs, loss_fn, maxiter=50) )(np.zeros(nd)) - history = jax.tree_map(lambda x: x[: status.iter_num + 1], history) + history = jax.tree.map(lambda x: x[: status.iter_num + 1], history) np.testing.assert_allclose(result, mean, rtol=0.01) diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 1bbc68970..7d6190af5 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -12,7 +12,9 @@ import blackjax import blackjax.smc.resampling as resampling from blackjax import adaptive_tempered_smc, tempered_smc -from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning +from blackjax.mcmc.random_walk import build_irmh +from blackjax.smc import extend_params +from blackjax.smc.inner_kernel_tuning import as_top_level_api as inner_kernel_tuning from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate from blackjax.smc.tuning.from_particles import ( mass_matrix_from_particles, @@ -92,38 +94,37 @@ def smc_inner_kernel_tuning_test_case( proposal_factory.return_value = 100 def mcmc_parameter_update_fn(state, info): - return 100 + return extend_params({"mean": 100}) - mcmc_factory = MagicMock() - sampling_algorithm = MagicMock() - mcmc_factory.return_value = sampling_algorithm prior = lambda x: stats.norm.logpdf(x) - def kernel_factory(proposal_distribution): - kernel = blackjax.irmh.build_kernel() - - def wrapped_kernel(rng_key, state, logdensity): - return kernel(rng_key, state, logdensity, proposal_distribution) - - return wrapped_kernel + def wrapped_kernel(rng_key, state, logdensity, mean): + return build_irmh()( + rng_key, + state, + logdensity, + functools.partial(irmh_proposal_distribution, mean=mean), + ) kernel = inner_kernel_tuning( logprior_fn=prior, loglikelihood_fn=specialized_log_weights_fn, - mcmc_factory=kernel_factory, + mcmc_step_fn=wrapped_kernel, mcmc_init_fn=blackjax.irmh.init, resampling_fn=resampling.systematic, smc_algorithm=smc_algorithm, - mcmc_parameters={}, mcmc_parameter_update_fn=mcmc_parameter_update_fn, - initial_parameter_value=irmh_proposal_distribution, + initial_parameter_value=extend_params({"mean": 1.0}), **smc_parameters, ) new_state, new_info = kernel.step( self.key, state=kernel.init(init_particles), **step_parameters ) - assert new_state.parameter_override == 100 + assert set(new_state.parameter_override.keys()) == { + "mean", + } + np.testing.assert_allclose(new_state.parameter_override["mean"], 100) class MeanAndStdFromParticlesTest(chex.TestCase): @@ -270,14 +271,6 @@ def setUp(self): super().setUp() self.key = jax.random.key(42) - def mcmc_factory(self, mass_matrix): - return functools.partial( - blackjax.hmc.build_kernel(), - inverse_mass_matrix=mass_matrix, - step_size=10e-2, - num_integration_steps=50, - ) - @chex.all_variants(with_pmap=False) def test_with_adaptive_tempered(self): ( @@ -286,18 +279,30 @@ def test_with_adaptive_tempered(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() + def parameter_update(state, info): + return extend_params( + { + "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "step_size": 10e-2, + "num_integration_steps": 50, + }, + ) + init, step = blackjax.inner_kernel_tuning( adaptive_tempered_smc, logprior_fn, loglikelihood_fn, - self.mcmc_factory, + blackjax.hmc.build_kernel(), blackjax.hmc.init, - {}, resampling.systematic, - mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles( - state.particles + mcmc_parameter_update_fn=parameter_update, + initial_parameter_value=extend_params( + dict( + inverse_mass_matrix=jnp.eye(2), + step_size=10e-2, + num_integration_steps=50, + ), ), - initial_parameter_value=jnp.eye(2), num_mcmc_steps=10, target_ess=0.5, ) @@ -306,20 +311,20 @@ def test_with_adaptive_tempered(self): def inference_loop(kernel, rng_key, initial_state): def cond(carry): - state, key = carry + _, state = carry return state.sampler_state.lmbda < 1 def body(carry): - state, op_key = carry - op_key, subkey = jax.random.split(op_key, 2) + i, state = carry + subkey = jax.random.fold_in(rng_key, i) state, _ = kernel(subkey, state) - return state, op_key + return i + 1, state - return jax.lax.while_loop(cond, body, (initial_state, rng_key)) + return jax.lax.while_loop(cond, body, (0, initial_state)) - state, _ = inference_loop(smc_kernel, self.key, init_state) + _, state = inference_loop(smc_kernel, self.key, init_state) - assert state.parameter_override.shape == (2, 2) + assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2) self.assert_linear_regression_test_case(state.sampler_state) @chex.all_variants(with_pmap=False) @@ -331,18 +336,30 @@ def test_with_tempered_smc(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() + def parameter_update(state, info): + return extend_params( + { + "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "step_size": 10e-2, + "num_integration_steps": 50, + }, + ) + init, step = blackjax.inner_kernel_tuning( tempered_smc, logprior_fn, loglikelihood_fn, - self.mcmc_factory, + blackjax.hmc.build_kernel(), blackjax.hmc.init, - {}, resampling.systematic, - mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles( - state.particles + mcmc_parameter_update_fn=parameter_update, + initial_parameter_value=extend_params( + dict( + inverse_mass_matrix=jnp.eye(2), + step_size=10e-2, + num_integration_steps=50, + ), ), - initial_parameter_value=jnp.eye(2), num_mcmc_steps=10, ) @@ -352,12 +369,12 @@ def test_with_tempered_smc(self): lambda_schedule = np.logspace(-5, 0, num_tempering_steps) def body_fn(carry, lmbda): - rng_key, state = carry - rng_key, subkey = jax.random.split(rng_key) + i, state = carry + subkey = jax.random.fold_in(self.key, i) new_state, info = smc_kernel(subkey, state, lmbda=lmbda) - return (rng_key, new_state), (new_state, info) + return (i + 1, new_state), (new_state, info) - (_, result), _ = jax.lax.scan(body_fn, (self.key, init_state), lambda_schedule) + (_, result), _ = jax.lax.scan(body_fn, (0, init_state), lambda_schedule) self.assert_linear_regression_test_case(result.sampler_state) diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index 3d2469914..fdda30b3a 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -7,6 +7,7 @@ import blackjax from blackjax import adaptive_tempered_smc from blackjax.mcmc.random_walk import normal +from blackjax.smc import extend_params class SMCAndMCMCIntegrationTest(unittest.TestCase): @@ -18,8 +19,9 @@ class SMCAndMCMCIntegrationTest(unittest.TestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) + self.n_particles = 3 self.initial_particles = jax.random.multivariate_normal( - self.key, jnp.zeros(2), jnp.eye(2), (3,) + self.key, jnp.zeros(2), jnp.eye(2), (self.n_particles,) ) def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters): @@ -40,54 +42,80 @@ def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters): kernel(self.key, init(self.initial_particles)) def test_compatible_with_rwm(self): + rwm = blackjax.additive_step_random_walk.build_kernel() + + def kernel(rng_key, state, logdensity_fn, proposal_mean): + return rwm(rng_key, state, logdensity_fn, normal(proposal_mean)) + self.check_compatible( - blackjax.additive_step_random_walk.build_kernel(), + kernel, blackjax.additive_step_random_walk.init, - {"random_step": normal(1.0)}, + extend_params({"proposal_mean": 1.0}), ) def test_compatible_with_rmh(self): + rmh = blackjax.rmh.build_kernel() + + def kernel( + rng_key, state, logdensity_fn, proposal_mean, proposal_logdensity_fn=None + ): + return rmh( + rng_key, + state, + logdensity_fn, + lambda a, b: blackjax.mcmc.random_walk.normal(proposal_mean)(a, b), + proposal_logdensity_fn, + ) + self.check_compatible( - blackjax.rmh.build_kernel(), + kernel, blackjax.rmh.init, - { - "transition_generator": lambda a, b: blackjax.mcmc.random_walk.normal( - 1.0 - )(a, b) - }, + extend_params({"proposal_mean": 1.0}), ) def test_compatible_with_hmc(self): self.check_compatible( blackjax.hmc.build_kernel(), blackjax.hmc.init, - { - "step_size": 0.3, - "inverse_mass_matrix": jnp.array([1]), - "num_integration_steps": 1, - }, + extend_params( + { + "step_size": 0.3, + "inverse_mass_matrix": jnp.array([1.0]), + "num_integration_steps": 1, + }, + ), ) def test_compatible_with_irmh(self): + def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None): + return blackjax.irmh.build_kernel()( + rng_key, + state, + logdensity_fn, + lambda key: mean + jax.random.normal(key), + proposal_logdensity_fn, + ) + self.check_compatible( - blackjax.irmh.build_kernel(), + kernel, blackjax.irmh.init, - { - "proposal_distribution": lambda key: jnp.array([1.0, 1.0]) - + jax.random.normal(key) - }, + extend_params({"mean": jnp.array([1.0, 1.0])}), ) def test_compatible_with_nuts(self): self.check_compatible( blackjax.nuts.build_kernel(), blackjax.nuts.init, - {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, + extend_params( + {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, + ), ) def test_compatible_with_mala(self): self.check_compatible( - blackjax.mala.build_kernel(), blackjax.mala.init, {"step_size": 1e-10} + blackjax.mala.build_kernel(), + blackjax.mala.init, + extend_params({"step_size": 1e-10}), ) @staticmethod diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 242e11c55..6366182a8 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -8,7 +8,7 @@ import blackjax import blackjax.smc.resampling as resampling -from blackjax.smc.base import init, step +from blackjax.smc.base import extend_params, init, step def logdensity_fn(position): @@ -31,14 +31,8 @@ def test_smc(self): num_mcmc_steps = 20 num_particles = 1000 - hmc = blackjax.hmc( - logdensity_fn, - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=50, - ) - - def update_fn(rng_key, position): + def update_fn(rng_key, position, update_params): + hmc = blackjax.hmc(logdensity_fn, **update_params) state = hmc.init(position) def body_fn(state, rng_key): @@ -53,13 +47,20 @@ def body_fn(state, rng_key): # Initialize the state of the SMC sampler init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) - state = init(init_particles) + same_for_all_params = dict( + step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50 + ) + + state = init( + init_particles, + same_for_all_params, + ) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4))( sample_key, state, - jax.vmap(update_fn), + jax.vmap(update_fn, in_axes=(0, 0, None)), jax.vmap(logdensity_fn), resampling.systematic, ) @@ -74,15 +75,9 @@ def test_smc_waste_free(self): num_particles = 1000 num_resampled = num_particles // num_mcmc_steps - hmc = blackjax.hmc( - logdensity_fn, - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=100, - ) - - def waste_free_update_fn(keys, particles): - def one_particle_fn(rng_key, position): + def waste_free_update_fn(keys, particles, update_params): + def one_particle_fn(rng_key, position, particle_update_params): + hmc = blackjax.hmc(logdensity_fn, **particle_update_params) state = hmc.init(position) def body_fn(state, rng_key): @@ -93,7 +88,9 @@ def body_fn(state, rng_key): _, (states, info) = jax.lax.scan(body_fn, state, keys) return states.position, info - particles, info = jax.vmap(one_particle_fn)(keys, particles) + particles, info = jax.vmap(one_particle_fn, in_axes=(0, 0, None))( + keys, particles, update_params + ) particles = particles.reshape((num_particles,)) return particles, info @@ -101,7 +98,14 @@ def body_fn(state, rng_key): # Initialize the state of the SMC sampler init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) - state = init(init_particles) + state = init( + init_particles, + dict( + step_size=1e-2, + inverse_mass_matrix=jnp.eye(1), + num_integration_steps=100, + ), + ) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4, 5))( @@ -118,5 +122,24 @@ def body_fn(state, rng_key): np.testing.assert_allclose(1.0, std, atol=1e-1) +class ExtendParamsTest(chex.TestCase): + def test_extend_params(self): + extended = extend_params( + { + "a": 50, + "b": np.array([50]), + "c": np.array([50, 60]), + "d": np.array([[1, 2], [3, 4]]), + }, + ) + np.testing.assert_allclose(extended["a"], np.ones((1,)) * 50) + np.testing.assert_allclose(extended["b"], np.array([[50]])) + np.testing.assert_allclose(extended["c"], np.array([[50, 60]])) + np.testing.assert_allclose( + extended["d"], + np.array([[[1, 2], [3, 4]]]), + ) + + if __name__ == "__main__": absltest.main() diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index f4234d117..527457d62 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -12,6 +12,7 @@ import blackjax.smc.resampling as resampling import blackjax.smc.solver as solver from blackjax import adaptive_tempered_smc, tempered_smc +from blackjax.smc import extend_params from tests.smc import SMCLinearRegressionTestCase @@ -21,13 +22,13 @@ def cond(carry): return state.lmbda < 1 def body(carry): - i, state, op_key, curr_loglikelihood = carry - op_key, subkey = jax.random.split(op_key, 2) + i, state, curr_loglikelihood = carry + subkey = jax.random.fold_in(rng_key, i) state, info = kernel(subkey, state) - return i + 1, state, op_key, curr_loglikelihood + info.log_likelihood_increment + return i + 1, state, curr_loglikelihood + info.log_likelihood_increment - total_iter, final_state, _, log_likelihood = jax.lax.while_loop( - cond, body, (0, initial_state, rng_key, 0.0) + total_iter, final_state, log_likelihood = jax.lax.while_loop( + cond, body, (0, initial_state, 0.0) ) return total_iter, final_state, log_likelihood @@ -64,13 +65,28 @@ def logprior_fn(x): hmc_kernel = blackjax.hmc.build_kernel() hmc_init = blackjax.hmc.init - hmc_parameters = { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(2), - "num_integration_steps": 50, - } - for target_ess in [0.5, 0.75]: + base_params = extend_params( + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + } + ) + + # verify results are equivalent with all shared, all unshared, and mixed params + hmc_parameters_list = [ + base_params, + jax.tree.map(lambda x: jnp.repeat(x, num_particles, axis=0), base_params), + jax.tree_util.tree_map_with_path( + lambda path, x: jnp.repeat(x, num_particles, axis=0) + if path[0].key == "step_size" + else x, + base_params, + ), + ] + + for target_ess, hmc_parameters in zip([0.5, 0.5, 0.75], hmc_parameters_list): tempering = adaptive_tempered_smc( logprior_fn, loglikelihood_fn, @@ -110,11 +126,13 @@ def test_fixed_schedule_tempered_smc(self): lambda_schedule = np.logspace(-5, 0, num_tempering_steps) hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(2), - "num_integration_steps": 50, - } + hmc_parameters = extend_params( + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) tempering = tempered_smc( logprior_fn, @@ -129,12 +147,12 @@ def test_fixed_schedule_tempered_smc(self): smc_kernel = self.variant(tempering.step) def body_fn(carry, lmbda): - rng_key, state = carry - rng_key, subkey = jax.random.split(rng_key) + i, state = carry + subkey = jax.random.fold_in(self.key, i) new_state, info = smc_kernel(subkey, state, lmbda) - return (rng_key, new_state), (new_state, info) + return (i + 1, new_state), (new_state, info) - (_, result), _ = jax.lax.scan(body_fn, (self.key, init_state), lambda_schedule) + (_, result), _ = jax.lax.scan(body_fn, (0, init_state), lambda_schedule) self.assert_linear_regression_test_case(result) @@ -174,11 +192,13 @@ def test_normalizing_constant(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(num_dim), - "num_integration_steps": 50, - } + hmc_parameters = extend_params( + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(num_dim), + "num_integration_steps": 50, + }, + ) tempering = adaptive_tempered_smc( logprior_fn, diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index d8f09cea0..2d108a48d 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -48,8 +48,11 @@ def run_regression(algorithm, **parameters): ) inference_algorithm = algorithm(logdensity_fn, **parameters) - _, states, _ = run_inference_algorithm( - inference_key, state, inference_algorithm, 10_000 + _, (states, _) = run_inference_algorithm( + rng_key=inference_key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=10_000, ) return states diff --git a/tests/test_compilation.py b/tests/test_compilation.py index e16f8ff3c..7179b71ba 100644 --- a/tests/test_compilation.py +++ b/tests/test_compilation.py @@ -40,8 +40,8 @@ def logdensity_fn(x): ) step = jax.jit(kernel.step) - for _ in range(10): - rng_key, sample_key = jax.random.split(rng_key) + for i in range(10): + sample_key = jax.random.fold_in(rng_key, i) state, _ = step(sample_key, state) def test_nuts(self): @@ -66,8 +66,8 @@ def logdensity_fn(x): ) step = jax.jit(kernel.step) - for _ in range(10): - rng_key, sample_key = jax.random.split(rng_key) + for i in range(10): + sample_key = jax.random.fold_in(rng_key, i) state, _ = step(sample_key, state) def test_hmc_warmup(self): @@ -94,8 +94,8 @@ def logdensity_fn(x): (state, parameters), _ = warmup.run(rng_key, 1.0, num_steps=100) kernel = jax.jit(blackjax.hmc(logdensity_fn, **parameters).step) - for _ in range(10): - rng_key, sample_key = jax.random.split(rng_key) + for i in range(10): + sample_key = jax.random.fold_in(rng_key, i) state, _ = kernel(sample_key, state) def test_nuts_warmup(self): @@ -121,8 +121,8 @@ def logdensity_fn(x): (state, parameters), _ = warmup.run(rng_key, 1.0, num_steps=100) step = jax.jit(blackjax.nuts(logdensity_fn, **parameters).step) - for _ in range(10): - rng_key, sample_key = jax.random.split(rng_key) + for i in range(10): + sample_key = jax.random.fold_in(rng_key, i) state, _ = step(sample_key, state) diff --git a/tests/test_util.py b/tests/test_util.py index d3eed1193..78198f013 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -3,15 +3,15 @@ import jax.numpy as jnp from absl.testing import absltest, parameterized -from blackjax.mcmc.hmc import hmc -from blackjax.util import run_inference_algorithm +import blackjax +from blackjax.util import run_inference_algorithm, store_only_expectation_values class RunInferenceAlgorithmTest(chex.TestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) - self.algorithm = hmc( + self.algorithm = blackjax.hmc( logdensity_fn=self.logdensity_fn, inverse_mass_matrix=jnp.eye(2), step_size=1.0, @@ -19,23 +19,82 @@ def setUp(self): ) self.num_steps = 10 - def check_compatible(self, initial_state_or_position, progress_bar): + def check_compatible(self, initial_state, progress_bar): """ Runs 10 steps with `run_inference_algorithm` starting with - `initial_state_or_position` and potentially a progress bar. + `initial_state` and potentially a progress bar. """ _ = run_inference_algorithm( - self.key, - initial_state_or_position, - self.algorithm, - self.num_steps, - progress_bar, - transform=lambda x: x.position, + rng_key=self.key, + initial_state=initial_state, + inference_algorithm=self.algorithm, + num_steps=self.num_steps, + progress_bar=progress_bar, + transform=lambda state, info: state.position, ) + def test_streaming(self): + def logdensity_fn(x): + return -0.5 * jnp.sum(jnp.square(x)) + + initial_position = jnp.ones( + 10, + ) + + init_key, state_key, run_key = jax.random.split(self.key, 3) + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=logdensity_fn, rng_key=state_key + ) + L = 1.0 + step_size = 0.1 + num_steps = 4 + + sampling_alg = blackjax.mclmc( + logdensity_fn, + L=L, + step_size=step_size, + ) + + state_transform = lambda x: x.position + + _, samples = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=sampling_alg, + num_steps=num_steps, + transform=lambda state, info: state_transform(state), + progress_bar=True, + ) + + print("average of steps (slow way):", samples.mean(axis=0)) + + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, state_transform=state_transform + ) + + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, + ) + + assert jnp.allclose(trace_at_every_step[0][-1], samples.mean(axis=0)) + @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): - self.check_compatible(jnp.array([1.0, 1.0]), progress_bar) + _ = run_inference_algorithm( + rng_key=self.key, + initial_position=jnp.array([1.0, 1.0]), + inference_algorithm=self.algorithm, + num_steps=self.num_steps, + progress_bar=progress_bar, + transform=lambda state, info: state.position, + ) @parameterized.parameters([True, False]) def test_compatible_with_initial_state(self, progress_bar): diff --git a/tests/vi/test_meanfield_vi.py b/tests/vi/test_meanfield_vi.py index c5a8a0865..689949720 100644 --- a/tests/vi/test_meanfield_vi.py +++ b/tests/vi/test_meanfield_vi.py @@ -36,12 +36,12 @@ def logdensity_fn(x): state = mfvi.init(initial_position) rng_key = self.key - for _ in range(num_steps): - rng_key, subkey = jax.random.split(rng_key) + for i in range(num_steps): + subkey = jax.random.fold_in(rng_key, i) state, _ = jax.jit(mfvi.step)(subkey, state) loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"] - scale = jax.tree_map(jnp.exp, state.rho) + scale = jax.tree.map(jnp.exp, state.rho) scale_1, scale_2 = scale["x_1"], scale["x_2"] self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01) self.assertAlmostEqual(scale_1, ground_truth[0][1], delta=0.01) diff --git a/tests/vi/test_schrodinger_follmer.py b/tests/vi/test_schrodinger_follmer.py index c59af6cf4..fd58fed0a 100644 --- a/tests/vi/test_schrodinger_follmer.py +++ b/tests/vi/test_schrodinger_follmer.py @@ -6,7 +6,7 @@ import jax.scipy.stats as stats from absl.testing import absltest -from blackjax.vi.schrodinger_follmer import schrodinger_follmer +from blackjax.vi.schrodinger_follmer import as_top_level_api as schrodinger_follmer class SchrodingerFollmerTest(chex.TestCase):