From 1cbe9cfa19dff0cf1b9d6bd07c98fa44c36918eb Mon Sep 17 00:00:00 2001 From: Adrien Corenflos Date: Mon, 16 Sep 2024 20:09:09 +0100 Subject: [PATCH] Merged comments from Junpeng --- blackjax/mcmc/metrics.py | 12 +++++++++--- blackjax/types.py | 4 ++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 8ce6c0e56..678b5cc37 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -33,10 +33,9 @@ import jax import jax.numpy as jnp import jax.scipy as jscipy -from chex import Numeric from jax.flatten_util import ravel_pytree -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import Array, ArrayLikeTree, ArrayTree, Numeric, PRNGKey from blackjax.util import generate_gaussian_noise, linear_map __all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"] @@ -61,11 +60,18 @@ def __call__( ... +class Scale(Protocol): + def __call__( + self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + ) -> ArrayLikeTree: + ... + + class Metric(NamedTuple): sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree] kinetic_energy: KineticEnergy check_turning: CheckTurning - scale: Callable[[ArrayLikeTree, ArrayLikeTree, bool], ArrayLikeTree] + scale: Scale MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]] diff --git a/blackjax/types.py b/blackjax/types.py index 5a3b59f07..be73b0d29 100644 --- a/blackjax/types.py +++ b/blackjax/types.py @@ -43,3 +43,7 @@ class WelfordAlgorithmState(NamedTuple): #: JAX PRNGKey PRNGKey = jax.Array + +#: JAX Scalar types +Scalar = Union[float, int] +Numeric = Union[Array, Scalar]