diff --git a/docs/quickstart.md b/docs/quickstart.md index 06b57fe..f31cb62 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -53,7 +53,7 @@ n_chains = 10 rng_key_set = initialize_rng_keys(n_chains, seed=42) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, jax.random.PRNGKey(21)) +model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, PRNGKeyArray(21)) step_size = 1e-1 local_sampler = MALA(log_posterior, True, {"step_size": step_size}) diff --git a/example/dualmoon.py b/example/dualmoon.py index cdfec16..a8fc842 100644 --- a/example/dualmoon.py +++ b/example/dualmoon.py @@ -38,7 +38,7 @@ def target_dualmoon(x, data): data = jnp.zeros(n_dim) rng_key_set = initialize_rng_keys(n_chains, 42) -model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10)) +model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10)) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 diff --git a/example/non_jax_likelihood.py b/example/non_jax_likelihood.py index 4bd6567..4272ec6 100644 --- a/example/non_jax_likelihood.py +++ b/example/non_jax_likelihood.py @@ -46,7 +46,7 @@ def neal_funnel(x): data = jnp.zeros(n_dim) rng_key_set = initialize_rng_keys(n_chains, 42) -model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10)) +model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10)) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 diff --git a/example/notebook/analyzingChains.ipynb b/example/notebook/analyzingChains.ipynb index 4e5cddc..5941368 100644 --- a/example/notebook/analyzingChains.ipynb +++ b/example/notebook/analyzingChains.ipynb @@ -110,7 +110,7 @@ "print(\"Initializing chains, normalizing flow model and local MCMC sampler\")\n", "\n", "initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1\n", - "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))\n", + "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))\n", "MALA_Sampler = MALA(target_dualmoon, True, {\"step_size\": step_size})\n", "\n", "print(\"Initializing samplers classes\")\n", diff --git a/example/notebook/dualmoon.ipynb b/example/notebook/dualmoon.ipynb index 7ddc60c..53533eb 100644 --- a/example/notebook/dualmoon.ipynb +++ b/example/notebook/dualmoon.ipynb @@ -141,7 +141,7 @@ "data = jnp.zeros(n_dim)\n", "\n", "model = MaskedCouplingRQSpline(\n", - " n_dim, n_layers, hidden_size, num_bins, jax.random.PRNGKey(10)\n", + " n_dim, n_layers, hidden_size, num_bins, PRNGKeyArray(10)\n", ")" ] }, diff --git a/example/notebook/maximizing_likelihood.ipynb b/example/notebook/maximizing_likelihood.ipynb index 148cef7..083d18a 100644 --- a/example/notebook/maximizing_likelihood.ipynb +++ b/example/notebook/maximizing_likelihood.ipynb @@ -164,7 +164,7 @@ "\n", "\n", "rng_key_set = initialize_rng_keys(n_chains, 42)\n", - "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))\n", + "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))\n", "\n", "initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1\n", "\n", diff --git a/example/notebook/mog_pretrain.ipynb b/example/notebook/mog_pretrain.ipynb index 87a63b7..f236f36 100644 --- a/example/notebook/mog_pretrain.ipynb +++ b/example/notebook/mog_pretrain.ipynb @@ -135,7 +135,7 @@ "## To use instead the more powerful RQSpline:\n", "from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline\n", "\n", - "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))\n", + "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))\n", "\n", "\n", "# Local sampler\n", diff --git a/example/notebook/normalizingFlow.ipynb b/example/notebook/normalizingFlow.ipynb index 01c304d..0d9a1ff 100644 --- a/example/notebook/normalizingFlow.ipynb +++ b/example/notebook/normalizingFlow.ipynb @@ -101,7 +101,7 @@ "n_layers = 10\n", "n_hidden = 100\n", "\n", - "key, subkey = jax.random.split(jax.random.PRNGKey(0), 2)\n", + "key, subkey = jax.random.split(PRNGKeyArray(0), 2)\n", "\n", "model = RealNVP(\n", " n_feature,\n", @@ -246,7 +246,7 @@ "n_hiddens = [64, 64]\n", "n_bins = 8\n", "\n", - "key, subkey = jax.random.split(jax.random.PRNGKey(1))\n", + "key, subkey = jax.random.split(PRNGKeyArray(1))\n", "\n", "model = MaskedCouplingRQSpline(\n", " n_feature,\n", diff --git a/example/train_normalizing_flow.py b/example/train_normalizing_flow.py index 24e71c8..842d7d5 100644 --- a/example/train_normalizing_flow.py +++ b/example/train_normalizing_flow.py @@ -23,7 +23,7 @@ data = make_moons(n_samples=20000, noise=0.05) data = jnp.array(data[0]) -key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) +key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3) model = MaskedCouplingRQSpline( 2, @@ -40,4 +40,4 @@ key, model, loss = train_flow(rng, model, data, num_epochs, batch_size, verbose=True) -nf_samples = model.sample(jax.random.PRNGKey(124098), 5000) +nf_samples = model.sample(PRNGKeyArray(124098), 5000) diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..bf64e2d --- /dev/null +++ b/ruff.toml @@ -0,0 +1 @@ +ignore = ["F722"] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 686aca5..22a133c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = flowMC -version = 0.2.4 +version = 0.2.5 author = Kaze Wong, Marylou GabriƩ, Dan Foreman-Mackey author_email = kazewong.physics@gmail.com url = https://github.com/kazewong/flowMC @@ -23,7 +23,7 @@ install_requires = tqdm corner -python_requires = >=3.9 +python_requires = >=3.10 [options.packages.find] where=src diff --git a/src/flowMC/nfmodel/base.py b/src/flowMC/nfmodel/base.py index da89048..99fdf5e 100644 --- a/src/flowMC/nfmodel/base.py +++ b/src/flowMC/nfmodel/base.py @@ -1,5 +1,4 @@ from abc import abstractmethod, abstractproperty -from typing import Tuple import equinox as eqx import jax from jaxtyping import Array, PRNGKeyArray, Float @@ -15,48 +14,48 @@ class NFModel(eqx.Module): def __init__(self): return NotImplemented - def __call__(self, x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]: """ Forward pass of the model. Args: - x (Array): Input data. + x (Float[Array, "n_dim"]): Input data. Returns: - Tuple[Array, Array]: Output data and log determinant of the Jacobian. + tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian. """ return self.forward(x) @abstractmethod - def log_prob(self, x: Array) -> Array: + def log_prob(self, x: Float[Array, "n_dim"]) -> Float: return NotImplemented @abstractmethod - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: return NotImplemented @abstractmethod - def forward(self, x: Array) -> Tuple[Array, Array]: + def forward(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]: """ Forward pass of the model. Args: - x (Array): Input data. + x (Float[Array, "n_dim"]): Input data. Returns: - Tuple[Array, Array]: Output data and log determinant of the Jacobian.""" + tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.""" return NotImplemented @abstractmethod - def inverse(self, x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]: """ Inverse pass of the model. Args: - x (Array): Input data. + x (Float[Array, "n_dim"]): Input data. Returns: - Tuple[Array, Array]: Output data and log determinant of the Jacobian.""" + tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.""" return NotImplemented @abstractproperty @@ -81,15 +80,15 @@ class Bijection(eqx.Module): def __init__(self): return NotImplemented - def __call__(self, x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Array) -> tuple[Array, Array]: return self.forward(x) @abstractmethod - def forward(self, x: Array) -> Tuple[Array, Array]: + def forward(self, x: Array) -> tuple[Array, Array]: return NotImplemented @abstractmethod - def inverse(self, x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Array) -> tuple[Array, Array]: return NotImplemented diff --git a/src/flowMC/nfmodel/common.py b/src/flowMC/nfmodel/common.py index 0e9cca5..b4d59bf 100644 --- a/src/flowMC/nfmodel/common.py +++ b/src/flowMC/nfmodel/common.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Iterable, Tuple +from typing import Callable, List, Tuple import jax import jax.numpy as jnp @@ -10,15 +10,15 @@ import jax import jax.numpy as jnp import equinox as eqx -from jaxtyping import Array +from jaxtyping import Array, Float, PRNGKeyArray class MLP(eqx.Module): r"""Multilayer perceptron. Args: - shape (Iterable[int]): Shape of the MLP. The first element is the input dimension, the last element is the output dimension. - key (jax.random.PRNGKey): Random key. + shape (List[int]): Shape of the MLP. The first element is the input dimension, the last element is the output dimension. + key (PRNGKeyArray): Random key. Attributes: layers (List): List of layers. @@ -29,9 +29,9 @@ class MLP(eqx.Module): def __init__( self, - shape: Iterable[int], - key: jax.random.PRNGKey, - scale: float = 1e-4, + shape: List[int], + key: PRNGKeyArray, + scale: Float = 1e-4, activation: Callable = jax.nn.relu, use_bias: bool = True, ): @@ -52,7 +52,7 @@ def __init__( eqx.nn.Linear(shape[-2], shape[-1], key=subkey, use_bias=use_bias) ) - def __call__(self, x: Array): + def __call__(self, x: Float[Array, "n_in"]) -> Float[Array, "n_out"]: for layer in self.layers: x = layer(x) return x @@ -83,25 +83,25 @@ class MaskedCouplingLayer(Bijection): """ - _mask: Array + _mask: Float[Array, "n_dim"] bijector: Bijection @property - def mask(self): + def mask(self) -> Float[Array, "n_dim"]: return jax.lax.stop_gradient(self._mask) - def __init__(self, bijector: Bijection, mask: Array): + def __init__(self, bijector: Bijection, mask: Float[Array, "n_dim"]): self.bijector = bijector self._mask = mask - def forward(self, x: Array) -> Tuple[Array, Array]: - y, log_det = self.bijector(x, x * self.mask) + def forward(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: + y, log_det = self.bijector(x, x * self.mask) # type: ignore y = (1 - self.mask) * y + self.mask * x log_det = ((1 - self.mask) * log_det).sum() return y, log_det - def inverse(self, x: Array) -> Tuple[Array, Array]: - y, log_det = self.bijector.inverse(x, x * self.mask) + def inverse(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: + y, log_det = self.bijector.inverse(x, x * self.mask) # type: ignore y = (1 - self.mask) * y + self.mask * x log_det = ((1 - self.mask) * log_det).sum() return y, log_det @@ -110,17 +110,17 @@ def inverse(self, x: Array) -> Tuple[Array, Array]: class MLPAffine(Bijection): scale_MLP: MLP shift_MLP: MLP - dt: float = 1 + dt: Float = 1 - def __init__(self, scale_MLP: MLP, shift_MLP: MLP, dt: float = 1): + def __init__(self, scale_MLP: MLP, shift_MLP: MLP, dt: Float = 1): self.scale_MLP = scale_MLP self.shift_MLP = shift_MLP self.dt = dt - def __call__(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: return self.forward(x, condition_x) - def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def forward(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: # Note that this note output log_det as an array instead of a number. # This is because we need to sum over the log_det in the masked coupling layer. scale = jnp.tanh(self.scale_MLP(condition_x)) * self.dt @@ -129,7 +129,7 @@ def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: y = (x + shift) * jnp.exp(scale) return y, log_det - def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: scale = jnp.tanh(self.scale_MLP(condition_x)) * self.dt shift = self.shift_MLP(condition_x) * self.dt log_det = -scale @@ -141,19 +141,19 @@ class ScalarAffine(Bijection): scale: Array shift: Array - def __init__(self, scale: float, shift: float): + def __init__(self, scale: Float, shift: Float): self.scale = jnp.array(scale) self.shift = jnp.array(shift) - def __call__(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: return self.forward(x, condition_x) - def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def forward(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: y = (x + self.shift) * jnp.exp(self.scale) log_det = self.scale return y, log_det - def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: y = x * jnp.exp(-self.scale) - self.shift log_det = -self.scale return y, log_det @@ -173,33 +173,33 @@ class Gaussian(Distribution): cov (Array): Covariance matrix. """ - _mean: Array - _cov: Array + _mean: Float[Array, "n_dim"] + _cov: Float[Array, "n_dim n_dim"] learnable: bool = False @property - def mean(self) -> Array: + def mean(self) -> Float[Array, "n_dim"]: if self.learnable: return self._mean else: return jax.lax.stop_gradient(self._mean) @property - def cov(self) -> Array: + def cov(self) -> Float[Array, "n_dim n_dim"]: if self.learnable: return self._cov else: return jax.lax.stop_gradient(self._cov) - def __init__(self, mean: Array, cov: Array, learnable: bool = False): + def __init__(self, mean: Float[Array, "n_dim"], cov: Float[Array, "n_dim n_dim"], learnable: bool = False): self._mean = mean self._cov = cov self.learnable = learnable - def log_prob(self, x: Array) -> Array: + def log_prob(self, x: Float[Array, "n_dim"]) -> Float: return jax.scipy.stats.multivariate_normal.logpdf(x, self.mean, self.cov) - def sample(self, key: jax.random.PRNGKey, n_samples: int = 1) -> Array: + def sample(self, key: PRNGKeyArray, n_samples: int = 1) -> Float[Array, "n_samples n_dim"]: return jax.random.multivariate_normal(key, self.mean, self.cov, (n_samples,)) @@ -212,14 +212,15 @@ def __init__(self, distributions: list[Distribution], partitions: dict): self.distributions = distributions self.partitions = partitions - def log_prob(self, x: Array) -> Array: + def log_prob(self, x: Float[Array, "n_dim"]) -> Float: log_prob = 0 for dist, (_, ranges) in zip(self.distributions, self.partitions.items()): log_prob += dist.log_prob(x[ranges[0] : ranges[1]]) return log_prob - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> dict[str, Float[Array, "n_samples n_dim"]]: samples = {} for dist, (key, _) in zip(self.distributions, self.partitions.items()): rng_key, sub_key = jax.random.split(rng_key) samples[key] = dist.sample(sub_key, n_samples=n_samples) + return samples \ No newline at end of file diff --git a/src/flowMC/nfmodel/realNVP.py b/src/flowMC/nfmodel/realNVP.py index b4d2714..26599b3 100644 --- a/src/flowMC/nfmodel/realNVP.py +++ b/src/flowMC/nfmodel/realNVP.py @@ -5,7 +5,7 @@ import equinox as eqx from flowMC.nfmodel.base import NFModel, Distribution from flowMC.nfmodel.common import MLP, MaskedCouplingLayer, MLPAffine, Gaussian -from jaxtyping import Array +from jaxtyping import Array, Float, PRNGKeyArray from functools import partial @@ -19,22 +19,22 @@ class AffineCoupling(eqx.Module): n_features: (int) The number of features in the input. n_hidden: (int) The number of hidden units in the MLP. mask: (ndarray) Alternating mask for the affine coupling layer. - dt: (float) Scaling factor for the affine coupling layer. + dt: (Float) Scaling factor for the affine coupling layer. """ _mask: Array - scale_MLP: eqx.Module - translate_MLP: eqx.Module - dt: float = 1 + scale_MLP: MLP + translate_MLP: MLP + dt: Float = 1 def __init__( self, n_features: int, n_hidden: int, mask: Array, - key: jax.random.PRNGKey, - dt: float = 1, - scale: float = 1e-4, + key: PRNGKeyArray, + dt: Float = 1, + scale: Float = 1e-4, ): self._mask = mask self.dt = dt @@ -102,7 +102,7 @@ class RealNVP(NFModel): n_layer: (int) The number of affine coupling layers. n_features: (int) The number of features in the input. n_hidden: (int) The number of hidden units in the MLP. - dt: (float) Scaling factor for the affine coupling layer. + dt: (Float) Scaling factor for the affine coupling layer. Properties: data_mean: (ndarray) Mean of Gaussian base distribution @@ -112,8 +112,8 @@ class RealNVP(NFModel): base_dist: Distribution affine_coupling: List[MaskedCouplingLayer] _n_features: int - _data_mean: Array - _data_cov: Array + _data_mean: Array | None + _data_cov: Array | None @property def n_features(self): @@ -128,16 +128,11 @@ def data_cov(self): return jax.lax.stop_gradient(self._data_cov) def __init__( - self, - n_features: int, - n_layer: int, - n_hidden: int, - key: jax.random.PRNGKey, - **kwargs + self, n_features: int, n_layer: int, n_hidden: int, key: PRNGKeyArray, **kwargs ): if kwargs.get("base_dist") is not None: - self.base_dist = kwargs.get("base_dist") + self.base_dist = kwargs.get("base_dist") # type: ignore else: self.base_dist = Gaussian( jnp.zeros(n_features), jnp.eye(n_features), learnable=False @@ -177,18 +172,18 @@ def __init__( else: self._data_cov = jnp.eye(n_features) - def __call__(self, x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Array) -> Tuple[Array, Float]: return self.forward(x) - def forward(self, x: Array) -> Tuple[Array, Array]: - log_det = 0 + def forward(self, x: Array) -> Tuple[Array, Float]: + log_det = 0.0 for i in range(len(self.affine_coupling)): x, log_det_i = self.affine_coupling[i](x) log_det += log_det_i return x, log_det @partial(jax.vmap, in_axes=(None, 0)) - def inverse(self, x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Array) -> Tuple[Array, Float]: """From latent space to data space""" log_det = 0.0 for layer in reversed(self.affine_coupling): @@ -197,8 +192,7 @@ def inverse(self, x: Array) -> Tuple[Array, Array]: return x, log_det @eqx.filter_jit - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: - + def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: samples = self.base_dist.sample(rng_key, n_samples) samples = self.inverse(samples)[0] samples = samples * jnp.sqrt(jnp.diag(self.data_cov)) + self.data_mean diff --git a/src/flowMC/nfmodel/rqSpline.py b/src/flowMC/nfmodel/rqSpline.py index fa0d0a5..e865209 100644 --- a/src/flowMC/nfmodel/rqSpline.py +++ b/src/flowMC/nfmodel/rqSpline.py @@ -1,7 +1,6 @@ -from typing import Sequence, Tuple import jax import jax.numpy as jnp -from jaxtyping import Array +from jaxtyping import Array, PRNGKeyArray, Float import equinox as eqx from flowMC.nfmodel.base import NFModel, Bijection, Distribution @@ -21,7 +20,7 @@ def _normalize_bin_sizes( @partial(jax.vmap, in_axes=(0, None)) def _normalize_knot_slopes( - unnormalized_knot_slopes: Array, min_knot_slope: float + unnormalized_knot_slopes: Array, min_knot_slope: Float ) -> Array: """Make knot slopes be no less than `min_knot_slope`.""" # The offset is such that the normalized knot slope will be equal to 1 @@ -34,7 +33,7 @@ def _normalize_knot_slopes( @partial(jax.vmap, in_axes=(0, 0, 0, 0)) def _rational_quadratic_spline_fwd( x: Array, x_pos: Array, y_pos: Array, knot_slopes: Array -) -> Tuple[Array, Array]: +) -> tuple[Array, Array]: """Applies a rational-quadratic spline to a scalar. Args: @@ -110,7 +109,9 @@ def _rational_quadratic_spline_fwd( # If x is outside the spline range, we default to a linear transformation. y = jnp.where(below_range, (x - x_pos[0]) * knot_slopes[0] + y_pos[0], y) - y = jnp.where(above_range, (x - x_pos[-1]) * knot_slopes[-1] + y_pos[-1], y) + y = jnp.where( + above_range, (x - x_pos[-1]) * knot_slopes[-1] + y_pos[-1], y + ) # type: ignore logdet = jnp.where(below_range, jnp.log(knot_slopes[0]), logdet) logdet = jnp.where(above_range, jnp.log(knot_slopes[-1]), logdet) return y, logdet @@ -146,7 +147,7 @@ def _safe_quadratic_root(a: Array, b: Array, c: Array) -> Array: @partial(jax.vmap, in_axes=(0, 0, 0, 0)) def _rational_quadratic_spline_inv( y: Array, x_pos: Array, y_pos: Array, knot_slopes: Array -) -> Tuple[Array, Array]: +) -> tuple[Array, Array]: """Applies the inverse of a rational-quadratic spline to a scalar. Args: @@ -217,20 +218,21 @@ def _rational_quadratic_spline_inv( # If y is outside the spline range, we default to a linear transformation. x = jnp.where(below_range, (y - y_pos[0]) / knot_slopes[0] + x_pos[0], x) - x = jnp.where(above_range, (y - y_pos[-1]) / knot_slopes[-1] + x_pos[-1], x) + x = jnp.where( + above_range, (y - y_pos[-1]) / knot_slopes[-1] + x_pos[-1], x + ) # type: ignore logdet = jnp.where(below_range, -jnp.log(knot_slopes[0]), logdet) logdet = jnp.where(above_range, -jnp.log(knot_slopes[-1]), logdet) return x, logdet class RQSpline(Bijection): - _range_min: float _range_max: float _num_bins: int _min_bin_size: float _min_knot_slope: float - conditioner: eqx.Module + conditioner: MLP """A rational-quadratic spline bijection. @@ -279,13 +281,12 @@ def dtype(self): def __init__( self, - conditioner: eqx.Module, + conditioner: MLP, range_min: float, range_max: float, min_bin_size: float = 1e-4, min_knot_slope: float = 1e-4, ): - self._range_min = range_min self._range_max = range_max self._min_bin_size = min_bin_size @@ -294,7 +295,11 @@ def __init__( self.conditioner = conditioner - def get_params(self, x: Array) -> Array: + def get_params( + self, x: Float[Array, " n_condition"] + ) -> tuple[ + Float[Array, " n_param"], Float[Array, " n_param"], Float[Array, " n_param"] + ]: params = self.conditioner(x).reshape(-1, self._num_bins * 3 + 1) unnormalized_bin_widths = params[:, : self._num_bins] unnormalized_bin_heights = params[:, self._num_bins : 2 * self._num_bins] @@ -320,14 +325,14 @@ def get_params(self, x: Array) -> Array: ) return x_pos, y_pos, knot_slopes - def __call__(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Array, condition_x: Array) -> tuple[Array, Array]: return self.forward(x, condition_x) - def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def forward(self, x: Array, condition_x: Array) -> tuple[Array, Array]: x_pos, y_pos, knot_slopes = self.get_params(condition_x) return _rational_quadratic_spline_fwd(x, x_pos, y_pos, knot_slopes) - def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Array, condition_x: Array) -> tuple[Array, Array]: x_pos, y_pos, knot_slopes = self.get_params(condition_x) return _rational_quadratic_spline_inv(x, x_pos, y_pos, knot_slopes) @@ -340,7 +345,7 @@ class MaskedCouplingRQSpline(NFModel): num_layers (int): Number of layers in the conditioner. hidden_size (Sequence[int]): Hidden size of the conditioner. num_bins (int): Number of bins in the spline. - key (jax.random.PRNGKey): Random key for initialization. + key (PRNGKeyArray): Random key for initialization. spline_range (Sequence[float]): Range of the spline. Defaults to (-10.0, 10.0). Properties: @@ -350,10 +355,10 @@ class MaskedCouplingRQSpline(NFModel): """ base_dist: Distribution - layers: list[eqx.Module] + layers: list[Bijection] _n_features: int - _data_mean: Array - _data_cov: Array + _data_mean: Float[Array, " n_dim"] + _data_cov: Float[Array, " n_dim n_dim"] @property def n_features(self): @@ -371,27 +376,32 @@ def __init__( self, n_features: int, n_layers: int, - hidden_size: Sequence[int], + hidden_size: list[int], num_bins: int, - key: jax.random.PRNGKey, - spline_range: Sequence[float] = (-10.0, 10.0), + key: PRNGKeyArray, + spline_range: tuple[float, float] = (-10.0, 10.0), **kwargs ): - if kwargs.get("base_dist") is not None: - self.base_dist = kwargs.get("base_dist") + dist = kwargs.get("base_dist") + assert isinstance(dist, Distribution) + self.base_dist = dist else: self.base_dist = Gaussian( jnp.zeros(n_features), jnp.eye(n_features), learnable=False ) if kwargs.get("data_mean") is not None: - self._data_mean = kwargs.get("data_mean") + data_mean = kwargs.get("data_mean") + assert isinstance(data_mean, Array) + self._data_mean = data_mean else: self._data_mean = jnp.zeros(n_features) if kwargs.get("data_cov") is not None: - self._data_cov = kwargs.get("data_cov") + data_cov = kwargs.get("data_cov") + assert isinstance(data_cov, Array) + self._data_cov = data_cov else: self._data_cov = jnp.eye(n_features) @@ -421,10 +431,14 @@ def __init__( mask = jnp.logical_not(mask) self.layers = layers - def __call__(self, x: Array) -> Tuple[Array, Array]: + def __call__( + self, x: Float[Array, " n_dim"] + ) -> tuple[Float[Array, " n_dim"], Float]: return self.forward(x) - def forward(self, x: Array) -> Tuple[Array, Array]: + def forward( + self, x: Float[Array, " n_dim"] + ) -> tuple[Float[Array, " n_dim"], Float]: log_det = 0.0 for layer in self.layers: x, log_det_i = layer(x) @@ -432,7 +446,9 @@ def forward(self, x: Array) -> Tuple[Array, Array]: return x, log_det @partial(jax.vmap, in_axes=(None, 0)) - def inverse(self, x: Array) -> Tuple[Array, Array]: + def inverse( + self, x: Float[Array, " n_dim"] + ) -> tuple[Float[Array, " n_dim"], Float]: """From latent space to data space""" log_det = 0.0 for layer in reversed(self.layers): @@ -441,7 +457,9 @@ def inverse(self, x: Array) -> Tuple[Array, Array]: return x, log_det @eqx.filter_jit - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> Float[Array, "n_samples n_dim"]: samples = self.base_dist.sample(rng_key, n_samples) samples = self.inverse(samples)[0] samples = samples * jnp.sqrt(jnp.diag(self.data_cov)) + self.data_mean @@ -449,7 +467,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: @eqx.filter_jit @partial(jax.vmap, in_axes=(None, 0)) - def log_prob(self, x: Array) -> Array: + def log_prob(self, x: Float[Array, "n_sample n_dim"]) -> Float[Array, " n_sample"]: """From data space to latent space""" x = (x - self.data_mean) / jnp.sqrt(jnp.diag(self.data_cov)) y, log_det = self.__call__(x) diff --git a/src/flowMC/nfmodel/utils.py b/src/flowMC/nfmodel/utils.py index 4881671..a64507b 100644 --- a/src/flowMC/nfmodel/utils.py +++ b/src/flowMC/nfmodel/utils.py @@ -1,13 +1,15 @@ import jax import jax.numpy as jnp # JAX NumPy -from tqdm import trange +from tqdm import trange, tqdm import optax import equinox as eqx from typing import Callable, Tuple -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, PRNGKeyArray, Float -def make_training_loop(optim: optax.GradientTransformation) -> Callable: +def make_training_loop( + optim: optax.GradientTransformation, +) -> tuple[Callable, Callable, Callable]: """ Create a function that trains an NF model. @@ -24,7 +26,9 @@ def loss_fn(model, x): return -jnp.mean(model.log_prob(x)) @eqx.filter_jit - def train_step(model, x, opt_state): + def train_step( + model: eqx.Module, x: Float[Array, "n_batch n_dim"], opt_state: optax.OptState + ) -> Tuple[Float, eqx.Module, optax.OptState]: """Train for a single step. Args: @@ -42,7 +46,13 @@ def train_step(model, x, opt_state): model = eqx.apply_updates(model, updates) return loss, model, opt_state - def train_epoch(rng, model, state, train_ds, batch_size): + def train_epoch( + rng: PRNGKeyArray, + model: eqx.Module, + state: optax.OptState, + train_ds: Float[Array, "n_example n_dim"], + batch_size: Float, + )-> Tuple[Float, eqx.Module, optax.OptState]: """Train for a single epoch.""" train_ds_size = len(train_ds) steps_per_epoch = train_ds_size // batch_size @@ -67,7 +77,7 @@ def train_flow( num_epochs: int, batch_size: int, verbose: bool = True, - ) -> Tuple[PRNGKeyArray, eqx.Module, Array]: + ) -> Tuple[PRNGKeyArray, eqx.Module, optax.OptState, Array]: """Train a normalizing flow model. Args: @@ -100,6 +110,7 @@ def train_flow( best_model = model best_loss = loss_values[epoch] if verbose: + assert isinstance(pbar, tqdm) if num_epochs > 10: if epoch % int(num_epochs / 10) == 0: pbar.set_description(f"Training NF, current loss: {value:.3f}") diff --git a/src/flowMC/sampler/Gaussian_random_walk.py b/src/flowMC/sampler/Gaussian_random_walk.py index a45429d..07c57e7 100644 --- a/src/flowMC/sampler/Gaussian_random_walk.py +++ b/src/flowMC/sampler/Gaussian_random_walk.py @@ -18,10 +18,10 @@ class GaussianRandomWalk(ProposalBase): def __init__( self, - logpdf: Callable, + logpdf: Callable[[Float[Array, "n_dim"], PyTree], Float], jit: bool, params: dict, - ) -> Callable: + ): super().__init__(logpdf, jit, params) self.params = params self.logpdf = logpdf @@ -29,32 +29,33 @@ def __init__( def kernel( self, rng_key: PRNGKeyArray, - position: Float[Array, "ndim"], + position: Float[Array, "n_dim"], log_prob: Float[Array, "1"], data: PyTree, - ) -> tuple[Float[Array, "ndim"], Float[Array, "1"], Int[Array, "1"]]: + ) -> tuple[Float[Array, "n_dim"], Float[Array, "1"], Int[Array, "1"]]: """ Random walk gaussian kernel. This is a kernel that only evolve a single chain. Args: rng_key (PRNGKeyArray): Jax PRNGKey - position (Float[Array, "ndim"]): current position of the chain + position (Float[Array, "n_dim"]): current position of the chain log_prob (Float[Array, "1"]): current log-probability of the chain data (PyTree): data to be passed to the logpdf function Returns: - position (Float[Array, "ndim"]): new position of the chain + position (Float[Array, "n_dim"]): new position of the chain log_prob (Float[Array, "1"]): new log-probability of the chain do_accept (Int[Array, "1"]): whether the new position is accepted """ key1, key2 = jax.random.split(rng_key) - move_proposal = ( + move_proposal: Float[Array, "n_dim"] = ( jax.random.normal(key1, shape=position.shape) * self.params["step_size"] ) + proposal = position + move_proposal - proposal_log_prob = self.logpdf(proposal, data) + proposal_log_prob: Float[Array, "n_dim"] = self.logpdf(proposal, data) log_uniform = jnp.log(jax.random.uniform(key2)) do_accept = log_uniform < proposal_log_prob - log_prob @@ -63,11 +64,9 @@ def kernel( log_prob = jnp.where(do_accept, proposal_log_prob, log_prob) return position, log_prob, do_accept - def update( - self, i, state - ) -> tuple[ + def update(self, i, state) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], PyTree, @@ -86,11 +85,12 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, ) -> tuple[ - Float[Array, "n_chains n_steps ndim"], + PRNGKeyArray, + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: diff --git a/src/flowMC/sampler/HMC.py b/src/flowMC/sampler/HMC.py index 26113eb..879bf26 100644 --- a/src/flowMC/sampler/HMC.py +++ b/src/flowMC/sampler/HMC.py @@ -17,11 +17,20 @@ class HMC(ProposalBase): params: dictionary of parameters for the sampler """ - def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: + def __init__( + self, + logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float], + jit: bool, + params: dict, + ): super().__init__(logpdf, jit, params) - self.potential = lambda x, data: -logpdf(x, data) - self.grad_potential = jax.grad(self.potential) + self.potential: Callable[ + [Float[Array, " n_dim"], PyTree], Float + ] = lambda x, data: -logpdf(x, data) + self.grad_potential: Callable[ + [Float[Array, " n_dim"], PyTree], Float[Array, " n_dim"] + ] = jax.grad(self.potential) self.params = params if "condition_matrix" in params: @@ -42,21 +51,22 @@ def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: self.n_leapfrog = 10 print("n_leapfrog not specified, using default value 10") - coefs = jnp.ones((self.n_leapfrog+2, 2)) + coefs = jnp.ones((self.n_leapfrog + 2, 2)) coefs = coefs.at[0].set(jnp.array([0, 0.5])) coefs = coefs.at[-1].set(jnp.array([1, 0.5])) self.leapfrog_coefs = coefs - self.kinetic = lambda p, metric: 0.5 * (p**2 * metric).sum() + self.kinetic: Callable[ + [Float[Array, " n_dim"], Float[Array, " n_dim n_dim"]], Float + ] = (lambda p, metric: 0.5 * (p**2 * metric).sum()) self.grad_kinetic = jax.grad(self.kinetic) - def get_initial_hamiltonian( self, - rng_key: jax.random.PRNGKey, - position: jnp.array, - data: jnp.array, - params: dict, + rng_key: PRNGKeyArray, + position: Float[Array, " n_dim"], + data: PyTree, + params: PyTree, ): """ Compute the value of the Hamiltonian from positions with initial momentum draw @@ -67,65 +77,76 @@ def get_initial_hamiltonian( jax.random.normal(rng_key, shape=position.shape) * params["condition_matrix"] ** -0.5 ) - return self.potential(position, data) + self.kinetic(momentum, params["condition_matrix"]) + return self.potential(position, data) + self.kinetic( + momentum, params["condition_matrix"] + ) def leapfrog_kernel(self, carry, extras): position, momentum, data, metric, index = carry - position = position + self.params["step_size"] * self.leapfrog_coefs[index][0] * self.grad_kinetic( - momentum, metric - ) - momentum = momentum - self.params["step_size"] * self.leapfrog_coefs[index][1] * self.grad_potential( - position, data - ) + position = position + self.params["step_size"] * self.leapfrog_coefs[index][ + 0 + ] * self.grad_kinetic(momentum, metric) + momentum = momentum - self.params["step_size"] * self.leapfrog_coefs[index][ + 1 + ] * self.grad_potential(position, data) index = index + 1 return (position, momentum, data, metric, index), extras - def leapfrog_step(self, position, momentum, data, metric): + def leapfrog_step( + self, + position: Float[Array, " n_dim"], + momentum: Float[Array, " n_dim"], + data: PyTree, + metric: Float[Array, " n_dim n_dim"], + ) -> tuple[Float[Array, " n_dim"], Float[Array, " n_dim"]]: (position, momentum, data, metric, index), _ = jax.lax.scan( self.leapfrog_kernel, (position, momentum, data, metric, 0), - jnp.arange(self.n_leapfrog+2), + jnp.arange(self.n_leapfrog + 2), ) return position, momentum def kernel( self, rng_key: PRNGKeyArray, - position: Float[Array, "ndim"], + position: Float[Array, " n_dim"], log_prob: Float[Array, "1"], data: PyTree, - ) -> tuple[Float[Array, "ndim"], Float[Array, "1"], Int[Array, "1"]]: + ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]: """ Note that since the potential function is the negative log likelihood, hamiltonian is going down, but the likelihood value should go up. Args: rng_key (n_chains, 2): random key - position (n_chains, n_dim): current position + position (n_chains, n_dim): current position PE (n_chains, ): Potential energy of the current position """ key1, key2 = jax.random.split(rng_key) - momentum = ( + momentum: Float[Array, " n_dim"] = ( jax.random.normal(key1, shape=position.shape) * self.params["condition_matrix"] ** -0.5 ) momentum = jnp.dot( jax.random.normal(key1, shape=position.shape), - jnp.linalg.cholesky(jnp.linalg.inv(self.params["condition_matrix"])).T) - H = - log_prob + self.kinetic(momentum, self.params["condition_matrix"]) + jnp.linalg.cholesky(jnp.linalg.inv(self.params["condition_matrix"])).T, + ) + H = -log_prob + self.kinetic(momentum, self.params["condition_matrix"]) proposed_position, proposed_momentum = self.leapfrog_step( position, momentum, data, self.params["condition_matrix"] ) proposed_PE = self.potential(proposed_position, data) - proposed_ham = proposed_PE + self.kinetic(proposed_momentum, self.params["condition_matrix"]) + proposed_ham = proposed_PE + self.kinetic( + proposed_momentum, self.params["condition_matrix"] + ) log_acc = H - proposed_ham log_uniform = jnp.log(jax.random.uniform(key2)) do_accept = log_uniform < log_acc position = jnp.where(do_accept, proposed_position, position) - log_prob = jnp.where(do_accept, -proposed_PE, log_prob) + log_prob = jnp.where(do_accept, -proposed_PE, log_prob) # type: ignore return position, log_prob, do_accept @@ -133,7 +154,7 @@ def update( self, i, state ) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], PyTree, @@ -152,11 +173,12 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, ) -> tuple[ - Float[Array, "n_chains n_steps ndim"], + PRNGKeyArray, + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: diff --git a/src/flowMC/sampler/MALA.py b/src/flowMC/sampler/MALA.py index d932086..8339fb9 100644 --- a/src/flowMC/sampler/MALA.py +++ b/src/flowMC/sampler/MALA.py @@ -5,12 +5,13 @@ from tqdm import tqdm from flowMC.sampler.Proposal_Base import ProposalBase from functools import partialmethod -from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray +from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray, Bool class MALA(ProposalBase): """ - Metropolis-adjusted Langevin algorithm sampler class builiding the mala_sampler method + Metropolis-adjusted Langevin algorithm sampler clas + builiding the mala_sampler method Args: logpdf: target logpdf function @@ -18,13 +19,26 @@ class MALA(ProposalBase): params: dictionary of parameters for the sampler """ - def __init__(self, logpdf: Callable, jit: bool, params: dict, use_autotune=False): + def __init__( + self, + logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float], + jit: Bool, + params: dict, + use_autotune=False, + ): super().__init__(logpdf, jit, params) - self.params = params - self.logpdf = logpdf - self.use_autotune = use_autotune + self.params: PyTree = params + self.logpdf: Callable = logpdf + self.use_autotune: Bool = use_autotune - def body(self, carry, this_key): + def body( + self, + carry: tuple[Float[Array, " n_dim"], float, dict], + this_key: PRNGKeyArray, + ) -> tuple[ + tuple[Float[Array, " n_dim"], float, dict], + tuple[Float[Array, " n_dim"], Float[Array, "1"], Float[Array, " n_dim"]], + ]: print("Compiling MALA body") this_position, dt, data = carry dt2 = dt * dt @@ -36,29 +50,29 @@ def body(self, carry, this_key): def kernel( self, rng_key: PRNGKeyArray, - position: Float[Array, "ndim"], + position: Float[Array, " n_dim"], log_prob: Float[Array, "1"], data: PyTree, - ) -> tuple[Float[Array, "ndim"], Float[Array, "1"], Int[Array, "1"]]: + ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]: """ Metropolis-adjusted Langevin algorithm kernel. This is a kernel that only evolve a single chain. Args: rng_key (PRNGKeyArray): Jax PRNGKey - position (Float[Array, "ndim"]): current position of the chain + position (Float[Array, " n_dim"]): current position of the chain log_prob (Float[Array, "1"]): current log-probability of the chain data (PyTree): data to be passed to the logpdf function Returns: - position (Float[Array, "ndim"]): new position of the chain + position (Float[Array, " n_dim"]): new position of the chain log_prob (Float[Array, "1"]): new log-probability of the chain do_accept (Int[Array, "1"]): whether the new position is accepted """ key1, key2 = jax.random.split(rng_key) - dt = self.params["step_size"] + dt: Float = self.params["step_size"] dt2 = dt * dt _, (proposal, logprob, d_logprob) = jax.lax.scan( @@ -74,7 +88,7 @@ def kernel( ) log_uniform = jnp.log(jax.random.uniform(key2)) - do_accept = log_uniform < ratio + do_accept: Bool[Array, " n_dim"] = log_uniform < ratio position = jnp.where(do_accept, proposal[0], position) log_prob = jnp.where(do_accept, logprob[1], logprob[0]) @@ -85,7 +99,7 @@ def update( self, i, state ) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], PyTree, @@ -114,11 +128,12 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, ) -> tuple[ - Float[Array, "n_chains n_steps ndim"], + PRNGKeyArray, + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: @@ -151,13 +166,13 @@ def mala_sampler_autotune( Args: mala_kernel_vmap (Callable): A MALA kernel rng_key: Jax PRNGKey - initial_position (n_chains, n_dim): initial position of the chains + initial_position (n_chains, n_dim): initial position of the chains log_prob (n_chains, ): log-probability of the initial position params (dict): parameters of the MALA kernel max_iter (int): maximal number of iterations to tune the step size """ - tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) # type: ignore counter = 0 position, log_prob, do_accept = self.kernel_vmap( @@ -167,7 +182,8 @@ def mala_sampler_autotune( while (acceptance_rate <= 0.3) or (acceptance_rate >= 0.5): if counter > max_iter: print( - "Maximal number of iterations reached. Existing tuning with current parameters." + "Maximal number of iterations reached.\ + Existing tuning with current parameters." ) break if acceptance_rate <= 0.3: @@ -179,5 +195,5 @@ def mala_sampler_autotune( rng_key, initial_position, log_prob, data ) acceptance_rate = jnp.mean(do_accept) - tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) + tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) # type: ignore return params diff --git a/src/flowMC/sampler/NF_proposal.py b/src/flowMC/sampler/NF_proposal.py index 803d98f..c935c23 100644 --- a/src/flowMC/sampler/NF_proposal.py +++ b/src/flowMC/sampler/NF_proposal.py @@ -3,16 +3,14 @@ from jax import random from tqdm import tqdm from flowMC.nfmodel.base import NFModel -from jaxtyping import Array, PRNGKeyArray, PyTree from typing import Callable from flowMC.sampler.Proposal_Base import ProposalBase -from jaxtyping import Array, Float, Int, PRNGKeyArray +from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree from math import ceil @jax.tree_util.register_pytree_node_class class NFProposal(ProposalBase): - model: NFModel def __init__( @@ -28,13 +26,15 @@ def __init__( def kernel( self, rng_key: PRNGKeyArray, - initial_position: Float[Array, "ndim"], - proposal_position: Float[Array, "ndim"], + initial_position: Float[Array, " n_dim"], + proposal_position: Float[Array, " n_dim"], log_prob_initial: Float[Array, "1"], log_prob_proposal: Float[Array, "1"], log_prob_nf_initial: Float[Array, "1"], log_prob_nf_proposal: Float[Array, "1"], - ) -> tuple[Float[Array, "ndim"], Float[Array, "1"], Int[Array, "1"]]: + ) -> tuple[ + Float[Array, " n_dim"], Float[Array, "1"], Float[Array, "1"], Int[Array, "1"] + ]: rng_key, subkey = random.split(rng_key) ratio = (log_prob_proposal - log_prob_initial) - ( @@ -48,13 +48,27 @@ def kernel( return position, log_prob, log_prob_nf, do_accept def update( - self, i, state + self, + i: int, + state: tuple[ + PRNGKeyArray, + Float[Array, "nstep n_dim"], + Float[Array, "nstep n_dim"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Int[Array, "nstep 1"], + ], ) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], + Float[Array, "nstep n_dim"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], - PyTree, ]: ( key, @@ -95,12 +109,13 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, mode: str = "training", ) -> tuple[ - Float[Array, "n_chains n_steps ndim"], + PRNGKeyArray, + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: @@ -157,7 +172,7 @@ def sample( def sample_flow( self, rng_key: PRNGKeyArray, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data, n_steps: int, ): diff --git a/src/flowMC/sampler/Proposal_Base.py b/src/flowMC/sampler/Proposal_Base.py index c697892..e89589d 100644 --- a/src/flowMC/sampler/Proposal_Base.py +++ b/src/flowMC/sampler/Proposal_Base.py @@ -7,7 +7,12 @@ @jax.tree_util.register_pytree_node_class class ProposalBase: - def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: + def __init__( + self, + logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float], + jit: bool, + params: dict, + ): """ Initialize the sampler class """ @@ -31,39 +36,67 @@ def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: def precompilation(self, n_chains, n_dims, n_step, data): if self.jit is True: print("jit is requested, precompiling kernels and update...") + key = jax.random.split(jax.random.PRNGKey(0), n_chains) + + self.logpdf_vmap = ( + jax.jit(self.logpdf_vmap) + .lower(jnp.ones((n_chains, n_dims)), data) + .compile() + ) + self.kernel_vmap = ( + jax.jit(self.kernel_vmap) + .lower( + key, + jnp.ones((n_chains, n_dims)), + jnp.ones((n_chains, )), + data, + ) + .compile() + ) + self.update_vmap = ( + jax.jit(self.update_vmap) + .lower( + 1, + ( + key, + jnp.ones((n_chains, n_step, n_dims)), + jnp.ones((n_chains, n_step, )), + jnp.zeros((n_chains, n_step, )), + data, + ), + ) + .compile() + ) else: print("jit is not requested, compiling only vmap functions...") - - key = jax.random.split(jax.random.PRNGKey(0), n_chains) - - self.logpdf_vmap(jnp.ones((n_chains, n_dims)), data) - self.kernel_vmap( - key, - jnp.ones((n_chains, n_dims)), - jnp.ones((n_chains, 1)), - data, - ) - self.update_vmap( - 1, - ( + key = jax.random.split(jax.random.PRNGKey(0), n_chains) + self.logpdf_vmap = self.logpdf_vmap(jnp.ones((n_chains, n_dims)), data) + self.kernel_vmap( key, - jnp.ones((n_chains, n_step, n_dims)), - jnp.ones((n_chains, n_step, 1)), - jnp.zeros((n_chains, n_step, 1)), + jnp.ones((n_chains, n_dims)), + jnp.ones((n_chains, )), data, - ), - ) + ) + self.update_vmap( + 1, + ( + key, + jnp.ones((n_chains, n_step, n_dims)), + jnp.ones((n_chains, n_step, )), + jnp.zeros((n_chains, n_step, )), + data, + ), + ) @abstractmethod def kernel( self, rng_key: PRNGKeyArray, - position: Float[Array, "nstep ndim"], + position: Float[Array, "nstep n_dim"], log_prob: Float[Array, "nstep 1"], data: PyTree, - params: dict, ) -> tuple[ - Float[Array, "nstep ndim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"] + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"] ]: """ Kernel for one step in the proposal cycle. @@ -71,10 +104,18 @@ def kernel( @abstractmethod def update( - self, i, state + self, + i: Float, + state: tuple[ + PRNGKeyArray, + Float[Array, "nstep n_dim"], + Float[Array, "nstep 1"], + Int[Array, "n_step 1"], + PyTree, + ], ) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], PyTree, @@ -88,11 +129,11 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, ) -> tuple[ - Float[Array, "n_chains n_steps ndim"], + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: diff --git a/src/flowMC/sampler/Sampler.py b/src/flowMC/sampler/Sampler.py index 81de1d5..afa13cb 100644 --- a/src/flowMC/sampler/Sampler.py +++ b/src/flowMC/sampler/Sampler.py @@ -71,7 +71,7 @@ def __init__( self, n_dim: int, rng_key_set: Tuple, - data: jnp.ndarray, + data: dict, local_sampler: ProposalBase, nf_model: NFModel, **kwargs, diff --git a/src/flowMC/utils/EvolutionaryOptimizer.py b/src/flowMC/utils/EvolutionaryOptimizer.py index 4a44a07..ff0926c 100644 --- a/src/flowMC/utils/EvolutionaryOptimizer.py +++ b/src/flowMC/utils/EvolutionaryOptimizer.py @@ -1,11 +1,11 @@ from evosax import CMA_ES import jax import jax.numpy as jnp +from jaxtyping import PRNGKeyArray import tqdm class EvolutionaryOptimizer: - """ A wrapper class for the evosax package. Note that we do not aim to solve any generic optimization problem, @@ -45,7 +45,7 @@ def __init__(self, ndims, popsize=100, verbose=False): self.history = [] self.state = None - def optimize(self, objective, bound, n_loops = 100, seed = 9527, keep_history_step = 0): + def optimize(self, objective, bound, n_loops=100, seed=9527, keep_history_step=0): """ Optimize the objective function. @@ -66,27 +66,44 @@ def optimize(self, objective, bound, n_loops = 100, seed = 9527, keep_history_st """ rng = jax.random.PRNGKey(seed) key, subkey = jax.random.split(rng) - progress_bar = tqdm.tqdm(range(n_loops), "Generation: ") if self.verbose else range(n_loops) + progress_bar = ( + tqdm.tqdm(range(n_loops), "Generation: ") + if self.verbose + else range(n_loops) + ) self.bound = bound self.state = self.strategy.initialize(key, self.es_params) if keep_history_step > 0: self.history = [] for i in progress_bar: - subkey, self.state, theta = self.optimize_step(subkey, self.state, objective, bound) - if i%keep_history_step == 0: self.history.append(theta) - if self.verbose: progress_bar.set_description(f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}") + subkey, self.state, theta = self.optimize_step( + subkey, self.state, objective, bound + ) + if i % keep_history_step == 0: + self.history.append(theta) + if self.verbose: + progress_bar.set_description( + f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}" + ) self.history = jnp.array(self.history) else: for i in progress_bar: - subkey, self.state, _ = self.optimize_step(subkey, self.state, objective, bound) - if self.verbose: progress_bar.set_description(f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}") + subkey, self.state, _ = self.optimize_step( + subkey, self.state, objective, bound + ) + if self.verbose: + progress_bar.set_description( + f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}" + ) - def optimize_step(self, key: jax.random.PRNGKey, state, objective: callable, bound): + def optimize_step(self, key: PRNGKeyArray, state, objective: callable, bound): key, subkey = jax.random.split(key) x, state = self.strategy.ask(subkey, state, self.es_params) theta = x * (bound[:, 1] - bound[:, 0]) + bound[:, 0] fitness = objective(theta) - state = self.strategy.tell(x, fitness.astype(jnp.float32), state, self.es_params) + state = self.strategy.tell( + x, fitness.astype(jnp.float32), state, self.es_params + ) return key, state, theta def get_result(self): diff --git a/test/integration/test_HMC.py b/test/integration/test_HMC.py index bceeeec..92f2f13 100644 --- a/test/integration/test_HMC.py +++ b/test/integration/test_HMC.py @@ -3,14 +3,14 @@ import jax import jax.numpy as jnp from jax.scipy.special import logsumexp +from jaxtyping import Float, Array - -def dual_moon_pe(x, data): +def dual_moon_pe(x: Float[Array, "n_dim"], data: dict): """ Term 2 and 3 separate the distribution and smear it along the first and second dimension """ print("compile count") - term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2 + term1 = 0.5 * ((jnp.linalg.norm(x - data['data']) - 2) / 0.1) ** 2 term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 return -(term1 - logsumexp(term2) - logsumexp(term3)) @@ -22,7 +22,7 @@ def dual_moon_pe(x, data): step_size = 0.1 n_leapfrog = 10 -data = jnp.arange(5) +data = {'data':jnp.arange(5)} rng_key_set = initialize_rng_keys(n_chains, seed=42) @@ -40,7 +40,6 @@ def dual_moon_pe(x, data): initial_PE = HMC_sampler.logpdf_vmap(initial_position, data) -HMC_sampler.precompilation(n_chains, n_dim, n_local_steps, data) initial_position = jnp.repeat(initial_position[:, None], n_local_steps, 1) initial_PE = jnp.repeat(initial_PE[:, None], n_local_steps, 1) @@ -80,7 +79,7 @@ def dual_moon_pe(x, data): nf_sampler = Sampler( n_dim, rng_key_set, - jnp.arange(5), + data, HMC_sampler, model, n_loop_training=n_loop_training, @@ -89,6 +88,7 @@ def dual_moon_pe(x, data): n_global_steps=n_global_steps, n_chains=n_chains, use_global=False, + precompile=True, ) nf_sampler.sample(initial_position, data) diff --git a/test/integration/test_MALA.py b/test/integration/test_MALA.py index 4aa2440..1b34888 100644 --- a/test/integration/test_MALA.py +++ b/test/integration/test_MALA.py @@ -3,14 +3,15 @@ import jax import jax.numpy as jnp from jax.scipy.special import logsumexp +from jaxtyping import Float, Array -def dual_moon_pe(x, data): +def dual_moon_pe(x: Float[Array, "n_dim"], data: dict): """ Term 2 and 3 separate the distribution and smear it along the first and second dimension """ print("compile count") - term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2 + term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2 term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 return -(term1 - logsumexp(term2) - logsumexp(term3)) @@ -22,7 +23,7 @@ def dual_moon_pe(x, data): step_size = 0.01 n_leapfrog = 10 -data = jnp.arange(5) +data = {"data": jnp.arange(5)} rng_key_set = initialize_rng_keys(n_chains, seed=42) @@ -30,9 +31,6 @@ def dual_moon_pe(x, data): MALA_Sampler = MALA(dual_moon_pe, True, {"step_size": step_size}) -MALA_Sampler.precompilation(n_chains, n_dim, n_local_steps, data) - - initial_position = jnp.repeat(initial_position[:, None], n_local_steps, 1) initial_logp = jnp.repeat( jax.vmap(dual_moon_pe, in_axes=(0, None))(initial_position[:, 0], data)[:, None], @@ -44,7 +42,12 @@ def dual_moon_pe(x, data): rng_key_set[1], initial_position, initial_logp, - jnp.zeros((n_chains, n_local_steps, 1)), + jnp.zeros( + ( + n_chains, + n_local_steps, + ) + ), data, ) @@ -77,7 +80,7 @@ def dual_moon_pe(x, data): nf_sampler = Sampler( n_dim, rng_key_set, - jnp.arange(5), + data, MALA_Sampler, model, n_loop_training=n_loop_training, @@ -87,6 +90,7 @@ def dual_moon_pe(x, data): n_chains=n_chains, # local_autotune=mala_sampler_autotune, use_global=False, + precompile=True, ) nf_sampler.sample(initial_position, data) diff --git a/test/integration/test_RWMCMC.py b/test/integration/test_RWMCMC.py index 3ac1c42..eb1da0d 100644 --- a/test/integration/test_RWMCMC.py +++ b/test/integration/test_RWMCMC.py @@ -3,14 +3,15 @@ import jax import jax.numpy as jnp from jax.scipy.special import logsumexp +from jaxtyping import Float, Array -def dual_moon_pe(x, data): +def dual_moon_pe(x: Float[Array, "n_dim"], data: dict): """ Term 2 and 3 separate the distribution and smear it along the first and second dimension """ print("compile count") - term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2 + term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2 term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 return -(term1 - logsumexp(term2) - logsumexp(term3)) @@ -22,7 +23,7 @@ def dual_moon_pe(x, data): step_size = 0.1 n_leapfrog = 10 -data = jnp.arange(5) +data = {"data": jnp.arange(5)} rng_key_set = initialize_rng_keys(n_chains, seed=42) @@ -30,26 +31,28 @@ def dual_moon_pe(x, data): RWMCMC_sampler = GaussianRandomWalk(dual_moon_pe, True, {"step_size": step_size}) -RWMCMC_sampler.precompilation(n_chains, n_dim, n_local_steps, data) - initial_position = jnp.repeat(initial_position[:, None], n_local_steps, 1) initial_logp = jnp.repeat( jax.vmap(dual_moon_pe, in_axes=(0, None))(initial_position[:, 0], data)[:, None], n_local_steps, 1, -)[..., None] +) state = ( rng_key_set[1], initial_position, initial_logp, - jnp.zeros((n_chains, n_local_steps, 1)), + jnp.zeros( + ( + n_chains, + n_local_steps, + ) + ), data, ) RWMCMC_sampler.update_vmap(1, state) - state = RWMCMC_sampler.sample( rng_key_set[1], n_local_steps, initial_position[:, 0], data ) diff --git a/test/integration/test_flowHMC.py b/test/integration/test_flowHMC.py index 4535af0..a5e5733 100644 --- a/test/integration/test_flowHMC.py +++ b/test/integration/test_flowHMC.py @@ -3,6 +3,7 @@ from flowMC.utils.PRNG_keys import initialize_rng_keys import jax import jax.numpy as jnp +from jaxtyping import Float, Array from jax.scipy.special import logsumexp from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.sampler.Sampler import Sampler @@ -13,18 +14,19 @@ def log_posterior(x, data): Term 2 and 3 separate the distribution and smear it along the first and second dimension """ print("compile count") - term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2 + term1 = 0.5 * ((jnp.linalg.norm(x - data['data']) - 2) / 0.1) ** 2 term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 return -(term1 - logsumexp(term2) - logsumexp(term3)) + n_dim = 8 n_chains = 15 n_local_steps = 30 step_size = 0.1 n_leapfrog = 3 -data = jnp.arange(n_dim) +data = {'data':jnp.arange(n_dim)} rng_key_set = initialize_rng_keys(n_chains, seed=42) @@ -60,18 +62,20 @@ def log_posterior(x, data): momentum = jax.random.normal(subkeys[0], shape=initial_position.shape) -nf_sampler = Sampler(n_dim, - rng_key_set, - data, - local_sampler, - model, - n_local_steps = 50, - n_global_steps = 50, - n_epochs = 30, - learning_rate = 1e-2, - batch_size = 10000, - n_chains = n_chains, - global_sampler = flowHMC_sampler, - verbose = True,) - -nf_sampler.sample(initial_position, data) \ No newline at end of file +nf_sampler = Sampler( + n_dim, + rng_key_set, + data, + local_sampler, + model, + n_local_steps=50, + n_global_steps=50, + n_epochs=30, + learning_rate=1e-2, + batch_size=10000, + n_chains=n_chains, + global_sampler=flowHMC_sampler, + verbose=True, +) + +nf_sampler.sample(initial_position, data) diff --git a/test/integration/test_quickstart.py b/test/integration/test_quickstart.py index 3353c0b..72066ea 100644 --- a/test/integration/test_quickstart.py +++ b/test/integration/test_quickstart.py @@ -7,11 +7,11 @@ from flowMC.nfmodel.utils import * -def log_posterior(x, data): - return -0.5 * jnp.sum((x - data) ** 2) +def log_posterior(x, data: dict): + return -0.5 * jnp.sum((x - data['data']) ** 2) -data = jnp.arange(5) +data = {'data':jnp.arange(5)} n_dim = 5 n_chains = 10 @@ -25,7 +25,7 @@ def log_posterior(x, data): nf_sampler = Sampler( n_dim, rng_key_set, - jnp.arange(n_dim), + data, local_sampler, model, n_local_steps=50, diff --git a/test/unit/test_kernels.py b/test/unit/test_kernels.py index 62c5e21..98780e4 100644 --- a/test/unit/test_kernels.py +++ b/test/unit/test_kernels.py @@ -85,7 +85,11 @@ def test_HMC_acceptance_rate(self): HMC_obj = HMC( log_posterior, True, - {"step_size": 0.0000001, "n_leapfrog": 5, "condition_matrix": jnp.eye(n_dim)}, + { + "step_size": 0.0000001, + "n_leapfrog": 5, + "condition_matrix": jnp.eye(n_dim), + }, ) n_chains = 100 @@ -94,7 +98,7 @@ def test_HMC_acceptance_rate(self): initial_position = ( jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 ) - initial_PE = - jax.vmap(HMC_obj.potential)(initial_position, None) + initial_PE = -jax.vmap(HMC_obj.potential)(initial_position, None) result = HMC_obj.kernel_vmap(rng_key_set[1], initial_position, initial_PE, None) @@ -118,7 +122,9 @@ def test_HMC_close_gaussian(self): result = HMC_obj.sample(rng_key_set[1], 10000, initial_position, None) - assert jnp.isclose(jnp.mean(result[1]), 0, atol=3e-2) # sqrt(N) is the expected error, but we can get unlucky + assert jnp.isclose( + jnp.mean(result[1]), 0, atol=3e-2 + ) # sqrt(N) is the expected error, but we can get unlucky assert jnp.isclose(jnp.var(result[1]), 1, atol=3e-2)