diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 96df14920..dfdcfc545 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -1,3 +1,6 @@ +import dataclasses +from typing import Callable + from blackjax._version import __version__ from .adaptation.chees_adaptation import chees_adaptation @@ -5,67 +8,156 @@ from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation +from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat -from .mcmc.barker import barker_proposal -from .mcmc.dynamic_hmc import dynamic_hmc -from .mcmc.elliptical_slice import elliptical_slice -from .mcmc.ghmc import ghmc -from .mcmc.hmc import hmc -from .mcmc.mala import mala -from .mcmc.marginal_latent_gaussian import mgrad_gaussian -from .mcmc.mclmc import mclmc -from .mcmc.nuts import nuts -from .mcmc.periodic_orbital import orbital_hmc -from .mcmc.random_walk import additive_step_random_walk, irmh, rmh -from .mcmc.rmhmc import rmhmc +from .mcmc import barker +from .mcmc import dynamic_hmc as _dynamic_hmc +from .mcmc import elliptical_slice as _elliptical_slice +from .mcmc import ghmc as _ghmc +from .mcmc import hmc as _hmc +from .mcmc import mala as _mala +from .mcmc import marginal_latent_gaussian +from .mcmc import mclmc as _mclmc +from .mcmc import nuts as _nuts +from .mcmc import periodic_orbital, random_walk +from .mcmc import rmhmc as _rmhmc +from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk +from .mcmc.random_walk import ( + irmh_as_top_level_api, + normal_random_walk, + rmh_as_top_level_api, +) from .optimizers import dual_averaging, lbfgs -from .sgmcmc.csgld import csgld -from .sgmcmc.sghmc import sghmc -from .sgmcmc.sgld import sgld -from .sgmcmc.sgnht import sgnht -from .smc.adaptive_tempered import adaptive_tempered_smc -from .smc.inner_kernel_tuning import inner_kernel_tuning -from .smc.tempered import tempered_smc -from .vi.meanfield_vi import meanfield_vi -from .vi.pathfinder import pathfinder -from .vi.schrodinger_follmer import schrodinger_follmer -from .vi.svgd import svgd +from .sgmcmc import csgld as _csgld +from .sgmcmc import sghmc as _sghmc +from .sgmcmc import sgld as _sgld +from .sgmcmc import sgnht as _sgnht +from .smc import adaptive_tempered +from .smc import inner_kernel_tuning as _inner_kernel_tuning +from .smc import tempered +from .vi import meanfield_vi as _meanfield_vi +from .vi import pathfinder as _pathfinder +from .vi import schrodinger_follmer as _schrodinger_follmer +from .vi import svgd as _svgd +from .vi.pathfinder import PathFinderAlgorithm + +""" +The above three classes exist as a backwards compatible way of exposing both the high level, differentiable +factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower +level to be mostly functional programming in nature and reducing boilerplate code. +""" + + +@dataclasses.dataclass +class GenerateSamplingAPI: + differentiable: Callable + init: Callable + build_kernel: Callable + + def __call__(self, *args, **kwargs) -> SamplingAlgorithm: + return self.differentiable(*args, **kwargs) + + def register_factory(self, name, callable): + setattr(self, name, callable) + + +@dataclasses.dataclass +class GenerateVariationalAPI: + differentiable: Callable + init: Callable + step: Callable + sample: Callable + + def __call__(self, *args, **kwargs) -> VIAlgorithm: + return self.differentiable(*args, **kwargs) + + +@dataclasses.dataclass +class GeneratePathfinderAPI: + differentiable: Callable + approximate: Callable + sample: Callable + + def __call__(self, *args, **kwargs) -> PathFinderAlgorithm: + return self.differentiable(*args, **kwargs) + + +def generate_top_level_api_from(module): + return GenerateSamplingAPI( + module.as_top_level_api, module.init, module.build_kernel + ) + + +# MCMC +hmc = generate_top_level_api_from(_hmc) +nuts = generate_top_level_api_from(_nuts) +rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh) +irmh = GenerateSamplingAPI( + irmh_as_top_level_api, random_walk.init, random_walk.build_irmh +) +dynamic_hmc = generate_top_level_api_from(_dynamic_hmc) +rmhmc = generate_top_level_api_from(_rmhmc) +mala = generate_top_level_api_from(_mala) +mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian) +orbital_hmc = generate_top_level_api_from(periodic_orbital) + +additive_step_random_walk = GenerateSamplingAPI( + _additive_step_random_walk, random_walk.init, random_walk.build_additive_step +) + +additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) + +mclmc = generate_top_level_api_from(_mclmc) +elliptical_slice = generate_top_level_api_from(_elliptical_slice) +ghmc = generate_top_level_api_from(_ghmc) +barker_proposal = generate_top_level_api_from(barker) + +hmc_family = [hmc, nuts] + +# SMC +adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered) +tempered_smc = generate_top_level_api_from(tempered) +inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) + +smc_family = [tempered_smc, adaptive_tempered_smc] +"Step_fn returning state has a .particles attribute" + +# stochastic gradient mcmc +sgld = generate_top_level_api_from(_sgld) +sghmc = generate_top_level_api_from(_sghmc) +sgnht = generate_top_level_api_from(_sgnht) +csgld = generate_top_level_api_from(_csgld) +svgd = generate_top_level_api_from(_svgd) + +# variational inference +meanfield_vi = GenerateVariationalAPI( + _meanfield_vi.as_top_level_api, + _meanfield_vi.init, + _meanfield_vi.step, + _meanfield_vi.sample, +) +schrodinger_follmer = GenerateVariationalAPI( + _schrodinger_follmer.as_top_level_api, + _schrodinger_follmer.init, + _schrodinger_follmer.step, + _schrodinger_follmer.sample, +) + +pathfinder = GeneratePathfinderAPI( + _pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample +) + __all__ = [ "__version__", "dual_averaging", # optimizers "lbfgs", - "hmc", # mcmc - "dynamic_hmc", - "rmhmc", - "mala", - "mgrad_gaussian", - "nuts", - "orbital_hmc", - "additive_step_random_walk", - "rmh", - "irmh", - "mclmc", - "elliptical_slice", - "ghmc", - "barker_proposal", - "sgld", # stochastic gradient mcmc - "sghmc", - "sgnht", - "csgld", "window_adaptation", # mcmc adaptation "meads_adaptation", "chees_adaptation", "pathfinder_adaptation", "mclmc_find_L_and_step_size", # mclmc adaptation - "adaptive_tempered_smc", # smc - "tempered_smc", - "inner_kernel_tuning", - "meanfield_vi", # variational inference - "pathfinder", - "schrodinger_follmer", - "svgd", "ess", # diagnostics "rhat", ] diff --git a/blackjax/adaptation/pathfinder_adaptation.py b/blackjax/adaptation/pathfinder_adaptation.py index c70ed3f99..efcc55741 100644 --- a/blackjax/adaptation/pathfinder_adaptation.py +++ b/blackjax/adaptation/pathfinder_adaptation.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Pathinder warmup for the HMC family of sampling algorithms.""" -from typing import Callable, NamedTuple, Union +from typing import Callable, NamedTuple import jax import jax.numpy as jnp -import blackjax.mcmc as mcmc import blackjax.vi as vi from blackjax.adaptation.base import AdaptationInfo, AdaptationResults from blackjax.adaptation.step_size import ( @@ -138,7 +137,7 @@ def final(warmup_state: PathfinderAdaptationState) -> tuple[float, Array]: def pathfinder_adaptation( - algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts], + algorithm, logdensity_fn: Callable, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index cc871b4b6..e15121dc5 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Stan warmup for the HMC family of sampling algorithms.""" -from typing import Callable, NamedTuple, Union +from typing import Callable, NamedTuple import jax import jax.numpy as jnp -import blackjax.mcmc as mcmc from blackjax.adaptation.base import AdaptationInfo, AdaptationResults from blackjax.adaptation.mass_matrix import ( MassMatrixAdaptationState, @@ -243,7 +242,7 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]: def window_adaptation( - algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts], + algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, @@ -252,7 +251,7 @@ def window_adaptation( **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of - algorithms in the HMC fmaily. + algorithms in the HMC family. See Blackjax.hmc_family Algorithms in the HMC family on a euclidean manifold depend on the value of at least two parameters: the step size, related to the trajectory diff --git a/blackjax/mcmc/barker.py b/blackjax/mcmc/barker.py index b91721a71..9923bd5f3 100644 --- a/blackjax/mcmc/barker.py +++ b/blackjax/mcmc/barker.py @@ -24,7 +24,7 @@ from blackjax.mcmc.proposal import static_binomial_sampling from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "barker_proposal"] +__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"] class BarkerState(NamedTuple): @@ -128,7 +128,10 @@ def kernel( return kernel -class barker_proposal: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a Gaussian base kernel. @@ -179,24 +182,16 @@ class barker_proposal: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel() - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) + def step_fn(rng_key: PRNGKey, state): + return kernel(rng_key, state, logdensity_fn, step_size) - def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, step_size) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def _barker_sample_nd(key, mean, a, scale): diff --git a/blackjax/mcmc/dynamic_hmc.py b/blackjax/mcmc/dynamic_hmc.py index 0fe4ec992..de77be825 100644 --- a/blackjax/mcmc/dynamic_hmc.py +++ b/blackjax/mcmc/dynamic_hmc.py @@ -27,8 +27,8 @@ "DynamicHMCState", "init", "build_kernel", - "dynamic_hmc", "halton_sequence", + "as_top_level_api", ] @@ -115,7 +115,16 @@ def kernel( return kernel -class dynamic_hmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: """Implements the (basic) user interface for the dynamic HMC kernel. Parameters @@ -144,41 +153,26 @@ class dynamic_hmc: ------- A ``SamplingAlgorithm``. """ - - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - inverse_mass_matrix: Array, - *, - divergence_threshold: int = 1000, - integrator: Callable = integrators.velocity_verlet, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), - ) -> SamplingAlgorithm: - kernel = cls.build_kernel( - integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn + kernel = build_kernel( + integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + # Note that rng_key here is not necessarily a PRNGKey, could be a Array that + # for generates a sequence of pseudo or quasi-random numbers (previously + # named as `random_generator_arg`) + return init(position, logdensity_fn, rng_key) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, ) - def init_fn(position: ArrayLikeTree, rng_key: Array): - # Note that rng_key here is not necessarily a PRNGKey, could be a Array that - # for generates a sequence of pseudo or quasi-random numbers (previously - # named as `random_generator_arg`) - return cls.init(position, logdensity_fn, rng_key) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - inverse_mass_matrix, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def halton_sequence(i: Array, max_bits: int = 10) -> float: diff --git a/blackjax/mcmc/elliptical_slice.py b/blackjax/mcmc/elliptical_slice.py index 52f242210..09ed66c86 100644 --- a/blackjax/mcmc/elliptical_slice.py +++ b/blackjax/mcmc/elliptical_slice.py @@ -26,7 +26,7 @@ "EllipSliceInfo", "init", "build_kernel", - "elliptical_slice", + "as_top_level_api", ] @@ -119,7 +119,12 @@ def kernel( return kernel -class elliptical_slice: +def as_top_level_api( + loglikelihood_fn: Callable, + *, + mean: Array, + cov: Array, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Elliptical Slice sampling kernel. Examples @@ -151,31 +156,20 @@ class elliptical_slice: ------- A ``SamplingAlgorithm``. """ + kernel = build_kernel(cov, mean) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, loglikelihood_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + loglikelihood_fn, + ) - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - loglikelihood_fn: Callable, - *, - mean: Array, - cov: Array, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(cov, mean) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, loglikelihood_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - loglikelihood_fn, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def elliptical_proposal( diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index 3cd0c86f6..a04ce0641 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -25,7 +25,7 @@ from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise -__all__ = ["GHMCState", "init", "build_kernel", "ghmc"] +__all__ = ["GHMCState", "init", "build_kernel", "as_top_level_api"] class GHMCState(NamedTuple): @@ -195,7 +195,16 @@ def update_momentum(rng_key, state, alpha, momentum_generator): return momentum -class ghmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + momentum_inverse_scale: ArrayLikeTree, + alpha: float, + delta: float, + *, + divergence_threshold: int = 1000, + noise_gn: Callable = lambda _: 0.0, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Generalized HMC kernel. The Generalized HMC kernel performs a similar procedure to the standard HMC @@ -257,34 +266,20 @@ class ghmc: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(noise_gn, divergence_threshold) - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - momentum_inverse_scale: ArrayLikeTree, - alpha: float, - delta: float, - *, - divergence_threshold: int = 1000, - noise_gn: Callable = lambda _: 0.0, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(noise_gn, divergence_threshold) - - def init_fn(position: ArrayLikeTree, rng_key: PRNGKey): - return cls.init(position, rng_key, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - momentum_inverse_scale, - alpha, - delta, - ) - - return SamplingAlgorithm(init_fn, step_fn) + def init_fn(position: ArrayLikeTree, rng_key: PRNGKey): + return init(position, rng_key, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + momentum_inverse_scale, + alpha, + delta, + ) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index b48834e5f..452b94e44 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -29,7 +29,7 @@ "HMCInfo", "init", "build_kernel", - "hmc", + "as_top_level_api", ] @@ -150,7 +150,15 @@ def kernel( return kernel -class hmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes, + num_integration_steps: int, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the HMC kernel. The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias @@ -225,36 +233,23 @@ class hmc: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(integrator, divergence_threshold) - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - inverse_mass_matrix: metrics.MetricTypes, - num_integration_steps: int, - *, - divergence_threshold: int = 1000, - integrator: Callable = integrators.velocity_verlet, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(integrator, divergence_threshold) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - inverse_mass_matrix, - num_integration_steps, - ) - - return SamplingAlgorithm(init_fn, step_fn) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + num_integration_steps, + ) + + return SamplingAlgorithm(init_fn, step_fn) def hmc_proposal( diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 1f1345cc4..56c0c0077 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -23,7 +23,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["MALAState", "MALAInfo", "init", "build_kernel", "mala"] +__all__ = ["MALAState", "MALAInfo", "init", "build_kernel", "as_top_level_api"] class MALAState(NamedTuple): @@ -117,7 +117,10 @@ def kernel( return kernel -class mala: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the MALA kernel. The general mala kernel builder (:meth:`blackjax.mcmc.mala.build_kernel`, alias `blackjax.mala.build_kernel`) can be @@ -167,21 +170,13 @@ class mala: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel() - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) + def step_fn(rng_key: PRNGKey, state): + return kernel(rng_key, state, logdensity_fn, step_size) - def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, step_size) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/marginal_latent_gaussian.py b/blackjax/mcmc/marginal_latent_gaussian.py index 8d4d76f6a..d2783f8d9 100644 --- a/blackjax/mcmc/marginal_latent_gaussian.py +++ b/blackjax/mcmc/marginal_latent_gaussian.py @@ -22,7 +22,13 @@ from blackjax.mcmc.proposal import static_binomial_sampling from blackjax.types import Array, PRNGKey -__all__ = ["MarginalState", "MarginalInfo", "init", "build_kernel", "mgrad_gaussian"] +__all__ = [ + "MarginalState", + "MarginalInfo", + "init", + "build_kernel", + "as_top_level_api", +] # [TODO](https://github.com/blackjax-devs/blackjax/issues/237) @@ -206,7 +212,13 @@ def kernel(key: PRNGKey, state: MarginalState, logdensity_fn, delta): return kernel -class mgrad_gaussian: +def as_top_level_api( + logdensity_fn: Callable, + covariance: Optional[Array] = None, + mean: Optional[Array] = None, + cov_svd: Optional[CovarianceSVD] = None, + step_size: float = 1.0, +) -> SamplingAlgorithm: """Implements the marginal sampler for latent Gaussian model of :cite:p:`titsias2018auxiliary`. It uses a first order approximation to the log_likelihood of a model with Gaussian prior. @@ -247,41 +259,28 @@ class mgrad_gaussian: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - covariance: Optional[Array] = None, - mean: Optional[Array] = None, - cov_svd: Optional[CovarianceSVD] = None, - step_size: float = 1.0, - ) -> SamplingAlgorithm: - if cov_svd is None: - if covariance is None: - raise ValueError("Either covariance or cov_svd must be provided.") - cov_svd = svd_from_covariance(covariance) - - U, Gamma, U_t = cov_svd - - if mean is not None: - logdensity_fn = generate_mean_shifted_logprob( - logdensity_fn, mean, covariance - ) - - kernel = cls.build_kernel(cov_svd) - - def init_fn(position: Array, rng_key=None): - del rng_key - return init(position, logdensity_fn, U_t) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - ) - - return SamplingAlgorithm(init_fn, step_fn) + if cov_svd is None: + if covariance is None: + raise ValueError("Either covariance or cov_svd must be provided.") + cov_svd = svd_from_covariance(covariance) + + U, Gamma, U_t = cov_svd + + if mean is not None: + logdensity_fn = generate_mean_shifted_logprob(logdensity_fn, mean, covariance) + + kernel = build_kernel(cov_svd) + + def init_fn(position: Array, rng_key=None): + del rng_key + return init(position, logdensity_fn, U_t) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + ) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 7c636181f..406c4125d 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -24,7 +24,7 @@ from blackjax.types import ArrayLike, PRNGKey from blackjax.util import generate_unit_vector, pytree_size -__all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] +__all__ = ["MCLMCInfo", "init", "build_kernel", "as_top_level_api"] class MCLMCInfo(NamedTuple): @@ -103,7 +103,12 @@ def kernel( return kernel -class mclmc: +def as_top_level_api( + logdensity_fn: Callable, + L, + step_size, + integrator=isokinetic_mclachlan, +) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel parameters at initialization time, we provide a helper function that @@ -150,25 +155,15 @@ class mclmc: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(logdensity_fn, integrator) - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - L, - step_size, - integrator=isokinetic_mclachlan, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(logdensity_fn, integrator) + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return init(position, logdensity_fn, rng_key) - def init_fn(position: ArrayLike, rng_key: PRNGKey): - return cls.init(position, logdensity_fn, rng_key) + def update_fn(rng_key, state): + return kernel(rng_key, state, L, step_size) - def update_fn(rng_key, state): - return kernel(rng_key, state, L, step_size) - - return SamplingAlgorithm(init_fn, update_fn) + return SamplingAlgorithm(init_fn, update_fn) def partially_refresh_momentum(momentum, rng_key, step_size, L): diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 5ffc083b1..c75ecdec6 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -27,7 +27,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["NUTSInfo", "init", "build_kernel", "nuts"] +__all__ = ["NUTSInfo", "init", "build_kernel", "as_top_level_api"] init = hmc.init @@ -147,7 +147,15 @@ def kernel( return kernel -class nuts: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes, + *, + max_num_doublings: int = 10, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the nuts kernel. Examples @@ -202,37 +210,23 @@ class nuts: A ``SamplingAlgorithm``. """ + kernel = build_kernel(integrator, divergence_threshold) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + max_num_doublings, + ) - init = staticmethod(hmc.init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - inverse_mass_matrix: metrics.MetricTypes, - *, - max_num_doublings: int = 10, - divergence_threshold: int = 1000, - integrator: Callable = integrators.velocity_verlet, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(integrator, divergence_threshold) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - inverse_mass_matrix, - max_num_doublings, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def iterative_nuts_proposal( diff --git a/blackjax/mcmc/periodic_orbital.py b/blackjax/mcmc/periodic_orbital.py index 6e4a2ca5e..61625a0b8 100644 --- a/blackjax/mcmc/periodic_orbital.py +++ b/blackjax/mcmc/periodic_orbital.py @@ -22,7 +22,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["PeriodicOrbitalState", "init", "build_kernel", "orbital_hmc"] +__all__ = ["PeriodicOrbitalState", "init", "build_kernel", "as_top_level_api"] class PeriodicOrbitalState(NamedTuple): @@ -217,7 +217,14 @@ def kernel( return kernel -class orbital_hmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, # assume momentum is always Gaussian + period: int, + *, + bijection: Callable = integrators.velocity_verlet, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Periodic orbital MCMC kernel. Each iteration of the periodic orbital MCMC outputs ``period`` weighted samples from @@ -261,36 +268,23 @@ class orbital_hmc: ------- A ``SamplingAlgorithm``. """ + kernel = build_kernel(bijection) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn, period) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + period, + ) - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - inverse_mass_matrix: Array, # assume momentum is always Gaussian - period: int, - *, - bijection: Callable = integrators.velocity_verlet, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(bijection) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn, period) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - inverse_mass_matrix, - period, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def periodic_orbital_proposal( diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index e454c057d..a1d1c3bd6 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -80,8 +80,9 @@ "rmh_proposal", "build_rmh_transition_energy", "additive_step_random_walk", - "irmh", - "rmh", + "irmh_as_top_level_api", + "rmh_as_top_level_api", + "normal_random_walk", ] @@ -182,7 +183,25 @@ def proposal_generator(key_proposal, position): return kernel -class additive_step_random_walk: +def normal_random_walk(logdensity_fn: Callable, sigma): + """ + Parameters + ---------- + logdensity_fn + The log density probability density function from which we wish to sample. + sigma + The value of the covariance matrix of the gaussian proposal distribution. + + Returns + ------- + A ``SamplingAlgorithm``. + """ + return additive_step_random_walk(logdensity_fn, normal(sigma)) + + +def additive_step_random_walk( + logdensity_fn: Callable, random_step: Callable +) -> SamplingAlgorithm: """Implements the user interface for the Additive Step RMH Examples @@ -218,39 +237,16 @@ class additive_step_random_walk: ------- A ``SamplingAlgorithm``. """ + kernel = build_additive_step() - init = staticmethod(init) - build_kernel = staticmethod(build_additive_step) - - @classmethod - def normal_random_walk(cls, logdensity_fn: Callable, sigma): - """ - Parameters - ---------- - logdensity_fn - The log density probability density function from which we wish to sample. - sigma - The value of the covariance matrix of the gaussian proposal distribution. - - Returns - ------- - A ``SamplingAlgorithm``. - """ - return cls(logdensity_fn, normal(sigma)) - - def __new__( # type: ignore[misc] - cls, logdensity_fn: Callable, random_step: Callable - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) + def step_fn(rng_key: PRNGKey, state): + return kernel(rng_key, state, logdensity_fn, random_step) - def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, random_step) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def build_irmh() -> Callable: @@ -297,7 +293,11 @@ def proposal_generator(rng_key: PRNGKey, position: ArrayTree): return kernel -class irmh: +def irmh_as_top_level_api( + logdensity_fn: Callable, + proposal_distribution: Callable, + proposal_logdensity_fn: Optional[Callable] = None, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the independent RMH. Examples @@ -334,32 +334,22 @@ class irmh: A ``SamplingAlgorithm``. """ + kernel = build_irmh() + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + proposal_distribution, + proposal_logdensity_fn, + ) - init = staticmethod(init) - build_kernel = staticmethod(build_irmh) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - proposal_distribution: Callable, - proposal_logdensity_fn: Optional[Callable] = None, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - proposal_distribution, - proposal_logdensity_fn, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def build_rmh(): @@ -420,7 +410,11 @@ def kernel( return kernel -class rmh: +def rmh_as_top_level_api( + logdensity_fn: Callable, + proposal_generator: Callable[[PRNGKey, ArrayLikeTree], ArrayTree], + proposal_logdensity_fn: Optional[Callable[[ArrayLikeTree], ArrayTree]] = None, +) -> SamplingAlgorithm: """Implements the user interface for the RMH. Examples @@ -456,32 +450,22 @@ class rmh: ------- A ``SamplingAlgorithm``. """ + kernel = build_rmh() + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + proposal_generator, + proposal_logdensity_fn, + ) - init = staticmethod(init) - build_kernel = staticmethod(build_rmh) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - proposal_generator: Callable[[PRNGKey, ArrayLikeTree], ArrayTree], - proposal_logdensity_fn: Optional[Callable[[ArrayLikeTree], ArrayTree]] = None, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - proposal_generator, - proposal_logdensity_fn, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def build_rmh_transition_energy(proposal_logdensity_fn: Optional[Callable]) -> Callable: diff --git a/blackjax/mcmc/rmhmc.py b/blackjax/mcmc/rmhmc.py index edcfb3571..a4551a781 100644 --- a/blackjax/mcmc/rmhmc.py +++ b/blackjax/mcmc/rmhmc.py @@ -20,14 +20,22 @@ from blackjax.mcmc import hmc from blackjax.types import ArrayTree, PRNGKey -__all__ = ["init", "build_kernel", "rmhmc"] +__all__ = ["init", "build_kernel", "as_top_level_api"] init = hmc.init build_kernel = hmc.build_kernel -class rmhmc: +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + mass_matrix: Union[metrics.Metric, Callable], + num_integration_steps: int, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.implicit_midpoint, +) -> SamplingAlgorithm: """A Riemannian Manifold Hamiltonian Monte Carlo kernel Of note, this kernel is simply an alias of the ``hmc`` kernel with a @@ -62,34 +70,20 @@ class rmhmc: ------- A ``SamplingAlgorithm``. """ + kernel = build_kernel(integrator, divergence_threshold) - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + def init_fn(position: ArrayTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - step_size: float, - mass_matrix: Union[metrics.Metric, Callable], - num_integration_steps: int, - *, - divergence_threshold: int = 1000, - integrator: Callable = integrators.implicit_midpoint, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(integrator, divergence_threshold) + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + mass_matrix, + num_integration_steps, + ) - def init_fn(position: ArrayTree, rng_key=None): - del rng_key - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - logdensity_fn, - step_size, - mass_matrix, - num_integration_steps, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index e0e008a33..506740c50 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -23,7 +23,7 @@ from blackjax.sgmcmc.diffusions import overdamped_langevin from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["ContourSGLDState", "init", "build_kernel", "csgld"] +__all__ = ["ContourSGLDState", "init", "build_kernel", "as_top_level_api"] class ContourSGLDState(NamedTuple): @@ -174,7 +174,14 @@ def kernel( return kernel -class csgld: +def as_top_level_api( + logdensity_estimator: Callable, + gradient_estimator: Callable, + zeta: float = 1, + num_partitions: int = 512, + energy_gap: float = 100, + min_energy: float = 0, +) -> SamplingAlgorithm: r"""Implements the (basic) user interface for the Contour SGLD kernel. Parameters @@ -209,42 +216,30 @@ class csgld: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(num_partitions, energy_gap, min_energy) - def __new__( # type: ignore[misc] - cls, - logdensity_estimator: Callable, - gradient_estimator: Callable, - zeta: float = 1, - num_partitions: int = 512, - energy_gap: float = 100, - min_energy: float = 0, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(num_partitions, energy_gap, min_energy) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position, num_partitions) - - def step_fn( - rng_key: PRNGKey, - state: ContourSGLDState, - minibatch: ArrayLikeTree, - step_size_diff: float, - step_size_stoch: float, - temperature: float = 1.0, - ) -> ContourSGLDState: - return kernel( - rng_key, - state, - logdensity_estimator, - gradient_estimator, - minibatch, - step_size_diff, - step_size_stoch, - zeta, - temperature, - ) - - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, num_partitions) + + def step_fn( + rng_key: PRNGKey, + state: ContourSGLDState, + minibatch: ArrayLikeTree, + step_size_diff: float, + step_size_stoch: float, + temperature: float = 1.0, + ) -> ContourSGLDState: + return kernel( + rng_key, + state, + logdensity_estimator, + gradient_estimator, + minibatch, + step_size_diff, + step_size_stoch, + zeta, + temperature, + ) + + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 806bbc14e..afa8e2e42 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -21,7 +21,7 @@ from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise -__all__ = ["init", "build_kernel", "sghmc"] +__all__ = ["init", "build_kernel", "as_top_level_api"] def init(position: ArrayLikeTree) -> ArrayLikeTree: @@ -58,7 +58,12 @@ def body_fn(state, rng_key): return kernel -class sghmc: +def as_top_level_api( + grad_estimator: Callable, + num_integration_steps: int = 10, + alpha: float = 0.01, + beta: float = 0, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the SGHMC kernel. The general sghmc kernel builder (:meth:`blackjax.sgmcmc.sghmc.build_kernel`, alias @@ -111,37 +116,27 @@ class sghmc: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(alpha, beta) - def __new__( # type: ignore[misc] - cls, - grad_estimator: Callable, - num_integration_steps: int = 10, - alpha: float = 0.01, - beta: float = 0, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(alpha, beta) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position) - - def step_fn( - rng_key: PRNGKey, - state: ArrayLikeTree, - minibatch: ArrayLikeTree, - step_size: float, - temperature: float = 1, - ) -> ArrayTree: - return kernel( - rng_key, - state, - grad_estimator, - minibatch, - step_size, - num_integration_steps, - temperature, - ) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position) - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + def step_fn( + rng_key: PRNGKey, + state: ArrayLikeTree, + minibatch: ArrayLikeTree, + step_size: float, + temperature: float = 1, + ) -> ArrayTree: + return kernel( + rng_key, + state, + grad_estimator, + minibatch, + step_size, + num_integration_steps, + temperature, + ) + + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index e2055c511..dca47a983 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -18,7 +18,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["init", "build_kernel", "sgld"] +__all__ = ["init", "build_kernel", "as_top_level_api"] def init(position: ArrayLikeTree) -> ArrayLikeTree: @@ -47,7 +47,9 @@ def kernel( return kernel -class sgld: +def as_top_level_api( + grad_estimator: Callable, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the SGLD kernel. The general sgld kernel builder (:meth:`blackjax.sgmcmc.sgld.build_kernel`, alias @@ -100,28 +102,19 @@ class sgld: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel() - def __new__( # type: ignore[misc] - cls, - grad_estimator: Callable, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel() - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position) - - def step_fn( - rng_key: PRNGKey, - state: ArrayLikeTree, - minibatch: ArrayLikeTree, - step_size: float, - temperature: float = 1, - ) -> ArrayTree: - return kernel( - rng_key, state, grad_estimator, minibatch, step_size, temperature - ) - - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position) + + def step_fn( + rng_key: PRNGKey, + state: ArrayLikeTree, + minibatch: ArrayLikeTree, + step_size: float, + temperature: float = 1, + ) -> ArrayTree: + return kernel(rng_key, state, grad_estimator, minibatch, step_size, temperature) + + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/sgnht.py b/blackjax/sgmcmc/sgnht.py index 57b0a4ca2..ad9547406 100644 --- a/blackjax/sgmcmc/sgnht.py +++ b/blackjax/sgmcmc/sgnht.py @@ -19,7 +19,7 @@ from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise -__all__ = ["SGNHTState", "init", "build_kernel", "sgnht"] +__all__ = ["SGNHTState", "init", "build_kernel", "as_top_level_api"] class SGNHTState(NamedTuple): @@ -67,7 +67,11 @@ def kernel( return kernel -class sgnht: +def as_top_level_api( + grad_estimator: Callable, + alpha: float = 0.01, + beta: float = 0.0, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the SGNHT kernel. The general sgnht kernel (:meth:`blackjax.sgmcmc.sgnht.build_kernel`, alias @@ -121,33 +125,22 @@ class sgnht: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel = build_kernel(alpha, beta) - def __new__( # type: ignore[misc] - cls, - grad_estimator: Callable, - alpha: float = 0.01, - beta: float = 0.0, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel(alpha, beta) - - def init_fn( - position: ArrayLikeTree, - rng_key: PRNGKey, - init_xi: Union[None, float] = None, - ): - return cls.init(position, rng_key, init_xi or alpha) - - def step_fn( - rng_key: PRNGKey, - state: SGNHTState, - minibatch: ArrayLikeTree, - step_size: float, - temperature: float = 1, - ) -> SGNHTState: - return kernel( - rng_key, state, grad_estimator, minibatch, step_size, temperature - ) - - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + def init_fn( + position: ArrayLikeTree, + rng_key: PRNGKey, + init_xi: Union[None, float] = None, + ): + return init(position, rng_key, init_xi or alpha) + + def step_fn( + rng_key: PRNGKey, + state: SGNHTState, + minibatch: ArrayLikeTree, + step_size: float, + temperature: float = 1, + ) -> SGNHTState: + return kernel(rng_key, state, grad_estimator, minibatch, step_size, temperature) + + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index 5b02c783b..b8a611606 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -23,7 +23,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, PRNGKey -__all__ = ["build_kernel", "adaptive_tempered_smc"] +__all__ = ["build_kernel", "init", "as_top_level_api"] def build_kernel( @@ -103,7 +103,20 @@ def kernel( return kernel -class adaptive_tempered_smc: +init = tempered.init + + +def as_top_level_api( + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + mcmc_parameters: dict, + resampling_fn: Callable, + target_ess: float, + root_solver: Callable = solver.dichotomy, + num_mcmc_steps: int = 10, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. Parameters @@ -133,42 +146,26 @@ class adaptive_tempered_smc: A ``SamplingAlgorithm``. """ + kernel = build_kernel( + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + target_ess, + root_solver, + ) - init = staticmethod(tempered.init) - build_kernel = staticmethod(build_kernel) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position) - def __new__( # type: ignore[misc] - cls, - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - mcmc_parameters: dict, - resampling_fn: Callable, - target_ess: float, - root_solver: Callable = solver.dichotomy, - num_mcmc_steps: int = 10, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel( - logprior_fn, - loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, - target_ess, - root_solver, + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + num_mcmc_steps, + mcmc_parameters, ) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - num_mcmc_steps, - mcmc_parameters, - ) - - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 705a60c35..2a63fd1ce 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -1,9 +1,7 @@ -from typing import Callable, Dict, NamedTuple, Tuple, Union +from typing import Callable, Dict, NamedTuple, Tuple from blackjax.base import SamplingAlgorithm -from blackjax.smc.adaptive_tempered import adaptive_tempered_smc from blackjax.smc.base import SMCInfo, SMCState -from blackjax.smc.tempered import tempered_smc from blackjax.types import ArrayTree, PRNGKey @@ -78,7 +76,18 @@ def kernel( return kernel -class inner_kernel_tuning: +def as_top_level_api( + smc_algorithm, + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], + initial_parameter_value, + num_mcmc_steps: int = 10, + **extra_parameters, +) -> SamplingAlgorithm: """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some @@ -89,7 +98,7 @@ class inner_kernel_tuning: ---------- smc_algorithm Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of - a sampling algorithm that returns an SMCState and SMCInfo pair). + a sampling algorithm that returns an SMCState and SMCInfo pair). See blackjax.smc_family logprior_fn A function that computes the log density of the prior distribution loglikelihood_fn @@ -112,41 +121,25 @@ class inner_kernel_tuning: """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - smc_algorithm: Union[adaptive_tempered_smc, tempered_smc], - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], - initial_parameter_value, - num_mcmc_steps: int = 10, + kernel = build_kernel( + smc_algorithm, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + mcmc_parameter_update_fn, + num_mcmc_steps, **extra_parameters, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel( - smc_algorithm, - logprior_fn, - loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, - mcmc_parameter_update_fn, - num_mcmc_steps, - **extra_parameters, - ) + ) - def init_fn(position, rng_key=None): - del rng_key - return cls.init(smc_algorithm.init, position, initial_parameter_value) + def init_fn(position, rng_key=None): + del rng_key + return init(smc_algorithm.init, position, initial_parameter_value) - def step_fn( - rng_key: PRNGKey, state, **extra_step_parameters - ) -> Tuple[StateWithParameterOverride, SMCInfo]: - return kernel(rng_key, state, **extra_step_parameters) + def step_fn( + rng_key: PRNGKey, state, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + return kernel(rng_key, state, **extra_step_parameters) - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 561eadecc..b373d062f 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -21,7 +21,7 @@ from blackjax.smc.base import SMCState from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["TemperedSMCState", "init", "build_kernel"] +__all__ = ["TemperedSMCState", "init", "build_kernel", "as_top_level_api"] class TemperedSMCState(NamedTuple): @@ -156,7 +156,15 @@ def body_fn(state, rng_key): return kernel -class tempered_smc: +def as_top_level_api( + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + mcmc_parameters: dict, + resampling_fn: Callable, + num_mcmc_steps: int = 10, +) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. Parameters @@ -181,39 +189,25 @@ class tempered_smc: A ``SamplingAlgorithm``. """ - - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - mcmc_parameters: dict, - resampling_fn: Callable, - num_mcmc_steps: int = 10, - ) -> SamplingAlgorithm: - kernel = cls.build_kernel( - logprior_fn, - loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, + kernel = build_kernel( + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + ) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position) + + def step_fn(rng_key: PRNGKey, state, lmbda): + return kernel( + rng_key, + state, + num_mcmc_steps, + lmbda, + mcmc_parameters, ) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return cls.init(position) - - def step_fn(rng_key: PRNGKey, state, lmbda): - return kernel( - rng_key, - state, - num_mcmc_steps, - lmbda, - mcmc_parameters, - ) - - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/vi/meanfield_vi.py b/blackjax/vi/meanfield_vi.py index f7fc3769f..6f379c7b0 100644 --- a/blackjax/vi/meanfield_vi.py +++ b/blackjax/vi/meanfield_vi.py @@ -27,7 +27,7 @@ "sample", "generate_meanfield_logdensity", "step", - "meanfield_vi", + "as_top_level_api", ] @@ -109,7 +109,11 @@ def sample(rng_key: PRNGKey, state: MFVIState, num_samples: int = 1): return _sample(rng_key, state.mu, state.rho, num_samples) -class meanfield_vi: +def as_top_level_api( + logdensity_fn: Callable, + optimizer: GradientTransformation, + num_samples: int = 100, +): """High-level implementation of Mean-Field Variational Inference. Parameters @@ -128,26 +132,16 @@ class meanfield_vi: """ - init = staticmethod(init) - step = staticmethod(step) - sample = staticmethod(sample) - - def __new__( - cls, - logdensity_fn: Callable, - optimizer: GradientTransformation, - num_samples: int = 100, - ): # type: ignore[misc] - def init_fn(position: ArrayLikeTree): - return cls.init(position, optimizer) + def init_fn(position: ArrayLikeTree): + return init(position, optimizer) - def step_fn(rng_key: PRNGKey, state: MFVIState) -> tuple[MFVIState, MFVIInfo]: - return cls.step(rng_key, state, logdensity_fn, optimizer, num_samples) + def step_fn(rng_key: PRNGKey, state: MFVIState) -> tuple[MFVIState, MFVIInfo]: + return step(rng_key, state, logdensity_fn, optimizer, num_samples) - def sample_fn(rng_key: PRNGKey, state: MFVIState, num_samples: int): - return cls.sample(rng_key, state, num_samples) + def sample_fn(rng_key: PRNGKey, state: MFVIState, num_samples: int): + return sample(rng_key, state, num_samples) - return VIAlgorithm(init_fn, step_fn, sample_fn) + return VIAlgorithm(init_fn, step_fn, sample_fn) def _sample(rng_key, mu, rho, num_samples): diff --git a/blackjax/vi/pathfinder.py b/blackjax/vi/pathfinder.py index 7d7e9f5c2..c1b7dc113 100644 --- a/blackjax/vi/pathfinder.py +++ b/blackjax/vi/pathfinder.py @@ -25,7 +25,7 @@ ) from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["PathfinderState", "approximate", "sample", "pathfinder"] +__all__ = ["PathfinderState", "approximate", "sample", "as_top_level_api"] class PathfinderState(NamedTuple): @@ -242,7 +242,7 @@ def sample( return jax.vmap(unravel_fn)(phi), logq -class pathfinder: +def as_top_level_api(logdensity_fn: Callable) -> PathFinderAlgorithm: """Implements the (basic) user interface for the pathfinder kernel. Pathfinder locates normal approximations to the target density along a @@ -266,21 +266,17 @@ class pathfinder: """ - approximate = staticmethod(approximate) - sample = staticmethod(sample) - - def __new__(cls, logdensity_fn: Callable) -> PathFinderAlgorithm: # type: ignore[misc] - def approximate_fn( - rng_key: PRNGKey, - position: ArrayLikeTree, - num_samples: int = 200, - **lbfgs_parameters, - ): - return cls.approximate( - rng_key, logdensity_fn, position, num_samples, **lbfgs_parameters - ) + def approximate_fn( + rng_key: PRNGKey, + position: ArrayLikeTree, + num_samples: int = 200, + **lbfgs_parameters, + ): + return approximate( + rng_key, logdensity_fn, position, num_samples, **lbfgs_parameters + ) - def sample_fn(rng_key: PRNGKey, state: PathfinderState, num_samples: int): - return cls.sample(rng_key, state, num_samples) + def sample_fn(rng_key: PRNGKey, state: PathfinderState, num_samples: int): + return sample(rng_key, state, num_samples) - return PathFinderAlgorithm(approximate_fn, sample_fn) + return PathFinderAlgorithm(approximate_fn, sample_fn) diff --git a/blackjax/vi/schrodinger_follmer.py b/blackjax/vi/schrodinger_follmer.py index d7f454f22..51d1e88fe 100644 --- a/blackjax/vi/schrodinger_follmer.py +++ b/blackjax/vi/schrodinger_follmer.py @@ -181,7 +181,7 @@ def _log_fn_corrected(position, logdensity_fn): return log_pdf_val + norm -class schrodinger_follmer: +def as_top_level_api(logdensity_fn: Callable, n_steps: int, n_inner_samples: int) -> VIAlgorithm: # type: ignore[misc] """Implements the (basic) user interface for the Schrödinger-Föllmer algortithm :cite:p:`huang2021schrodingerfollmer`. The Schrödinger-Föllmer algorithm obtains (approximate) samples from the target distribution by means of a diffusion with @@ -202,22 +202,17 @@ class schrodinger_follmer: """ - init = staticmethod(init) - step = staticmethod(step) - sample = staticmethod(sample) + def init_fn(position: ArrayLikeTree): + return init(position) - def __new__(cls, logdensity_fn: Callable, n_steps: int, n_inner_samples: int) -> VIAlgorithm: # type: ignore[misc] - def init_fn(position: ArrayLikeTree): - return cls.init(position) + def step_fn( + rng_key: PRNGKey, state: SchrodingerFollmerState + ) -> tuple[SchrodingerFollmerState, SchrodingerFollmerInfo]: + return step(rng_key, state, logdensity_fn, 1 / n_steps, n_inner_samples) - def step_fn( - rng_key: PRNGKey, state: SchrodingerFollmerState - ) -> tuple[SchrodingerFollmerState, SchrodingerFollmerInfo]: - return cls.step(rng_key, state, logdensity_fn, 1 / n_steps, n_inner_samples) - - def sample_fn(rng_key: PRNGKey, state: SchrodingerFollmerState, n_samples: int): - return cls.sample( - rng_key, state, logdensity_fn, n_steps, n_inner_samples, n_samples - ) + def sample_fn(rng_key: PRNGKey, state: SchrodingerFollmerState, n_samples: int): + return sample( + rng_key, state, logdensity_fn, n_steps, n_inner_samples, n_samples + ) - return VIAlgorithm(init_fn, step_fn, sample_fn) + return VIAlgorithm(init_fn, step_fn, sample_fn) diff --git a/blackjax/vi/svgd.py b/blackjax/vi/svgd.py index f93941aee..881de77e6 100644 --- a/blackjax/vi/svgd.py +++ b/blackjax/vi/svgd.py @@ -9,7 +9,13 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree -__all__ = ["svgd", "rbf_kernel", "update_median_heuristic"] +__all__ = [ + "as_top_level_api", + "init", + "build_kernel", + "rbf_kernel", + "update_median_heuristic", +] class SVGDState(NamedTuple): @@ -123,7 +129,12 @@ def update_median_heuristic(state: SVGDState) -> SVGDState: return SVGDState(position, median_heuristic(kernel_parameters, position), opt_state) -class svgd: +def as_top_level_api( + grad_logdensity_fn: Callable, + optimizer, + kernel: Callable = rbf_kernel, + update_kernel_parameters: Callable = update_median_heuristic, +): """Implements the (basic) user interface for the svgd algorithm. Parameters @@ -142,26 +153,16 @@ class svgd: A ``SamplingAlgorithm``. """ - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) + kernel_ = build_kernel(optimizer) - def __new__( - cls, - grad_logdensity_fn: Callable, - optimizer, - kernel: Callable = rbf_kernel, - update_kernel_parameters: Callable = update_median_heuristic, + def init_fn( + initial_position: ArrayLikeTree, + kernel_parameters: dict[str, Any] = {"length_scale": 1.0}, ): - kernel_ = cls.build_kernel(optimizer) - - def init_fn( - initial_position: ArrayLikeTree, - kernel_parameters: dict[str, Any] = {"length_scale": 1.0}, - ): - return cls.init(initial_position, kernel_parameters, optimizer) + return init(initial_position, kernel_parameters, optimizer) - def step_fn(state, **grad_params): - state = kernel_(state, grad_logdensity_fn, kernel, **grad_params) - return update_kernel_parameters(state) + def step_fn(state, **grad_params): + state = kernel_(state, grad_logdensity_fn, kernel, **grad_params) + return update_kernel_parameters(state) - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 6e9961799..39c1b811b 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -539,103 +539,6 @@ def rmhmc_static_mass_matrix_fn(position): return jnp.array([1.0]) -normal_test_cases = [ - { - "algorithm": blackjax.hmc, - "initial_position": jnp.array(3.0), - "parameters": { - "step_size": 3.9, - "inverse_mass_matrix": jnp.array([1.0]), - "num_integration_steps": 30, - }, - "num_sampling_steps": 6000, - "burnin": 1_000, - }, - { - "algorithm": blackjax.nuts, - "initial_position": jnp.array(3.0), - "parameters": {"step_size": 4.0, "inverse_mass_matrix": jnp.array([1.0])}, - "num_sampling_steps": 6000, - "burnin": 1_000, - }, - { - "algorithm": blackjax.orbital_hmc, - "initial_position": jnp.array(100.0), - "parameters": { - "step_size": 0.1, - "inverse_mass_matrix": jnp.array([0.1]), - "period": 100, - }, - "num_sampling_steps": 20_000, - "burnin": 15_000, - }, - { - "algorithm": blackjax.additive_step_random_walk.normal_random_walk, - "initial_position": 1.0, - "parameters": {"sigma": jnp.array([1.0])}, - "num_sampling_steps": 20_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.rmh, - "parameters": {}, - "initial_position": 1.0, - "num_sampling_steps": 20_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.mala, - "initial_position": 1.0, - "parameters": {"step_size": 1e-1}, - "num_sampling_steps": 45_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.elliptical_slice, - "initial_position": 1.0, - "parameters": {"cov": jnp.array([2.0**2]), "mean": 1.0}, - "num_sampling_steps": 20_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.irmh, - "initial_position": jnp.array(1.0), - "parameters": {}, - "num_sampling_steps": 50_000, - "burnin": 5_000, - }, - { - "algorithm": blackjax.ghmc, - "initial_position": jnp.array(1.0), - "parameters": { - "step_size": 1.0, - "momentum_inverse_scale": jnp.array(1.0), - "alpha": 0.8, - "delta": 2.0, - }, - "num_sampling_steps": 6000, - "burnin": 1_000, - }, - { - "algorithm": blackjax.barker_proposal, - "initial_position": 1.0, - "parameters": {"step_size": 1.5}, - "num_sampling_steps": 20_000, - "burnin": 2_000, - }, - { - "algorithm": blackjax.rmhmc, - "initial_position": jnp.array(3.0), - "parameters": { - "step_size": 1.0, - "num_integration_steps": 30, - }, - "num_sampling_steps": 6000, - "burnin": 1_000, - }, -] - - class UnivariateNormalTest(chex.TestCase): """Test sampling of a univariate Normal distribution. @@ -649,34 +552,15 @@ def setUp(self): def normal_logprob(self, x): return stats.norm.logpdf(x, loc=1.0, scale=2.0) - @chex.all_variants(with_pmap=False) - @parameterized.parameters(normal_test_cases) - def test_univariate_normal( - self, algorithm, initial_position, parameters, num_sampling_steps, burnin + def univariate_normal_test_case( + self, + inference_algorithm, + rng_key, + initial_state, + num_sampling_steps, + burnin, + postprocess_samples=None, ): - if algorithm == blackjax.irmh: - parameters["proposal_distribution"] = functools.partial( - irmh_proposal_distribution, mean=1.0 - ) - - if algorithm == blackjax.rmh: - parameters["proposal_generator"] = rmh_proposal_distribution - - if algorithm == blackjax.rmhmc: - parameters["mass_matrix"] = rmhmc_static_mass_matrix_fn - - inference_algorithm = algorithm(self.normal_logprob, **parameters) - rng_key = self.key - if algorithm == blackjax.elliptical_slice: - inference_algorithm = algorithm(lambda x: jnp.ones_like(x), **parameters) - if algorithm == blackjax.ghmc: - rng_key, initial_state_key = jax.random.split(rng_key) - initial_state = inference_algorithm.init( - initial_position, initial_state_key - ) - else: - initial_state = inference_algorithm.init(initial_position) - inference_key, orbit_key = jax.random.split(rng_key) _, states, _ = self.variant( functools.partial( @@ -686,15 +570,161 @@ def test_univariate_normal( ) )(inference_key, initial_state) - if algorithm == blackjax.orbital_hmc: - samples = orbit_samples( - states.positions[burnin:], states.weights[burnin:], orbit_key - ) + # else: + if postprocess_samples: + samples = postprocess_samples(states, orbit_key) else: samples = states.position[burnin:] np.testing.assert_allclose(np.mean(samples), 1.0, rtol=1e-1) np.testing.assert_allclose(np.var(samples), 4.0, rtol=1e-1) + @chex.all_variants(with_pmap=False) + def test_irmh(self): + inference_algorithm = blackjax.irmh( + self.normal_logprob, + proposal_distribution=functools.partial( + irmh_proposal_distribution, mean=1.0 + ), + ) + initial_state = inference_algorithm.init(jnp.array(1.0)) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 50000, 5000 + ) + + @chex.all_variants(with_pmap=False) + def test_nuts(self): + inference_algorithm = blackjax.nuts( + self.normal_logprob, step_size=4.0, inverse_mass_matrix=jnp.array([1.0]) + ) + + initial_state = inference_algorithm.init(jnp.array(3.0)) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 5000, 1000 + ) + + @chex.all_variants(with_pmap=False) + def test_rmh(self): + inference_algorithm = blackjax.rmh( + self.normal_logprob, proposal_generator=rmh_proposal_distribution + ) + initial_state = inference_algorithm.init(1.0) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 20_000, 5_000 + ) + + @chex.all_variants(with_pmap=False) + def test_rmhmc(self): + inference_algorithm = blackjax.rmhmc( + self.normal_logprob, + mass_matrix=rmhmc_static_mass_matrix_fn, + step_size=1.0, + num_integration_steps=30, + ) + + initial_state = inference_algorithm.init(jnp.array(3.0)) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 6_000, 1_000 + ) + + @chex.all_variants(with_pmap=False) + def test_elliptical_slice(self): + inference_algorithm = blackjax.elliptical_slice( + lambda x: jnp.ones_like(x), cov=jnp.array([2.0**2]), mean=1.0 + ) + + initial_state = inference_algorithm.init(1.0) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 20_000, 5_000 + ) + + @chex.all_variants(with_pmap=False) + def test_ghmc(self): + rng_key, initial_state_key = jax.random.split(self.key) + inference_algorithm = blackjax.ghmc( + self.normal_logprob, + step_size=1.0, + momentum_inverse_scale=jnp.array(1.0), + alpha=0.8, + delta=2.0, + ) + initial_state = inference_algorithm.init(jnp.array(1.0), initial_state_key) + self.univariate_normal_test_case( + inference_algorithm, rng_key, initial_state, 6000, 1000 + ) + + @chex.all_variants(with_pmap=False) + def test_hmc(self): + rng_key, initial_state_key = jax.random.split(self.key) + inference_algorithm = blackjax.hmc( + self.normal_logprob, + step_size=3.9, + inverse_mass_matrix=jnp.array([1.0]), + num_integration_steps=30, + ) + initial_state = inference_algorithm.init(jnp.array(3.0)) + self.univariate_normal_test_case( + inference_algorithm, rng_key, initial_state, 6000, 1000 + ) + + @chex.all_variants(with_pmap=False) + def test_orbital_hmc(self): + inference_algorithm = blackjax.orbital_hmc( + self.normal_logprob, + step_size=0.1, + inverse_mass_matrix=jnp.array([0.1]), + period=100, + ) + initial_state = inference_algorithm.init(jnp.array(100.0)) + burnin = 15_000 + + def postprocess_samples(states, key): + return orbit_samples( + states.positions[burnin:], states.weights[burnin:], key + ) + + self.univariate_normal_test_case( + inference_algorithm, + self.key, + initial_state, + 20_000, + burnin, + postprocess_samples, + ) + + @chex.all_variants(with_pmap=False) + def test_random_walk(self): + inference_algorithm = blackjax.additive_step_random_walk.normal_random_walk( + self.normal_logprob, sigma=jnp.array([1.0]) + ) + initial_state = inference_algorithm.init(jnp.array(1.0)) + + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 20_000, 5_000 + ) + + @chex.all_variants(with_pmap=False) + def test_mala(self): + inference_algorithm = blackjax.mala(self.normal_logprob, step_size=1e-1) + initial_state = inference_algorithm.init(jnp.array(1.0)) + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 45000, 5_000 + ) + + @chex.all_variants(with_pmap=False) + def test_barker(self): + inference_algorithm = blackjax.barker_proposal( + self.normal_logprob, step_size=1.5 + ) + initial_state = inference_algorithm.init(jnp.array(1.0)) + self.univariate_normal_test_case( + inference_algorithm, self.key, initial_state, 20000, 2_000 + ) + mcse_test_cases = [ { diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index e33130d31..bf970ae47 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -14,7 +14,7 @@ from blackjax import adaptive_tempered_smc, tempered_smc from blackjax.mcmc.random_walk import build_irmh from blackjax.smc import extend_params -from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning +from blackjax.smc.inner_kernel_tuning import as_top_level_api as inner_kernel_tuning from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate from blackjax.smc.tuning.from_particles import ( mass_matrix_from_particles, diff --git a/tests/test_util.py b/tests/test_util.py index d3eed1193..a6e023074 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -3,7 +3,7 @@ import jax.numpy as jnp from absl.testing import absltest, parameterized -from blackjax.mcmc.hmc import hmc +import blackjax from blackjax.util import run_inference_algorithm @@ -11,7 +11,7 @@ class RunInferenceAlgorithmTest(chex.TestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) - self.algorithm = hmc( + self.algorithm = blackjax.hmc( logdensity_fn=self.logdensity_fn, inverse_mass_matrix=jnp.eye(2), step_size=1.0, diff --git a/tests/vi/test_schrodinger_follmer.py b/tests/vi/test_schrodinger_follmer.py index c59af6cf4..fd58fed0a 100644 --- a/tests/vi/test_schrodinger_follmer.py +++ b/tests/vi/test_schrodinger_follmer.py @@ -6,7 +6,7 @@ import jax.scipy.stats as stats from absl.testing import absltest -from blackjax.vi.schrodinger_follmer import schrodinger_follmer +from blackjax.vi.schrodinger_follmer import as_top_level_api as schrodinger_follmer class SchrodingerFollmerTest(chex.TestCase):