Skip to content

Commit

Permalink
Merged comments from Junpeng
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienCorenflos committed Sep 16, 2024
1 parent c8ee838 commit 1cbe9cf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
12 changes: 9 additions & 3 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from chex import Numeric
from jax.flatten_util import ravel_pytree

from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.types import Array, ArrayLikeTree, ArrayTree, Numeric, PRNGKey
from blackjax.util import generate_gaussian_noise, linear_map

__all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"]
Expand All @@ -61,11 +60,18 @@ def __call__(
...


class Scale(Protocol):
def __call__(
self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
) -> ArrayLikeTree:
...


class Metric(NamedTuple):
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree]
kinetic_energy: KineticEnergy
check_turning: CheckTurning
scale: Callable[[ArrayLikeTree, ArrayLikeTree, bool], ArrayLikeTree]
scale: Scale


MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]]
Expand Down
4 changes: 4 additions & 0 deletions blackjax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ class WelfordAlgorithmState(NamedTuple):

#: JAX PRNGKey
PRNGKey = jax.Array

#: JAX Scalar types
Scalar = Union[float, int]
Numeric = Union[Array, Scalar]

0 comments on commit 1cbe9cf

Please sign in to comment.