From 9b84b34c4b6681a25f41449cf0688e8da0976530 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 22 Mar 2024 15:26:44 -0400 Subject: [PATCH 01/23] Update type hint for rng_key in sample method --- src/flowMC/nfmodel/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flowMC/nfmodel/base.py b/src/flowMC/nfmodel/base.py index da89048..f57e49a 100644 --- a/src/flowMC/nfmodel/base.py +++ b/src/flowMC/nfmodel/base.py @@ -32,7 +32,7 @@ def log_prob(self, x: Array) -> Array: return NotImplemented @abstractmethod - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: return NotImplemented @abstractmethod From 08b8d070cd88b743e9efac963b6e5e8d7bd85de1 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 22 Mar 2024 15:28:21 -0400 Subject: [PATCH 02/23] Fix type hinting in NFModel and Bijection classes --- src/flowMC/nfmodel/base.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/flowMC/nfmodel/base.py b/src/flowMC/nfmodel/base.py index f57e49a..08a732d 100644 --- a/src/flowMC/nfmodel/base.py +++ b/src/flowMC/nfmodel/base.py @@ -1,5 +1,4 @@ from abc import abstractmethod, abstractproperty -from typing import Tuple import equinox as eqx import jax from jaxtyping import Array, PRNGKeyArray, Float @@ -15,7 +14,7 @@ class NFModel(eqx.Module): def __init__(self): return NotImplemented - def __call__(self, x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Array) -> tuple[Array, Array]: """ Forward pass of the model. @@ -23,7 +22,7 @@ def __call__(self, x: Array) -> Tuple[Array, Array]: x (Array): Input data. Returns: - Tuple[Array, Array]: Output data and log determinant of the Jacobian. + tuple[Array, Array]: Output data and log determinant of the Jacobian. """ return self.forward(x) @@ -36,7 +35,7 @@ def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: return NotImplemented @abstractmethod - def forward(self, x: Array) -> Tuple[Array, Array]: + def forward(self, x: Array) -> tuple[Array, Array]: """ Forward pass of the model. @@ -44,11 +43,11 @@ def forward(self, x: Array) -> Tuple[Array, Array]: x (Array): Input data. Returns: - Tuple[Array, Array]: Output data and log determinant of the Jacobian.""" + tuple[Array, Array]: Output data and log determinant of the Jacobian.""" return NotImplemented @abstractmethod - def inverse(self, x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Array) -> tuple[Array, Array]: """ Inverse pass of the model. @@ -56,7 +55,7 @@ def inverse(self, x: Array) -> Tuple[Array, Array]: x (Array): Input data. Returns: - Tuple[Array, Array]: Output data and log determinant of the Jacobian.""" + tuple[Array, Array]: Output data and log determinant of the Jacobian.""" return NotImplemented @abstractproperty @@ -81,15 +80,15 @@ class Bijection(eqx.Module): def __init__(self): return NotImplemented - def __call__(self, x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Array) -> tuple[Array, Array]: return self.forward(x) @abstractmethod - def forward(self, x: Array) -> Tuple[Array, Array]: + def forward(self, x: Array) -> tuple[Array, Array]: return NotImplemented @abstractmethod - def inverse(self, x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Array) -> tuple[Array, Array]: return NotImplemented From 3808b87c8d076ce04b66786b90c9312e7c5a4fc0 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 22 Mar 2024 15:47:25 -0400 Subject: [PATCH 03/23] Update NFProposal class signature --- src/flowMC/sampler/NF_proposal.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/flowMC/sampler/NF_proposal.py b/src/flowMC/sampler/NF_proposal.py index 803d98f..13e0ea8 100644 --- a/src/flowMC/sampler/NF_proposal.py +++ b/src/flowMC/sampler/NF_proposal.py @@ -34,7 +34,9 @@ def kernel( log_prob_proposal: Float[Array, "1"], log_prob_nf_initial: Float[Array, "1"], log_prob_nf_proposal: Float[Array, "1"], - ) -> tuple[Float[Array, "ndim"], Float[Array, "1"], Int[Array, "1"]]: + ) -> tuple[ + Float[Array, "ndim"], Float[Array, "1"], Float[Array, "1"], Int[Array, "1"] + ]: rng_key, subkey = random.split(rng_key) ratio = (log_prob_proposal - log_prob_initial) - ( @@ -47,14 +49,23 @@ def kernel( log_prob_nf = jnp.where(do_accept, log_prob_nf_proposal, log_prob_nf_initial) return position, log_prob, log_prob_nf, do_accept - def update( - self, i, state - ) -> tuple[ + def update(self, i: int, state: tuple[PRNGKeyArray, + Float[Array, "nstep ndim"], + Float[Array, "nstep ndim"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Int[Array, "nstep 1"] + ]) -> tuple[ PRNGKeyArray, Float[Array, "nstep ndim"], + Float[Array, "nstep ndim"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], - PyTree, ]: ( key, @@ -100,6 +111,7 @@ def sample( verbose: bool = False, mode: str = "training", ) -> tuple[ + PRNGKeyArray, Float[Array, "n_chains n_steps ndim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], From 2db2b83d7ed339c5ba8330682d507919a60fc154 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 22 Mar 2024 15:48:14 -0400 Subject: [PATCH 04/23] Update jaxtyping import and PRNGKey type in MLP class --- src/flowMC/nfmodel/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/flowMC/nfmodel/common.py b/src/flowMC/nfmodel/common.py index 0e9cca5..8399756 100644 --- a/src/flowMC/nfmodel/common.py +++ b/src/flowMC/nfmodel/common.py @@ -10,7 +10,7 @@ import jax import jax.numpy as jnp import equinox as eqx -from jaxtyping import Array +from jaxtyping import Array, PRNGKeyArray class MLP(eqx.Module): @@ -18,7 +18,7 @@ class MLP(eqx.Module): Args: shape (Iterable[int]): Shape of the MLP. The first element is the input dimension, the last element is the output dimension. - key (jax.random.PRNGKey): Random key. + key (PRNGKeyArray): Random key. Attributes: layers (List): List of layers. @@ -30,7 +30,7 @@ class MLP(eqx.Module): def __init__( self, shape: Iterable[int], - key: jax.random.PRNGKey, + key: PRNGKeyArray, scale: float = 1e-4, activation: Callable = jax.nn.relu, use_bias: bool = True, @@ -199,7 +199,7 @@ def __init__(self, mean: Array, cov: Array, learnable: bool = False): def log_prob(self, x: Array) -> Array: return jax.scipy.stats.multivariate_normal.logpdf(x, self.mean, self.cov) - def sample(self, key: jax.random.PRNGKey, n_samples: int = 1) -> Array: + def sample(self, key: PRNGKeyArray, n_samples: int = 1) -> Array: return jax.random.multivariate_normal(key, self.mean, self.cov, (n_samples,)) @@ -218,7 +218,7 @@ def log_prob(self, x: Array) -> Array: log_prob += dist.log_prob(x[ranges[0] : ranges[1]]) return log_prob - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: samples = {} for dist, (key, _) in zip(self.distributions, self.partitions.items()): rng_key, sub_key = jax.random.split(rng_key) From 715cba4a80ae339651f0ac3d850118054fc60a90 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 22 Mar 2024 16:10:05 -0400 Subject: [PATCH 05/23] Fix PRNGKeyArray import in code --- docs/quickstart.md | 2 +- example/dualmoon.py | 2 +- example/non_jax_likelihood.py | 2 +- example/notebook/analyzingChains.ipynb | 2 +- example/notebook/dualmoon.ipynb | 2 +- example/notebook/maximizing_likelihood.ipynb | 2 +- example/notebook/mog_pretrain.ipynb | 2 +- example/notebook/normalizingFlow.ipynb | 4 +- example/train_normalizing_flow.py | 4 +- src/flowMC/nfmodel/base.py | 20 +++++----- src/flowMC/nfmodel/common.py | 24 +++++------ src/flowMC/nfmodel/realNVP.py | 42 +++++++++----------- src/flowMC/nfmodel/rqSpline.py | 6 +-- src/flowMC/sampler/HMC.py | 36 +++++++++-------- src/flowMC/sampler/Proposal_Base.py | 4 +- src/flowMC/utils/EvolutionaryOptimizer.py | 39 +++++++++++++----- src/flowMC/utils/PRNG_keys.py | 2 +- test/integration/test_HMC.py | 2 +- test/integration/test_MALA.py | 2 +- test/integration/test_RWMCMC.py | 2 +- test/integration/test_flowHMC.py | 37 +++++++++-------- test/integration/test_normalizingFlow.py | 8 ++-- test/integration/test_quickstart.py | 2 +- test/unit/test_kernels.py | 16 +++++--- test/unit/test_nf.py | 12 +++--- 25 files changed, 148 insertions(+), 128 deletions(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index 06b57fe..f31cb62 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -53,7 +53,7 @@ n_chains = 10 rng_key_set = initialize_rng_keys(n_chains, seed=42) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, jax.random.PRNGKey(21)) +model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, PRNGKeyArray(21)) step_size = 1e-1 local_sampler = MALA(log_posterior, True, {"step_size": step_size}) diff --git a/example/dualmoon.py b/example/dualmoon.py index cdfec16..a8fc842 100644 --- a/example/dualmoon.py +++ b/example/dualmoon.py @@ -38,7 +38,7 @@ def target_dualmoon(x, data): data = jnp.zeros(n_dim) rng_key_set = initialize_rng_keys(n_chains, 42) -model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10)) +model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10)) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 diff --git a/example/non_jax_likelihood.py b/example/non_jax_likelihood.py index 4bd6567..4272ec6 100644 --- a/example/non_jax_likelihood.py +++ b/example/non_jax_likelihood.py @@ -46,7 +46,7 @@ def neal_funnel(x): data = jnp.zeros(n_dim) rng_key_set = initialize_rng_keys(n_chains, 42) -model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10)) +model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10)) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 diff --git a/example/notebook/analyzingChains.ipynb b/example/notebook/analyzingChains.ipynb index 4e5cddc..5941368 100644 --- a/example/notebook/analyzingChains.ipynb +++ b/example/notebook/analyzingChains.ipynb @@ -110,7 +110,7 @@ "print(\"Initializing chains, normalizing flow model and local MCMC sampler\")\n", "\n", "initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1\n", - "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))\n", + "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))\n", "MALA_Sampler = MALA(target_dualmoon, True, {\"step_size\": step_size})\n", "\n", "print(\"Initializing samplers classes\")\n", diff --git a/example/notebook/dualmoon.ipynb b/example/notebook/dualmoon.ipynb index 7ddc60c..53533eb 100644 --- a/example/notebook/dualmoon.ipynb +++ b/example/notebook/dualmoon.ipynb @@ -141,7 +141,7 @@ "data = jnp.zeros(n_dim)\n", "\n", "model = MaskedCouplingRQSpline(\n", - " n_dim, n_layers, hidden_size, num_bins, jax.random.PRNGKey(10)\n", + " n_dim, n_layers, hidden_size, num_bins, PRNGKeyArray(10)\n", ")" ] }, diff --git a/example/notebook/maximizing_likelihood.ipynb b/example/notebook/maximizing_likelihood.ipynb index 148cef7..083d18a 100644 --- a/example/notebook/maximizing_likelihood.ipynb +++ b/example/notebook/maximizing_likelihood.ipynb @@ -164,7 +164,7 @@ "\n", "\n", "rng_key_set = initialize_rng_keys(n_chains, 42)\n", - "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))\n", + "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))\n", "\n", "initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1\n", "\n", diff --git a/example/notebook/mog_pretrain.ipynb b/example/notebook/mog_pretrain.ipynb index 87a63b7..f236f36 100644 --- a/example/notebook/mog_pretrain.ipynb +++ b/example/notebook/mog_pretrain.ipynb @@ -135,7 +135,7 @@ "## To use instead the more powerful RQSpline:\n", "from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline\n", "\n", - "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))\n", + "model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))\n", "\n", "\n", "# Local sampler\n", diff --git a/example/notebook/normalizingFlow.ipynb b/example/notebook/normalizingFlow.ipynb index 01c304d..0d9a1ff 100644 --- a/example/notebook/normalizingFlow.ipynb +++ b/example/notebook/normalizingFlow.ipynb @@ -101,7 +101,7 @@ "n_layers = 10\n", "n_hidden = 100\n", "\n", - "key, subkey = jax.random.split(jax.random.PRNGKey(0), 2)\n", + "key, subkey = jax.random.split(PRNGKeyArray(0), 2)\n", "\n", "model = RealNVP(\n", " n_feature,\n", @@ -246,7 +246,7 @@ "n_hiddens = [64, 64]\n", "n_bins = 8\n", "\n", - "key, subkey = jax.random.split(jax.random.PRNGKey(1))\n", + "key, subkey = jax.random.split(PRNGKeyArray(1))\n", "\n", "model = MaskedCouplingRQSpline(\n", " n_feature,\n", diff --git a/example/train_normalizing_flow.py b/example/train_normalizing_flow.py index 24e71c8..842d7d5 100644 --- a/example/train_normalizing_flow.py +++ b/example/train_normalizing_flow.py @@ -23,7 +23,7 @@ data = make_moons(n_samples=20000, noise=0.05) data = jnp.array(data[0]) -key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) +key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3) model = MaskedCouplingRQSpline( 2, @@ -40,4 +40,4 @@ key, model, loss = train_flow(rng, model, data, num_epochs, batch_size, verbose=True) -nf_samples = model.sample(jax.random.PRNGKey(124098), 5000) +nf_samples = model.sample(PRNGKeyArray(124098), 5000) diff --git a/src/flowMC/nfmodel/base.py b/src/flowMC/nfmodel/base.py index 08a732d..99fdf5e 100644 --- a/src/flowMC/nfmodel/base.py +++ b/src/flowMC/nfmodel/base.py @@ -14,20 +14,20 @@ class NFModel(eqx.Module): def __init__(self): return NotImplemented - def __call__(self, x: Array) -> tuple[Array, Array]: + def __call__(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]: """ Forward pass of the model. Args: - x (Array): Input data. + x (Float[Array, "n_dim"]): Input data. Returns: - tuple[Array, Array]: Output data and log determinant of the Jacobian. + tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian. """ return self.forward(x) @abstractmethod - def log_prob(self, x: Array) -> Array: + def log_prob(self, x: Float[Array, "n_dim"]) -> Float: return NotImplemented @abstractmethod @@ -35,27 +35,27 @@ def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: return NotImplemented @abstractmethod - def forward(self, x: Array) -> tuple[Array, Array]: + def forward(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]: """ Forward pass of the model. Args: - x (Array): Input data. + x (Float[Array, "n_dim"]): Input data. Returns: - tuple[Array, Array]: Output data and log determinant of the Jacobian.""" + tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.""" return NotImplemented @abstractmethod - def inverse(self, x: Array) -> tuple[Array, Array]: + def inverse(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]: """ Inverse pass of the model. Args: - x (Array): Input data. + x (Float[Array, "n_dim"]): Input data. Returns: - tuple[Array, Array]: Output data and log determinant of the Jacobian.""" + tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.""" return NotImplemented @abstractproperty diff --git a/src/flowMC/nfmodel/common.py b/src/flowMC/nfmodel/common.py index 8399756..170af47 100644 --- a/src/flowMC/nfmodel/common.py +++ b/src/flowMC/nfmodel/common.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Iterable, Tuple +from typing import Callable, List, Tuple import jax import jax.numpy as jnp @@ -10,14 +10,14 @@ import jax import jax.numpy as jnp import equinox as eqx -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, Float, PRNGKeyArray class MLP(eqx.Module): r"""Multilayer perceptron. Args: - shape (Iterable[int]): Shape of the MLP. The first element is the input dimension, the last element is the output dimension. + shape (List[int]): Shape of the MLP. The first element is the input dimension, the last element is the output dimension. key (PRNGKeyArray): Random key. Attributes: @@ -29,7 +29,7 @@ class MLP(eqx.Module): def __init__( self, - shape: Iterable[int], + shape: List[int], key: PRNGKeyArray, scale: float = 1e-4, activation: Callable = jax.nn.relu, @@ -52,7 +52,7 @@ def __init__( eqx.nn.Linear(shape[-2], shape[-1], key=subkey, use_bias=use_bias) ) - def __call__(self, x: Array): + def __call__(self, x: Float[Array, "n_in"]) -> Float[Array, "n_out"]: for layer in self.layers: x = layer(x) return x @@ -83,24 +83,24 @@ class MaskedCouplingLayer(Bijection): """ - _mask: Array + _mask: Float[Array, "n_dim"] bijector: Bijection @property def mask(self): return jax.lax.stop_gradient(self._mask) - def __init__(self, bijector: Bijection, mask: Array): + def __init__(self, bijector: Bijection, mask: Float[Array, "n_dim"]): self.bijector = bijector self._mask = mask - def forward(self, x: Array) -> Tuple[Array, Array]: + def forward(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: y, log_det = self.bijector(x, x * self.mask) y = (1 - self.mask) * y + self.mask * x log_det = ((1 - self.mask) * log_det).sum() return y, log_det - def inverse(self, x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: y, log_det = self.bijector.inverse(x, x * self.mask) y = (1 - self.mask) * y + self.mask * x log_det = ((1 - self.mask) * log_det).sum() @@ -117,10 +117,10 @@ def __init__(self, scale_MLP: MLP, shift_MLP: MLP, dt: float = 1): self.shift_MLP = shift_MLP self.dt = dt - def __call__(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Float[Array, "n_dim"], condition_x: Array) -> Tuple[Array, float]: return self.forward(x, condition_x) - def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def forward(self, x: Array, condition_x: Array) -> Tuple[Array, float]: # Note that this note output log_det as an array instead of a number. # This is because we need to sum over the log_det in the masked coupling layer. scale = jnp.tanh(self.scale_MLP(condition_x)) * self.dt @@ -129,7 +129,7 @@ def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: y = (x + shift) * jnp.exp(scale) return y, log_det - def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, float]: scale = jnp.tanh(self.scale_MLP(condition_x)) * self.dt shift = self.shift_MLP(condition_x) * self.dt log_det = -scale diff --git a/src/flowMC/nfmodel/realNVP.py b/src/flowMC/nfmodel/realNVP.py index b4d2714..26599b3 100644 --- a/src/flowMC/nfmodel/realNVP.py +++ b/src/flowMC/nfmodel/realNVP.py @@ -5,7 +5,7 @@ import equinox as eqx from flowMC.nfmodel.base import NFModel, Distribution from flowMC.nfmodel.common import MLP, MaskedCouplingLayer, MLPAffine, Gaussian -from jaxtyping import Array +from jaxtyping import Array, Float, PRNGKeyArray from functools import partial @@ -19,22 +19,22 @@ class AffineCoupling(eqx.Module): n_features: (int) The number of features in the input. n_hidden: (int) The number of hidden units in the MLP. mask: (ndarray) Alternating mask for the affine coupling layer. - dt: (float) Scaling factor for the affine coupling layer. + dt: (Float) Scaling factor for the affine coupling layer. """ _mask: Array - scale_MLP: eqx.Module - translate_MLP: eqx.Module - dt: float = 1 + scale_MLP: MLP + translate_MLP: MLP + dt: Float = 1 def __init__( self, n_features: int, n_hidden: int, mask: Array, - key: jax.random.PRNGKey, - dt: float = 1, - scale: float = 1e-4, + key: PRNGKeyArray, + dt: Float = 1, + scale: Float = 1e-4, ): self._mask = mask self.dt = dt @@ -102,7 +102,7 @@ class RealNVP(NFModel): n_layer: (int) The number of affine coupling layers. n_features: (int) The number of features in the input. n_hidden: (int) The number of hidden units in the MLP. - dt: (float) Scaling factor for the affine coupling layer. + dt: (Float) Scaling factor for the affine coupling layer. Properties: data_mean: (ndarray) Mean of Gaussian base distribution @@ -112,8 +112,8 @@ class RealNVP(NFModel): base_dist: Distribution affine_coupling: List[MaskedCouplingLayer] _n_features: int - _data_mean: Array - _data_cov: Array + _data_mean: Array | None + _data_cov: Array | None @property def n_features(self): @@ -128,16 +128,11 @@ def data_cov(self): return jax.lax.stop_gradient(self._data_cov) def __init__( - self, - n_features: int, - n_layer: int, - n_hidden: int, - key: jax.random.PRNGKey, - **kwargs + self, n_features: int, n_layer: int, n_hidden: int, key: PRNGKeyArray, **kwargs ): if kwargs.get("base_dist") is not None: - self.base_dist = kwargs.get("base_dist") + self.base_dist = kwargs.get("base_dist") # type: ignore else: self.base_dist = Gaussian( jnp.zeros(n_features), jnp.eye(n_features), learnable=False @@ -177,18 +172,18 @@ def __init__( else: self._data_cov = jnp.eye(n_features) - def __call__(self, x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Array) -> Tuple[Array, Float]: return self.forward(x) - def forward(self, x: Array) -> Tuple[Array, Array]: - log_det = 0 + def forward(self, x: Array) -> Tuple[Array, Float]: + log_det = 0.0 for i in range(len(self.affine_coupling)): x, log_det_i = self.affine_coupling[i](x) log_det += log_det_i return x, log_det @partial(jax.vmap, in_axes=(None, 0)) - def inverse(self, x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Array) -> Tuple[Array, Float]: """From latent space to data space""" log_det = 0.0 for layer in reversed(self.affine_coupling): @@ -197,8 +192,7 @@ def inverse(self, x: Array) -> Tuple[Array, Array]: return x, log_det @eqx.filter_jit - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: - + def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: samples = self.base_dist.sample(rng_key, n_samples) samples = self.inverse(samples)[0] samples = samples * jnp.sqrt(jnp.diag(self.data_cov)) + self.data_mean diff --git a/src/flowMC/nfmodel/rqSpline.py b/src/flowMC/nfmodel/rqSpline.py index fa0d0a5..c5a55c3 100644 --- a/src/flowMC/nfmodel/rqSpline.py +++ b/src/flowMC/nfmodel/rqSpline.py @@ -340,7 +340,7 @@ class MaskedCouplingRQSpline(NFModel): num_layers (int): Number of layers in the conditioner. hidden_size (Sequence[int]): Hidden size of the conditioner. num_bins (int): Number of bins in the spline. - key (jax.random.PRNGKey): Random key for initialization. + key (PRNGKeyArray): Random key for initialization. spline_range (Sequence[float]): Range of the spline. Defaults to (-10.0, 10.0). Properties: @@ -373,7 +373,7 @@ def __init__( n_layers: int, hidden_size: Sequence[int], num_bins: int, - key: jax.random.PRNGKey, + key: PRNGKeyArray, spline_range: Sequence[float] = (-10.0, 10.0), **kwargs ): @@ -441,7 +441,7 @@ def inverse(self, x: Array) -> Tuple[Array, Array]: return x, log_det @eqx.filter_jit - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: samples = self.base_dist.sample(rng_key, n_samples) samples = self.inverse(samples)[0] samples = samples * jnp.sqrt(jnp.diag(self.data_cov)) + self.data_mean diff --git a/src/flowMC/sampler/HMC.py b/src/flowMC/sampler/HMC.py index 26113eb..849356b 100644 --- a/src/flowMC/sampler/HMC.py +++ b/src/flowMC/sampler/HMC.py @@ -42,7 +42,7 @@ def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: self.n_leapfrog = 10 print("n_leapfrog not specified, using default value 10") - coefs = jnp.ones((self.n_leapfrog+2, 2)) + coefs = jnp.ones((self.n_leapfrog + 2, 2)) coefs = coefs.at[0].set(jnp.array([0, 0.5])) coefs = coefs.at[-1].set(jnp.array([1, 0.5])) self.leapfrog_coefs = coefs @@ -50,10 +50,9 @@ def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: self.kinetic = lambda p, metric: 0.5 * (p**2 * metric).sum() self.grad_kinetic = jax.grad(self.kinetic) - def get_initial_hamiltonian( self, - rng_key: jax.random.PRNGKey, + rng_key: PRNGKeyArray, position: jnp.array, data: jnp.array, params: dict, @@ -67,16 +66,18 @@ def get_initial_hamiltonian( jax.random.normal(rng_key, shape=position.shape) * params["condition_matrix"] ** -0.5 ) - return self.potential(position, data) + self.kinetic(momentum, params["condition_matrix"]) + return self.potential(position, data) + self.kinetic( + momentum, params["condition_matrix"] + ) def leapfrog_kernel(self, carry, extras): position, momentum, data, metric, index = carry - position = position + self.params["step_size"] * self.leapfrog_coefs[index][0] * self.grad_kinetic( - momentum, metric - ) - momentum = momentum - self.params["step_size"] * self.leapfrog_coefs[index][1] * self.grad_potential( - position, data - ) + position = position + self.params["step_size"] * self.leapfrog_coefs[index][ + 0 + ] * self.grad_kinetic(momentum, metric) + momentum = momentum - self.params["step_size"] * self.leapfrog_coefs[index][ + 1 + ] * self.grad_potential(position, data) index = index + 1 return (position, momentum, data, metric, index), extras @@ -84,7 +85,7 @@ def leapfrog_step(self, position, momentum, data, metric): (position, momentum, data, metric, index), _ = jax.lax.scan( self.leapfrog_kernel, (position, momentum, data, metric, 0), - jnp.arange(self.n_leapfrog+2), + jnp.arange(self.n_leapfrog + 2), ) return position, momentum @@ -112,13 +113,16 @@ def kernel( ) momentum = jnp.dot( jax.random.normal(key1, shape=position.shape), - jnp.linalg.cholesky(jnp.linalg.inv(self.params["condition_matrix"])).T) - H = - log_prob + self.kinetic(momentum, self.params["condition_matrix"]) + jnp.linalg.cholesky(jnp.linalg.inv(self.params["condition_matrix"])).T, + ) + H = -log_prob + self.kinetic(momentum, self.params["condition_matrix"]) proposed_position, proposed_momentum = self.leapfrog_step( position, momentum, data, self.params["condition_matrix"] ) proposed_PE = self.potential(proposed_position, data) - proposed_ham = proposed_PE + self.kinetic(proposed_momentum, self.params["condition_matrix"]) + proposed_ham = proposed_PE + self.kinetic( + proposed_momentum, self.params["condition_matrix"] + ) log_acc = H - proposed_ham log_uniform = jnp.log(jax.random.uniform(key2)) @@ -129,9 +133,7 @@ def kernel( return position, log_prob, do_accept - def update( - self, i, state - ) -> tuple[ + def update(self, i, state) -> tuple[ PRNGKeyArray, Float[Array, "nstep ndim"], Float[Array, "nstep 1"], diff --git a/src/flowMC/sampler/Proposal_Base.py b/src/flowMC/sampler/Proposal_Base.py index c697892..84241f9 100644 --- a/src/flowMC/sampler/Proposal_Base.py +++ b/src/flowMC/sampler/Proposal_Base.py @@ -70,9 +70,7 @@ def kernel( """ @abstractmethod - def update( - self, i, state - ) -> tuple[ + def update(self, i, state) -> tuple[ PRNGKeyArray, Float[Array, "nstep ndim"], Float[Array, "nstep 1"], diff --git a/src/flowMC/utils/EvolutionaryOptimizer.py b/src/flowMC/utils/EvolutionaryOptimizer.py index 4a44a07..be58f14 100644 --- a/src/flowMC/utils/EvolutionaryOptimizer.py +++ b/src/flowMC/utils/EvolutionaryOptimizer.py @@ -1,11 +1,11 @@ from evosax import CMA_ES import jax import jax.numpy as jnp +from jaxtyping import PRNGKeyArray import tqdm class EvolutionaryOptimizer: - """ A wrapper class for the evosax package. Note that we do not aim to solve any generic optimization problem, @@ -45,7 +45,7 @@ def __init__(self, ndims, popsize=100, verbose=False): self.history = [] self.state = None - def optimize(self, objective, bound, n_loops = 100, seed = 9527, keep_history_step = 0): + def optimize(self, objective, bound, n_loops=100, seed=9527, keep_history_step=0): """ Optimize the objective function. @@ -64,29 +64,46 @@ def optimize(self, objective, bound, n_loops = 100, seed = 9527, keep_history_st ------- None """ - rng = jax.random.PRNGKey(seed) + rng = PRNGKeyArray(seed) key, subkey = jax.random.split(rng) - progress_bar = tqdm.tqdm(range(n_loops), "Generation: ") if self.verbose else range(n_loops) + progress_bar = ( + tqdm.tqdm(range(n_loops), "Generation: ") + if self.verbose + else range(n_loops) + ) self.bound = bound self.state = self.strategy.initialize(key, self.es_params) if keep_history_step > 0: self.history = [] for i in progress_bar: - subkey, self.state, theta = self.optimize_step(subkey, self.state, objective, bound) - if i%keep_history_step == 0: self.history.append(theta) - if self.verbose: progress_bar.set_description(f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}") + subkey, self.state, theta = self.optimize_step( + subkey, self.state, objective, bound + ) + if i % keep_history_step == 0: + self.history.append(theta) + if self.verbose: + progress_bar.set_description( + f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}" + ) self.history = jnp.array(self.history) else: for i in progress_bar: - subkey, self.state, _ = self.optimize_step(subkey, self.state, objective, bound) - if self.verbose: progress_bar.set_description(f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}") + subkey, self.state, _ = self.optimize_step( + subkey, self.state, objective, bound + ) + if self.verbose: + progress_bar.set_description( + f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}" + ) - def optimize_step(self, key: jax.random.PRNGKey, state, objective: callable, bound): + def optimize_step(self, key: PRNGKeyArray, state, objective: callable, bound): key, subkey = jax.random.split(key) x, state = self.strategy.ask(subkey, state, self.es_params) theta = x * (bound[:, 1] - bound[:, 0]) + bound[:, 0] fitness = objective(theta) - state = self.strategy.tell(x, fitness.astype(jnp.float32), state, self.es_params) + state = self.strategy.tell( + x, fitness.astype(jnp.float32), state, self.es_params + ) return key, state, theta def get_result(self): diff --git a/src/flowMC/utils/PRNG_keys.py b/src/flowMC/utils/PRNG_keys.py index 1ce51a6..f69e99c 100644 --- a/src/flowMC/utils/PRNG_keys.py +++ b/src/flowMC/utils/PRNG_keys.py @@ -16,7 +16,7 @@ def initialize_rng_keys(n_chains, seed=42): rng_keys_nf (Device Array): RNG keys for the normalizing flow global sampler. init_rng_keys_nf (Device Array): RNG keys for initializing wieght of the normalizing flow model. """ - rng_key = jax.random.PRNGKey(seed) + rng_key = PRNGKeyArray(seed) rng_key_init, rng_key_mcmc, rng_key_nf = jax.random.split(rng_key, 3) rng_keys_mcmc = jax.random.split(rng_key_mcmc, n_chains) diff --git a/test/integration/test_HMC.py b/test/integration/test_HMC.py index bceeeec..f556c81 100644 --- a/test/integration/test_HMC.py +++ b/test/integration/test_HMC.py @@ -73,7 +73,7 @@ def dual_moon_pe(x, data): initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, jax.random.PRNGKey(10)) +model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, PRNGKeyArray(10)) print("Initializing sampler class") diff --git a/test/integration/test_MALA.py b/test/integration/test_MALA.py index 4aa2440..5fb459d 100644 --- a/test/integration/test_MALA.py +++ b/test/integration/test_MALA.py @@ -70,7 +70,7 @@ def dual_moon_pe(x, data): initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, jax.random.PRNGKey(10)) +model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, PRNGKeyArray(10)) print("Initializing sampler class") diff --git a/test/integration/test_RWMCMC.py b/test/integration/test_RWMCMC.py index 3ac1c42..2ff8e0f 100644 --- a/test/integration/test_RWMCMC.py +++ b/test/integration/test_RWMCMC.py @@ -70,7 +70,7 @@ def dual_moon_pe(x, data): initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, jax.random.PRNGKey(10)) +model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, PRNGKeyArray(10)) print("Initializing sampler class") diff --git a/test/integration/test_flowHMC.py b/test/integration/test_flowHMC.py index 4535af0..04661ef 100644 --- a/test/integration/test_flowHMC.py +++ b/test/integration/test_flowHMC.py @@ -18,6 +18,7 @@ def log_posterior(x, data): term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 return -(term1 - logsumexp(term2) - logsumexp(term3)) + n_dim = 8 n_chains = 15 n_local_steps = 30 @@ -29,7 +30,7 @@ def log_posterior(x, data): rng_key_set = initialize_rng_keys(n_chains, seed=42) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 4, jax.random.PRNGKey(10)) +model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 4, PRNGKeyArray(10)) local_sampler = MALA(log_posterior, True, {"step_size": step_size}) @@ -46,7 +47,7 @@ def log_posterior(x, data): ) n_steps = 50 -rng_key, *subkeys = jax.random.split(jax.random.PRNGKey(0), 3) +rng_key, *subkeys = jax.random.split(PRNGKeyArray(0), 3) n_chains = initial_position.shape[0] n_dim = initial_position.shape[-1] @@ -60,18 +61,20 @@ def log_posterior(x, data): momentum = jax.random.normal(subkeys[0], shape=initial_position.shape) -nf_sampler = Sampler(n_dim, - rng_key_set, - data, - local_sampler, - model, - n_local_steps = 50, - n_global_steps = 50, - n_epochs = 30, - learning_rate = 1e-2, - batch_size = 10000, - n_chains = n_chains, - global_sampler = flowHMC_sampler, - verbose = True,) - -nf_sampler.sample(initial_position, data) \ No newline at end of file +nf_sampler = Sampler( + n_dim, + rng_key_set, + data, + local_sampler, + model, + n_local_steps=50, + n_global_steps=50, + n_epochs=30, + learning_rate=1e-2, + batch_size=10000, + n_chains=n_chains, + global_sampler=flowHMC_sampler, + verbose=True, +) + +nf_sampler.sample(initial_position, data) diff --git a/test/integration/test_normalizingFlow.py b/test/integration/test_normalizingFlow.py index d799fdc..1f6c12c 100644 --- a/test/integration/test_normalizingFlow.py +++ b/test/integration/test_normalizingFlow.py @@ -9,7 +9,7 @@ def test_realNVP(): - key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) + key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3) data = jax.random.normal(key1, (100, 2)) num_epochs = 5 @@ -25,7 +25,7 @@ def test_realNVP(): rng, best_model, state, loss_values = train_flow( rng, model, data, state, num_epochs, batch_size, verbose=True ) - rng_key_nf = jax.random.PRNGKey(124098) + rng_key_nf = PRNGKeyArray(124098) model.sample(rng_key_nf, 10000) @@ -37,7 +37,7 @@ def test_rqSpline(): learning_rate = 0.001 momentum = 0.9 - key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) + key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3) data = jax.random.normal(key1, (batch_size, n_dim)) n_layers = 4 @@ -60,5 +60,5 @@ def test_rqSpline(): rng, best_model, state, loss_values = train_flow( rng, model, data, state, num_epochs, batch_size, verbose=True ) - rng_key_nf = jax.random.PRNGKey(124098) + rng_key_nf = PRNGKeyArray(124098) model.sample(rng_key_nf, 10000) diff --git a/test/integration/test_quickstart.py b/test/integration/test_quickstart.py index 3353c0b..851705b 100644 --- a/test/integration/test_quickstart.py +++ b/test/integration/test_quickstart.py @@ -18,7 +18,7 @@ def log_posterior(x, data): rng_key_set = initialize_rng_keys(n_chains, seed=42) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, jax.random.PRNGKey(21)) +model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, PRNGKeyArray(21)) step_size = 1e-1 local_sampler = MALA(log_posterior, True, {"step_size": step_size}) diff --git a/test/unit/test_kernels.py b/test/unit/test_kernels.py index 62c5e21..bf184b7 100644 --- a/test/unit/test_kernels.py +++ b/test/unit/test_kernels.py @@ -85,7 +85,11 @@ def test_HMC_acceptance_rate(self): HMC_obj = HMC( log_posterior, True, - {"step_size": 0.0000001, "n_leapfrog": 5, "condition_matrix": jnp.eye(n_dim)}, + { + "step_size": 0.0000001, + "n_leapfrog": 5, + "condition_matrix": jnp.eye(n_dim), + }, ) n_chains = 100 @@ -94,7 +98,7 @@ def test_HMC_acceptance_rate(self): initial_position = ( jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 ) - initial_PE = - jax.vmap(HMC_obj.potential)(initial_position, None) + initial_PE = -jax.vmap(HMC_obj.potential)(initial_position, None) result = HMC_obj.kernel_vmap(rng_key_set[1], initial_position, initial_PE, None) @@ -118,7 +122,9 @@ def test_HMC_close_gaussian(self): result = HMC_obj.sample(rng_key_set[1], 10000, initial_position, None) - assert jnp.isclose(jnp.mean(result[1]), 0, atol=3e-2) # sqrt(N) is the expected error, but we can get unlucky + assert jnp.isclose( + jnp.mean(result[1]), 0, atol=3e-2 + ) # sqrt(N) is the expected error, but we can get unlucky assert jnp.isclose(jnp.var(result[1]), 1, atol=3e-2) @@ -248,7 +254,7 @@ def test_Gaussian_random_walk_close_gaussian(self): class TestNF: def test_NF_kernel(self): - key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) + key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3) data = jax.random.normal(key1, (100, 2)) num_epochs = 5 @@ -272,7 +278,7 @@ def test_NF_kernel(self): rng, self.model, state, loss_values = train_flow( rng, model, data, state, num_epochs, batch_size, verbose=True ) - key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(1), 3) + key1, rng, init_rng = jax.random.split(PRNGKeyArray(1), 3) n_dim = 2 n_chains = 1 diff --git a/test/unit/test_nf.py b/test/unit/test_nf.py index 3a454ec..d162ec7 100644 --- a/test/unit/test_nf.py +++ b/test/unit/test_nf.py @@ -9,7 +9,7 @@ def test_affine_coupling_forward_and_inverse(): n_hidden = 4 x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) mask = jnp.where(jnp.arange(n_features) % 2 == 0, 1.0, 0.0) - key = jax.random.PRNGKey(0) + key = PRNGKeyArray(0) dt = 0.5 layer = AffineCoupling(n_features, n_hidden, mask, key, dt) @@ -26,7 +26,7 @@ def test_realnvp(): n_layer = 2 x = jnp.array([[1, 2, 3], [4, 5, 6]]) - rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(0), 2) + rng_key, rng_subkey = jax.random.split(PRNGKeyArray(0), 2) model = RealNVP(n_features, n_layer, n_hidden, rng_key) y, log_det = jax.vmap(model)(x) @@ -41,7 +41,7 @@ def test_realnvp(): assert jnp.allclose(x, y_inv) assert jnp.allclose(log_det, -log_det_inv) - rng_key = jax.random.PRNGKey(0) + rng_key = PRNGKeyArray(0) samples = model.sample(rng_key, 2) assert samples.shape == (2, 3) @@ -57,14 +57,14 @@ def test_rqspline(): n_layer = 2 n_bins = 8 - rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(0), 2) + rng_key, rng_subkey = jax.random.split(PRNGKeyArray(0), 2) model = MaskedCouplingRQSpline( - n_features, n_layer, hidden_layes, n_bins, jax.random.PRNGKey(10) + n_features, n_layer, hidden_layes, n_bins, PRNGKeyArray(10) ) jnp.array([[1, 2, 3], [4, 5, 6]]) - rng_key = jax.random.PRNGKey(0) + rng_key = PRNGKeyArray(0) samples = model.sample(rng_key, 2) assert samples.shape == (2, 3) From ffad6e1434dba9056f0def48f59660a6210e936e Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 25 Mar 2024 11:45:17 -0400 Subject: [PATCH 06/23] Fix type annotations in common.py --- src/flowMC/nfmodel/common.py | 45 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/flowMC/nfmodel/common.py b/src/flowMC/nfmodel/common.py index 170af47..b4d59bf 100644 --- a/src/flowMC/nfmodel/common.py +++ b/src/flowMC/nfmodel/common.py @@ -31,7 +31,7 @@ def __init__( self, shape: List[int], key: PRNGKeyArray, - scale: float = 1e-4, + scale: Float = 1e-4, activation: Callable = jax.nn.relu, use_bias: bool = True, ): @@ -87,7 +87,7 @@ class MaskedCouplingLayer(Bijection): bijector: Bijection @property - def mask(self): + def mask(self) -> Float[Array, "n_dim"]: return jax.lax.stop_gradient(self._mask) def __init__(self, bijector: Bijection, mask: Float[Array, "n_dim"]): @@ -95,13 +95,13 @@ def __init__(self, bijector: Bijection, mask: Float[Array, "n_dim"]): self._mask = mask def forward(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: - y, log_det = self.bijector(x, x * self.mask) + y, log_det = self.bijector(x, x * self.mask) # type: ignore y = (1 - self.mask) * y + self.mask * x log_det = ((1 - self.mask) * log_det).sum() return y, log_det def inverse(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: - y, log_det = self.bijector.inverse(x, x * self.mask) + y, log_det = self.bijector.inverse(x, x * self.mask) # type: ignore y = (1 - self.mask) * y + self.mask * x log_det = ((1 - self.mask) * log_det).sum() return y, log_det @@ -110,17 +110,17 @@ def inverse(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Floa class MLPAffine(Bijection): scale_MLP: MLP shift_MLP: MLP - dt: float = 1 + dt: Float = 1 - def __init__(self, scale_MLP: MLP, shift_MLP: MLP, dt: float = 1): + def __init__(self, scale_MLP: MLP, shift_MLP: MLP, dt: Float = 1): self.scale_MLP = scale_MLP self.shift_MLP = shift_MLP self.dt = dt - def __call__(self, x: Float[Array, "n_dim"], condition_x: Array) -> Tuple[Array, float]: + def __call__(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: return self.forward(x, condition_x) - def forward(self, x: Array, condition_x: Array) -> Tuple[Array, float]: + def forward(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: # Note that this note output log_det as an array instead of a number. # This is because we need to sum over the log_det in the masked coupling layer. scale = jnp.tanh(self.scale_MLP(condition_x)) * self.dt @@ -129,7 +129,7 @@ def forward(self, x: Array, condition_x: Array) -> Tuple[Array, float]: y = (x + shift) * jnp.exp(scale) return y, log_det - def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, float]: + def inverse(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: scale = jnp.tanh(self.scale_MLP(condition_x)) * self.dt shift = self.shift_MLP(condition_x) * self.dt log_det = -scale @@ -141,19 +141,19 @@ class ScalarAffine(Bijection): scale: Array shift: Array - def __init__(self, scale: float, shift: float): + def __init__(self, scale: Float, shift: Float): self.scale = jnp.array(scale) self.shift = jnp.array(shift) - def __call__(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: return self.forward(x, condition_x) - def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def forward(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: y = (x + self.shift) * jnp.exp(self.scale) log_det = self.scale return y, log_det - def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]: y = x * jnp.exp(-self.scale) - self.shift log_det = -self.scale return y, log_det @@ -173,33 +173,33 @@ class Gaussian(Distribution): cov (Array): Covariance matrix. """ - _mean: Array - _cov: Array + _mean: Float[Array, "n_dim"] + _cov: Float[Array, "n_dim n_dim"] learnable: bool = False @property - def mean(self) -> Array: + def mean(self) -> Float[Array, "n_dim"]: if self.learnable: return self._mean else: return jax.lax.stop_gradient(self._mean) @property - def cov(self) -> Array: + def cov(self) -> Float[Array, "n_dim n_dim"]: if self.learnable: return self._cov else: return jax.lax.stop_gradient(self._cov) - def __init__(self, mean: Array, cov: Array, learnable: bool = False): + def __init__(self, mean: Float[Array, "n_dim"], cov: Float[Array, "n_dim n_dim"], learnable: bool = False): self._mean = mean self._cov = cov self.learnable = learnable - def log_prob(self, x: Array) -> Array: + def log_prob(self, x: Float[Array, "n_dim"]) -> Float: return jax.scipy.stats.multivariate_normal.logpdf(x, self.mean, self.cov) - def sample(self, key: PRNGKeyArray, n_samples: int = 1) -> Array: + def sample(self, key: PRNGKeyArray, n_samples: int = 1) -> Float[Array, "n_samples n_dim"]: return jax.random.multivariate_normal(key, self.mean, self.cov, (n_samples,)) @@ -212,14 +212,15 @@ def __init__(self, distributions: list[Distribution], partitions: dict): self.distributions = distributions self.partitions = partitions - def log_prob(self, x: Array) -> Array: + def log_prob(self, x: Float[Array, "n_dim"]) -> Float: log_prob = 0 for dist, (_, ranges) in zip(self.distributions, self.partitions.items()): log_prob += dist.log_prob(x[ranges[0] : ranges[1]]) return log_prob - def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: + def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> dict[str, Float[Array, "n_samples n_dim"]]: samples = {} for dist, (key, _) in zip(self.distributions, self.partitions.items()): rng_key, sub_key = jax.random.split(rng_key) samples[key] = dist.sample(sub_key, n_samples=n_samples) + return samples \ No newline at end of file From b0335e0eed39eac70b92a8eae6a82524a169f1d4 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 25 Mar 2024 11:47:09 -0400 Subject: [PATCH 07/23] Fix import statement in rqSpline.py and EvolutionaryOptimizer.py --- src/flowMC/nfmodel/rqSpline.py | 2 +- src/flowMC/utils/EvolutionaryOptimizer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/flowMC/nfmodel/rqSpline.py b/src/flowMC/nfmodel/rqSpline.py index c5a55c3..d532644 100644 --- a/src/flowMC/nfmodel/rqSpline.py +++ b/src/flowMC/nfmodel/rqSpline.py @@ -1,7 +1,7 @@ from typing import Sequence, Tuple import jax import jax.numpy as jnp -from jaxtyping import Array +from jaxtyping import Array, PRNGKeyArray import equinox as eqx from flowMC.nfmodel.base import NFModel, Bijection, Distribution diff --git a/src/flowMC/utils/EvolutionaryOptimizer.py b/src/flowMC/utils/EvolutionaryOptimizer.py index be58f14..ff0926c 100644 --- a/src/flowMC/utils/EvolutionaryOptimizer.py +++ b/src/flowMC/utils/EvolutionaryOptimizer.py @@ -64,7 +64,7 @@ def optimize(self, objective, bound, n_loops=100, seed=9527, keep_history_step=0 ------- None """ - rng = PRNGKeyArray(seed) + rng = jax.random.PRNGKey(seed) key, subkey = jax.random.split(rng) progress_bar = ( tqdm.tqdm(range(n_loops), "Generation: ") From da67e9fc41867077c8748ee46ca52082f11bf469 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 25 Mar 2024 12:05:07 -0400 Subject: [PATCH 08/23] Refactor ProposalBase class initialization and precompilation --- src/flowMC/sampler/Proposal_Base.py | 69 ++++++++++++++++++----------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/src/flowMC/sampler/Proposal_Base.py b/src/flowMC/sampler/Proposal_Base.py index 84241f9..82ef058 100644 --- a/src/flowMC/sampler/Proposal_Base.py +++ b/src/flowMC/sampler/Proposal_Base.py @@ -7,7 +7,7 @@ @jax.tree_util.register_pytree_node_class class ProposalBase: - def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: + def __init__(self, logpdf: Callable, jit: bool, params: dict): """ Initialize the sampler class """ @@ -21,38 +21,56 @@ def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: in_axes=(None, (0, 0, 0, 0, None)), out_axes=(0, 0, 0, 0, None), ) - if self.jit is True: - self.logpdf_vmap = jax.jit(self.logpdf_vmap) - self.kernel = jax.jit(self.kernel) - self.kernel_vmap = jax.jit(self.kernel_vmap) - self.update = jax.jit(self.update) - self.update_vmap = jax.jit(self.update_vmap) + # if self.jit is True: + # self.logpdf_vmap = jax.jit(self.logpdf_vmap) + # self.kernel = jax.jit(self.kernel) + # self.kernel_vmap = jax.jit(self.kernel_vmap) + # self.update = jax.jit(self.update) + # self.update_vmap = jax.jit(self.update_vmap) def precompilation(self, n_chains, n_dims, n_step, data): if self.jit is True: print("jit is requested, precompiling kernels and update...") + key = jax.random.split(jax.random.PRNGKey(0), n_chains) + + self.logpdf_vmap = jax.jit(self.logpdf_vmap).lower(jnp.ones((n_chains, n_dims)), data).compile() + self.kernel_vmap = jax.jit(self.kernel_vmap).lower( + key, + jnp.ones((n_chains, n_dims)), + jnp.ones((n_chains, 1)), + data, + ).compile() + self.update_vmap = jax.jit(self.update_vmap).lower( + 1, + ( + key, + jnp.ones((n_chains, n_step, n_dims)), + jnp.ones((n_chains, n_step, 1)), + jnp.zeros((n_chains, n_step, 1)), + data, + ), + ).compile() else: print("jit is not requested, compiling only vmap functions...") - - key = jax.random.split(jax.random.PRNGKey(0), n_chains) - - self.logpdf_vmap(jnp.ones((n_chains, n_dims)), data) - self.kernel_vmap( - key, - jnp.ones((n_chains, n_dims)), - jnp.ones((n_chains, 1)), - data, - ) - self.update_vmap( - 1, - ( + key = jax.random.split(jax.random.PRNGKey(0), n_chains) + self.logpdf_vmap = self.logpdf_vmap(jnp.ones((n_chains, n_dims)), data) + self.kernel_vmap( key, - jnp.ones((n_chains, n_step, n_dims)), - jnp.ones((n_chains, n_step, 1)), - jnp.zeros((n_chains, n_step, 1)), + jnp.ones((n_chains, n_dims)), + jnp.ones((n_chains, 1)), data, - ), - ) + ) + self.update_vmap( + 1, + ( + key, + jnp.ones((n_chains, n_step, n_dims)), + jnp.ones((n_chains, n_step, 1)), + jnp.zeros((n_chains, n_step, 1)), + data, + ), + ) + @abstractmethod def kernel( @@ -61,7 +79,6 @@ def kernel( position: Float[Array, "nstep ndim"], log_prob: Float[Array, "nstep 1"], data: PyTree, - params: dict, ) -> tuple[ Float[Array, "nstep ndim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"] ]: From cdc1a42f51a1036802b02075262fa3861a66834d Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 25 Mar 2024 12:14:08 -0400 Subject: [PATCH 09/23] Fix PRNGKeyArray import and update data dictionary --- test/integration/test_HMC.py | 10 +++++----- test/integration/test_MALA.py | 2 +- test/integration/test_RWMCMC.py | 2 +- test/integration/test_flowHMC.py | 9 +++++---- test/integration/test_normalizingFlow.py | 6 +++--- test/integration/test_quickstart.py | 2 +- 6 files changed, 16 insertions(+), 15 deletions(-) diff --git a/test/integration/test_HMC.py b/test/integration/test_HMC.py index f556c81..1f3ee4c 100644 --- a/test/integration/test_HMC.py +++ b/test/integration/test_HMC.py @@ -3,14 +3,14 @@ import jax import jax.numpy as jnp from jax.scipy.special import logsumexp +from jaxtyping import Float, Array - -def dual_moon_pe(x, data): +def dual_moon_pe(x: Float[Array, "n_dim"], data: dict): """ Term 2 and 3 separate the distribution and smear it along the first and second dimension """ print("compile count") - term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2 + term1 = 0.5 * ((jnp.linalg.norm(x - data['data']) - 2) / 0.1) ** 2 term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 return -(term1 - logsumexp(term2) - logsumexp(term3)) @@ -22,7 +22,7 @@ def dual_moon_pe(x, data): step_size = 0.1 n_leapfrog = 10 -data = jnp.arange(5) +data = {'data':jnp.arange(5)} rng_key_set = initialize_rng_keys(n_chains, seed=42) @@ -73,7 +73,7 @@ def dual_moon_pe(x, data): initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, PRNGKeyArray(10)) +model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, jax.random.PRNGKey(10)) print("Initializing sampler class") diff --git a/test/integration/test_MALA.py b/test/integration/test_MALA.py index 5fb459d..4aa2440 100644 --- a/test/integration/test_MALA.py +++ b/test/integration/test_MALA.py @@ -70,7 +70,7 @@ def dual_moon_pe(x, data): initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, PRNGKeyArray(10)) +model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, jax.random.PRNGKey(10)) print("Initializing sampler class") diff --git a/test/integration/test_RWMCMC.py b/test/integration/test_RWMCMC.py index 2ff8e0f..3ac1c42 100644 --- a/test/integration/test_RWMCMC.py +++ b/test/integration/test_RWMCMC.py @@ -70,7 +70,7 @@ def dual_moon_pe(x, data): initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, PRNGKeyArray(10)) +model = MaskedCouplingRQSpline(2, 4, [32, 32], 4, jax.random.PRNGKey(10)) print("Initializing sampler class") diff --git a/test/integration/test_flowHMC.py b/test/integration/test_flowHMC.py index 04661ef..a5e5733 100644 --- a/test/integration/test_flowHMC.py +++ b/test/integration/test_flowHMC.py @@ -3,6 +3,7 @@ from flowMC.utils.PRNG_keys import initialize_rng_keys import jax import jax.numpy as jnp +from jaxtyping import Float, Array from jax.scipy.special import logsumexp from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.sampler.Sampler import Sampler @@ -13,7 +14,7 @@ def log_posterior(x, data): Term 2 and 3 separate the distribution and smear it along the first and second dimension """ print("compile count") - term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2 + term1 = 0.5 * ((jnp.linalg.norm(x - data['data']) - 2) / 0.1) ** 2 term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 return -(term1 - logsumexp(term2) - logsumexp(term3)) @@ -25,12 +26,12 @@ def log_posterior(x, data): step_size = 0.1 n_leapfrog = 3 -data = jnp.arange(n_dim) +data = {'data':jnp.arange(n_dim)} rng_key_set = initialize_rng_keys(n_chains, seed=42) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 4, PRNGKeyArray(10)) +model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 4, jax.random.PRNGKey(10)) local_sampler = MALA(log_posterior, True, {"step_size": step_size}) @@ -47,7 +48,7 @@ def log_posterior(x, data): ) n_steps = 50 -rng_key, *subkeys = jax.random.split(PRNGKeyArray(0), 3) +rng_key, *subkeys = jax.random.split(jax.random.PRNGKey(0), 3) n_chains = initial_position.shape[0] n_dim = initial_position.shape[-1] diff --git a/test/integration/test_normalizingFlow.py b/test/integration/test_normalizingFlow.py index 1f6c12c..cc2e28b 100644 --- a/test/integration/test_normalizingFlow.py +++ b/test/integration/test_normalizingFlow.py @@ -9,7 +9,7 @@ def test_realNVP(): - key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3) + key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) data = jax.random.normal(key1, (100, 2)) num_epochs = 5 @@ -25,7 +25,7 @@ def test_realNVP(): rng, best_model, state, loss_values = train_flow( rng, model, data, state, num_epochs, batch_size, verbose=True ) - rng_key_nf = PRNGKeyArray(124098) + rng_key_nf = jax.random.PRNGKey(124098) model.sample(rng_key_nf, 10000) @@ -37,7 +37,7 @@ def test_rqSpline(): learning_rate = 0.001 momentum = 0.9 - key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3) + key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) data = jax.random.normal(key1, (batch_size, n_dim)) n_layers = 4 diff --git a/test/integration/test_quickstart.py b/test/integration/test_quickstart.py index 851705b..3353c0b 100644 --- a/test/integration/test_quickstart.py +++ b/test/integration/test_quickstart.py @@ -18,7 +18,7 @@ def log_posterior(x, data): rng_key_set = initialize_rng_keys(n_chains, seed=42) initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1 -model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, PRNGKeyArray(21)) +model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, jax.random.PRNGKey(21)) step_size = 1e-1 local_sampler = MALA(log_posterior, True, {"step_size": step_size}) From 2151b42185479a91e78667ab600297a9aea99021 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 25 Mar 2024 12:16:04 -0400 Subject: [PATCH 10/23] Enable just-in-time (JIT) compilation --- src/flowMC/sampler/Proposal_Base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/flowMC/sampler/Proposal_Base.py b/src/flowMC/sampler/Proposal_Base.py index 82ef058..abe1ed4 100644 --- a/src/flowMC/sampler/Proposal_Base.py +++ b/src/flowMC/sampler/Proposal_Base.py @@ -21,12 +21,12 @@ def __init__(self, logpdf: Callable, jit: bool, params: dict): in_axes=(None, (0, 0, 0, 0, None)), out_axes=(0, 0, 0, 0, None), ) - # if self.jit is True: - # self.logpdf_vmap = jax.jit(self.logpdf_vmap) - # self.kernel = jax.jit(self.kernel) - # self.kernel_vmap = jax.jit(self.kernel_vmap) - # self.update = jax.jit(self.update) - # self.update_vmap = jax.jit(self.update_vmap) + if self.jit is True: + self.logpdf_vmap = jax.jit(self.logpdf_vmap) + self.kernel = jax.jit(self.kernel) + self.kernel_vmap = jax.jit(self.kernel_vmap) + self.update = jax.jit(self.update) + self.update_vmap = jax.jit(self.update_vmap) def precompilation(self, n_chains, n_dims, n_step, data): if self.jit is True: From 47040298d164069b13bcf0e1151f52b4d7900e41 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 10:08:49 -0400 Subject: [PATCH 11/23] Refactor sampler classes and update type annotations --- src/flowMC/sampler/Gaussian_random_walk.py | 2 +- src/flowMC/sampler/MALA.py | 32 +++++++----- src/flowMC/sampler/Proposal_Base.py | 59 +++++++++++++++------- 3 files changed, 60 insertions(+), 33 deletions(-) diff --git a/src/flowMC/sampler/Gaussian_random_walk.py b/src/flowMC/sampler/Gaussian_random_walk.py index a45429d..3262743 100644 --- a/src/flowMC/sampler/Gaussian_random_walk.py +++ b/src/flowMC/sampler/Gaussian_random_walk.py @@ -21,7 +21,7 @@ def __init__( logpdf: Callable, jit: bool, params: dict, - ) -> Callable: + ): super().__init__(logpdf, jit, params) self.params = params self.logpdf = logpdf diff --git a/src/flowMC/sampler/MALA.py b/src/flowMC/sampler/MALA.py index d932086..1088ebc 100644 --- a/src/flowMC/sampler/MALA.py +++ b/src/flowMC/sampler/MALA.py @@ -5,8 +5,7 @@ from tqdm import tqdm from flowMC.sampler.Proposal_Base import ProposalBase from functools import partialmethod -from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray - +from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray, Bool class MALA(ProposalBase): """ @@ -18,13 +17,20 @@ class MALA(ProposalBase): params: dictionary of parameters for the sampler """ - def __init__(self, logpdf: Callable, jit: bool, params: dict, use_autotune=False): + def __init__(self, logpdf: Callable, jit: Bool, params: dict, use_autotune=False): super().__init__(logpdf, jit, params) - self.params = params - self.logpdf = logpdf - self.use_autotune = use_autotune + self.params: PyTree = params + self.logpdf: Callable = logpdf + self.use_autotune: Bool = use_autotune - def body(self, carry, this_key): + def body( + self, + carry: tuple[Float[Array, "n_dim"], float, dict], + this_key: PRNGKeyArray, + ) -> tuple[ + tuple[Float[Array, "n_dim"], float, dict], + tuple[Float[Array, "n_dim"], Float[Array, "1"], Float[Array, "n_dim"]], + ]: print("Compiling MALA body") this_position, dt, data = carry dt2 = dt * dt @@ -58,7 +64,7 @@ def kernel( key1, key2 = jax.random.split(rng_key) - dt = self.params["step_size"] + dt: Float = self.params["step_size"] dt2 = dt * dt _, (proposal, logprob, d_logprob) = jax.lax.scan( @@ -74,16 +80,14 @@ def kernel( ) log_uniform = jnp.log(jax.random.uniform(key2)) - do_accept = log_uniform < ratio + do_accept: Bool[Array, "n_dim"] = log_uniform < ratio position = jnp.where(do_accept, proposal[0], position) log_prob = jnp.where(do_accept, logprob[1], logprob[0]) return position, log_prob, do_accept - def update( - self, i, state - ) -> tuple[ + def update(self, i, state) -> tuple[ PRNGKeyArray, Float[Array, "nstep ndim"], Float[Array, "nstep 1"], @@ -157,7 +161,7 @@ def mala_sampler_autotune( max_iter (int): maximal number of iterations to tune the step size """ - tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) # type: ignore counter = 0 position, log_prob, do_accept = self.kernel_vmap( @@ -179,5 +183,5 @@ def mala_sampler_autotune( rng_key, initial_position, log_prob, data ) acceptance_rate = jnp.mean(do_accept) - tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) + tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) # type: ignore return params diff --git a/src/flowMC/sampler/Proposal_Base.py b/src/flowMC/sampler/Proposal_Base.py index abe1ed4..c78f898 100644 --- a/src/flowMC/sampler/Proposal_Base.py +++ b/src/flowMC/sampler/Proposal_Base.py @@ -7,7 +7,9 @@ @jax.tree_util.register_pytree_node_class class ProposalBase: - def __init__(self, logpdf: Callable, jit: bool, params: dict): + def __init__( + self, logpdf: Callable[[Float[Array, "n_dim"], PyTree], Float], jit: bool, params: dict + ): """ Initialize the sampler class """ @@ -33,23 +35,35 @@ def precompilation(self, n_chains, n_dims, n_step, data): print("jit is requested, precompiling kernels and update...") key = jax.random.split(jax.random.PRNGKey(0), n_chains) - self.logpdf_vmap = jax.jit(self.logpdf_vmap).lower(jnp.ones((n_chains, n_dims)), data).compile() - self.kernel_vmap = jax.jit(self.kernel_vmap).lower( - key, - jnp.ones((n_chains, n_dims)), - jnp.ones((n_chains, 1)), - data, - ).compile() - self.update_vmap = jax.jit(self.update_vmap).lower( - 1, - ( + self.logpdf_vmap = ( + jax.jit(self.logpdf_vmap) + .lower(jnp.ones((n_chains, n_dims)), data) + .compile() + ) + self.kernel_vmap = ( + jax.jit(self.kernel_vmap) + .lower( key, - jnp.ones((n_chains, n_step, n_dims)), - jnp.ones((n_chains, n_step, 1)), - jnp.zeros((n_chains, n_step, 1)), + jnp.ones((n_chains, n_dims)), + jnp.ones((n_chains, 1)), data, - ), - ).compile() + ) + .compile() + ) + self.update_vmap = ( + jax.jit(self.update_vmap) + .lower( + 1, + ( + key, + jnp.ones((n_chains, n_step, n_dims)), + jnp.ones((n_chains, n_step, 1)), + jnp.zeros((n_chains, n_step, 1)), + data, + ), + ) + .compile() + ) else: print("jit is not requested, compiling only vmap functions...") key = jax.random.split(jax.random.PRNGKey(0), n_chains) @@ -71,7 +85,6 @@ def precompilation(self, n_chains, n_dims, n_step, data): ), ) - @abstractmethod def kernel( self, @@ -87,7 +100,17 @@ def kernel( """ @abstractmethod - def update(self, i, state) -> tuple[ + def update( + self, + i: Float, + state: tuple[ + PRNGKeyArray, + Float[Array, "nstep ndim"], + Float[Array, "nstep 1"], + Int[Array, "n_step 1"], + PyTree, + ], + ) -> tuple[ PRNGKeyArray, Float[Array, "nstep ndim"], Float[Array, "nstep 1"], From 58315e660e76d75be54651fc8b0b757b570d7ea8 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 10:12:32 -0400 Subject: [PATCH 12/23] Update logpdf type hints in GaussianRandomWalk and MALA --- src/flowMC/sampler/Gaussian_random_walk.py | 9 ++++----- src/flowMC/sampler/MALA.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/flowMC/sampler/Gaussian_random_walk.py b/src/flowMC/sampler/Gaussian_random_walk.py index 3262743..11848e7 100644 --- a/src/flowMC/sampler/Gaussian_random_walk.py +++ b/src/flowMC/sampler/Gaussian_random_walk.py @@ -18,7 +18,7 @@ class GaussianRandomWalk(ProposalBase): def __init__( self, - logpdf: Callable, + logpdf: Callable[[Float[Array, "n_dim"], PyTree], Float], jit: bool, params: dict, ): @@ -50,9 +50,10 @@ def kernel( """ key1, key2 = jax.random.split(rng_key) - move_proposal = ( + move_proposal: Float[Array, "n_dim"] = ( jax.random.normal(key1, shape=position.shape) * self.params["step_size"] ) + proposal = position + move_proposal proposal_log_prob = self.logpdf(proposal, data) @@ -63,9 +64,7 @@ def kernel( log_prob = jnp.where(do_accept, proposal_log_prob, log_prob) return position, log_prob, do_accept - def update( - self, i, state - ) -> tuple[ + def update(self, i, state) -> tuple[ PRNGKeyArray, Float[Array, "nstep ndim"], Float[Array, "nstep 1"], diff --git a/src/flowMC/sampler/MALA.py b/src/flowMC/sampler/MALA.py index 1088ebc..2f27bc2 100644 --- a/src/flowMC/sampler/MALA.py +++ b/src/flowMC/sampler/MALA.py @@ -17,7 +17,7 @@ class MALA(ProposalBase): params: dictionary of parameters for the sampler """ - def __init__(self, logpdf: Callable, jit: Bool, params: dict, use_autotune=False): + def __init__(self, logpdf: Callable[[Float[Array, "n_dim"], PyTree], Float], jit: Bool, params: dict, use_autotune=False): super().__init__(logpdf, jit, params) self.params: PyTree = params self.logpdf: Callable = logpdf From 47816a2e28541e93d4a7c16c6c382ad32402a64f Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 10:15:11 -0400 Subject: [PATCH 13/23] Add PRNGKeyArray to MALA sampler --- src/flowMC/sampler/MALA.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/flowMC/sampler/MALA.py b/src/flowMC/sampler/MALA.py index 2f27bc2..33574b1 100644 --- a/src/flowMC/sampler/MALA.py +++ b/src/flowMC/sampler/MALA.py @@ -122,6 +122,7 @@ def sample( data: PyTree, verbose: bool = False, ) -> tuple[ + PRNGKeyArray, Float[Array, "n_chains n_steps ndim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], From 40ab84538a9d990684418bdce0bad687381a12bf Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 10:17:06 -0400 Subject: [PATCH 14/23] Fix variable naming in GaussianRandomWalk class --- src/flowMC/sampler/Gaussian_random_walk.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/flowMC/sampler/Gaussian_random_walk.py b/src/flowMC/sampler/Gaussian_random_walk.py index 11848e7..07c57e7 100644 --- a/src/flowMC/sampler/Gaussian_random_walk.py +++ b/src/flowMC/sampler/Gaussian_random_walk.py @@ -29,22 +29,22 @@ def __init__( def kernel( self, rng_key: PRNGKeyArray, - position: Float[Array, "ndim"], + position: Float[Array, "n_dim"], log_prob: Float[Array, "1"], data: PyTree, - ) -> tuple[Float[Array, "ndim"], Float[Array, "1"], Int[Array, "1"]]: + ) -> tuple[Float[Array, "n_dim"], Float[Array, "1"], Int[Array, "1"]]: """ Random walk gaussian kernel. This is a kernel that only evolve a single chain. Args: rng_key (PRNGKeyArray): Jax PRNGKey - position (Float[Array, "ndim"]): current position of the chain + position (Float[Array, "n_dim"]): current position of the chain log_prob (Float[Array, "1"]): current log-probability of the chain data (PyTree): data to be passed to the logpdf function Returns: - position (Float[Array, "ndim"]): new position of the chain + position (Float[Array, "n_dim"]): new position of the chain log_prob (Float[Array, "1"]): new log-probability of the chain do_accept (Int[Array, "1"]): whether the new position is accepted """ @@ -55,7 +55,7 @@ def kernel( ) proposal = position + move_proposal - proposal_log_prob = self.logpdf(proposal, data) + proposal_log_prob: Float[Array, "n_dim"] = self.logpdf(proposal, data) log_uniform = jnp.log(jax.random.uniform(key2)) do_accept = log_uniform < proposal_log_prob - log_prob @@ -66,7 +66,7 @@ def kernel( def update(self, i, state) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], PyTree, @@ -85,11 +85,12 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, ) -> tuple[ - Float[Array, "n_chains n_steps ndim"], + PRNGKeyArray, + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: From 1bef42ce8fa870da016996a31fbe41a4084b1f6a Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 10:34:55 -0400 Subject: [PATCH 15/23] Fix formatting and type annotations --- ruff.toml | 1 + src/flowMC/sampler/HMC.py | 54 ++++++++++++++++++++--------- src/flowMC/sampler/MALA.py | 47 +++++++++++++++---------- src/flowMC/sampler/NF_proposal.py | 43 ++++++++++++----------- src/flowMC/sampler/Proposal_Base.py | 17 +++++---- 5 files changed, 100 insertions(+), 62 deletions(-) create mode 100644 ruff.toml diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..bf64e2d --- /dev/null +++ b/ruff.toml @@ -0,0 +1 @@ +ignore = ["F722"] \ No newline at end of file diff --git a/src/flowMC/sampler/HMC.py b/src/flowMC/sampler/HMC.py index 849356b..879bf26 100644 --- a/src/flowMC/sampler/HMC.py +++ b/src/flowMC/sampler/HMC.py @@ -17,11 +17,20 @@ class HMC(ProposalBase): params: dictionary of parameters for the sampler """ - def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: + def __init__( + self, + logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float], + jit: bool, + params: dict, + ): super().__init__(logpdf, jit, params) - self.potential = lambda x, data: -logpdf(x, data) - self.grad_potential = jax.grad(self.potential) + self.potential: Callable[ + [Float[Array, " n_dim"], PyTree], Float + ] = lambda x, data: -logpdf(x, data) + self.grad_potential: Callable[ + [Float[Array, " n_dim"], PyTree], Float[Array, " n_dim"] + ] = jax.grad(self.potential) self.params = params if "condition_matrix" in params: @@ -47,15 +56,17 @@ def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable: coefs = coefs.at[-1].set(jnp.array([1, 0.5])) self.leapfrog_coefs = coefs - self.kinetic = lambda p, metric: 0.5 * (p**2 * metric).sum() + self.kinetic: Callable[ + [Float[Array, " n_dim"], Float[Array, " n_dim n_dim"]], Float + ] = (lambda p, metric: 0.5 * (p**2 * metric).sum()) self.grad_kinetic = jax.grad(self.kinetic) def get_initial_hamiltonian( self, rng_key: PRNGKeyArray, - position: jnp.array, - data: jnp.array, - params: dict, + position: Float[Array, " n_dim"], + data: PyTree, + params: PyTree, ): """ Compute the value of the Hamiltonian from positions with initial momentum draw @@ -81,7 +92,13 @@ def leapfrog_kernel(self, carry, extras): index = index + 1 return (position, momentum, data, metric, index), extras - def leapfrog_step(self, position, momentum, data, metric): + def leapfrog_step( + self, + position: Float[Array, " n_dim"], + momentum: Float[Array, " n_dim"], + data: PyTree, + metric: Float[Array, " n_dim n_dim"], + ) -> tuple[Float[Array, " n_dim"], Float[Array, " n_dim"]]: (position, momentum, data, metric, index), _ = jax.lax.scan( self.leapfrog_kernel, (position, momentum, data, metric, 0), @@ -92,22 +109,22 @@ def leapfrog_step(self, position, momentum, data, metric): def kernel( self, rng_key: PRNGKeyArray, - position: Float[Array, "ndim"], + position: Float[Array, " n_dim"], log_prob: Float[Array, "1"], data: PyTree, - ) -> tuple[Float[Array, "ndim"], Float[Array, "1"], Int[Array, "1"]]: + ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]: """ Note that since the potential function is the negative log likelihood, hamiltonian is going down, but the likelihood value should go up. Args: rng_key (n_chains, 2): random key - position (n_chains, n_dim): current position + position (n_chains, n_dim): current position PE (n_chains, ): Potential energy of the current position """ key1, key2 = jax.random.split(rng_key) - momentum = ( + momentum: Float[Array, " n_dim"] = ( jax.random.normal(key1, shape=position.shape) * self.params["condition_matrix"] ** -0.5 ) @@ -129,13 +146,15 @@ def kernel( do_accept = log_uniform < log_acc position = jnp.where(do_accept, proposed_position, position) - log_prob = jnp.where(do_accept, -proposed_PE, log_prob) + log_prob = jnp.where(do_accept, -proposed_PE, log_prob) # type: ignore return position, log_prob, do_accept - def update(self, i, state) -> tuple[ + def update( + self, i, state + ) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], PyTree, @@ -154,11 +173,12 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, ) -> tuple[ - Float[Array, "n_chains n_steps ndim"], + PRNGKeyArray, + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: diff --git a/src/flowMC/sampler/MALA.py b/src/flowMC/sampler/MALA.py index 33574b1..8339fb9 100644 --- a/src/flowMC/sampler/MALA.py +++ b/src/flowMC/sampler/MALA.py @@ -7,9 +7,11 @@ from functools import partialmethod from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray, Bool + class MALA(ProposalBase): """ - Metropolis-adjusted Langevin algorithm sampler class builiding the mala_sampler method + Metropolis-adjusted Langevin algorithm sampler clas + builiding the mala_sampler method Args: logpdf: target logpdf function @@ -17,7 +19,13 @@ class MALA(ProposalBase): params: dictionary of parameters for the sampler """ - def __init__(self, logpdf: Callable[[Float[Array, "n_dim"], PyTree], Float], jit: Bool, params: dict, use_autotune=False): + def __init__( + self, + logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float], + jit: Bool, + params: dict, + use_autotune=False, + ): super().__init__(logpdf, jit, params) self.params: PyTree = params self.logpdf: Callable = logpdf @@ -25,11 +33,11 @@ def __init__(self, logpdf: Callable[[Float[Array, "n_dim"], PyTree], Float], ji def body( self, - carry: tuple[Float[Array, "n_dim"], float, dict], + carry: tuple[Float[Array, " n_dim"], float, dict], this_key: PRNGKeyArray, ) -> tuple[ - tuple[Float[Array, "n_dim"], float, dict], - tuple[Float[Array, "n_dim"], Float[Array, "1"], Float[Array, "n_dim"]], + tuple[Float[Array, " n_dim"], float, dict], + tuple[Float[Array, " n_dim"], Float[Array, "1"], Float[Array, " n_dim"]], ]: print("Compiling MALA body") this_position, dt, data = carry @@ -42,22 +50,22 @@ def body( def kernel( self, rng_key: PRNGKeyArray, - position: Float[Array, "ndim"], + position: Float[Array, " n_dim"], log_prob: Float[Array, "1"], data: PyTree, - ) -> tuple[Float[Array, "ndim"], Float[Array, "1"], Int[Array, "1"]]: + ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]: """ Metropolis-adjusted Langevin algorithm kernel. This is a kernel that only evolve a single chain. Args: rng_key (PRNGKeyArray): Jax PRNGKey - position (Float[Array, "ndim"]): current position of the chain + position (Float[Array, " n_dim"]): current position of the chain log_prob (Float[Array, "1"]): current log-probability of the chain data (PyTree): data to be passed to the logpdf function Returns: - position (Float[Array, "ndim"]): new position of the chain + position (Float[Array, " n_dim"]): new position of the chain log_prob (Float[Array, "1"]): new log-probability of the chain do_accept (Int[Array, "1"]): whether the new position is accepted """ @@ -80,16 +88,18 @@ def kernel( ) log_uniform = jnp.log(jax.random.uniform(key2)) - do_accept: Bool[Array, "n_dim"] = log_uniform < ratio + do_accept: Bool[Array, " n_dim"] = log_uniform < ratio position = jnp.where(do_accept, proposal[0], position) log_prob = jnp.where(do_accept, logprob[1], logprob[0]) return position, log_prob, do_accept - def update(self, i, state) -> tuple[ + def update( + self, i, state + ) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], PyTree, @@ -118,12 +128,12 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, ) -> tuple[ PRNGKeyArray, - Float[Array, "n_chains n_steps ndim"], + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: @@ -156,13 +166,13 @@ def mala_sampler_autotune( Args: mala_kernel_vmap (Callable): A MALA kernel rng_key: Jax PRNGKey - initial_position (n_chains, n_dim): initial position of the chains + initial_position (n_chains, n_dim): initial position of the chains log_prob (n_chains, ): log-probability of the initial position params (dict): parameters of the MALA kernel max_iter (int): maximal number of iterations to tune the step size """ - tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) # type: ignore + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) # type: ignore counter = 0 position, log_prob, do_accept = self.kernel_vmap( @@ -172,7 +182,8 @@ def mala_sampler_autotune( while (acceptance_rate <= 0.3) or (acceptance_rate >= 0.5): if counter > max_iter: print( - "Maximal number of iterations reached. Existing tuning with current parameters." + "Maximal number of iterations reached.\ + Existing tuning with current parameters." ) break if acceptance_rate <= 0.3: @@ -184,5 +195,5 @@ def mala_sampler_autotune( rng_key, initial_position, log_prob, data ) acceptance_rate = jnp.mean(do_accept) - tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) # type: ignore + tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) # type: ignore return params diff --git a/src/flowMC/sampler/NF_proposal.py b/src/flowMC/sampler/NF_proposal.py index 13e0ea8..c935c23 100644 --- a/src/flowMC/sampler/NF_proposal.py +++ b/src/flowMC/sampler/NF_proposal.py @@ -3,16 +3,14 @@ from jax import random from tqdm import tqdm from flowMC.nfmodel.base import NFModel -from jaxtyping import Array, PRNGKeyArray, PyTree from typing import Callable from flowMC.sampler.Proposal_Base import ProposalBase -from jaxtyping import Array, Float, Int, PRNGKeyArray +from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree from math import ceil @jax.tree_util.register_pytree_node_class class NFProposal(ProposalBase): - model: NFModel def __init__( @@ -28,14 +26,14 @@ def __init__( def kernel( self, rng_key: PRNGKeyArray, - initial_position: Float[Array, "ndim"], - proposal_position: Float[Array, "ndim"], + initial_position: Float[Array, " n_dim"], + proposal_position: Float[Array, " n_dim"], log_prob_initial: Float[Array, "1"], log_prob_proposal: Float[Array, "1"], log_prob_nf_initial: Float[Array, "1"], log_prob_nf_proposal: Float[Array, "1"], ) -> tuple[ - Float[Array, "ndim"], Float[Array, "1"], Float[Array, "1"], Int[Array, "1"] + Float[Array, " n_dim"], Float[Array, "1"], Float[Array, "1"], Int[Array, "1"] ]: rng_key, subkey = random.split(rng_key) @@ -49,18 +47,23 @@ def kernel( log_prob_nf = jnp.where(do_accept, log_prob_nf_proposal, log_prob_nf_initial) return position, log_prob, log_prob_nf, do_accept - def update(self, i: int, state: tuple[PRNGKeyArray, - Float[Array, "nstep ndim"], - Float[Array, "nstep ndim"], - Float[Array, "nstep 1"], - Float[Array, "nstep 1"], - Float[Array, "nstep 1"], - Float[Array, "nstep 1"], - Int[Array, "nstep 1"] - ]) -> tuple[ + def update( + self, + i: int, + state: tuple[ + PRNGKeyArray, + Float[Array, "nstep n_dim"], + Float[Array, "nstep n_dim"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Float[Array, "nstep 1"], + Int[Array, "nstep 1"], + ], + ) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Float[Array, "nstep 1"], Float[Array, "nstep 1"], @@ -106,13 +109,13 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, mode: str = "training", ) -> tuple[ PRNGKeyArray, - Float[Array, "n_chains n_steps ndim"], + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: @@ -169,7 +172,7 @@ def sample( def sample_flow( self, rng_key: PRNGKeyArray, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data, n_steps: int, ): diff --git a/src/flowMC/sampler/Proposal_Base.py b/src/flowMC/sampler/Proposal_Base.py index c78f898..4358ba2 100644 --- a/src/flowMC/sampler/Proposal_Base.py +++ b/src/flowMC/sampler/Proposal_Base.py @@ -8,7 +8,10 @@ @jax.tree_util.register_pytree_node_class class ProposalBase: def __init__( - self, logpdf: Callable[[Float[Array, "n_dim"], PyTree], Float], jit: bool, params: dict + self, + logpdf: Callable[[Float[Array, " n_dim"], PyTree], Float], + jit: bool, + params: dict, ): """ Initialize the sampler class @@ -89,11 +92,11 @@ def precompilation(self, n_chains, n_dims, n_step, data): def kernel( self, rng_key: PRNGKeyArray, - position: Float[Array, "nstep ndim"], + position: Float[Array, "nstep n_dim"], log_prob: Float[Array, "nstep 1"], data: PyTree, ) -> tuple[ - Float[Array, "nstep ndim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"] + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"] ]: """ Kernel for one step in the proposal cycle. @@ -105,14 +108,14 @@ def update( i: Float, state: tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], PyTree, ], ) -> tuple[ PRNGKeyArray, - Float[Array, "nstep ndim"], + Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"], PyTree, @@ -126,11 +129,11 @@ def sample( self, rng_key: PRNGKeyArray, n_steps: int, - initial_position: Float[Array, "n_chains ndim"], + initial_position: Float[Array, "n_chains n_dim"], data: PyTree, verbose: bool = False, ) -> tuple[ - Float[Array, "n_chains n_steps ndim"], + Float[Array, "n_chains n_steps n_dim"], Float[Array, "n_chains n_steps 1"], Int[Array, "n_chains n_steps 1"], ]: From 51b71f6fe1aad43cb52e2a2b16e45e311965ce7b Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 10:46:01 -0400 Subject: [PATCH 16/23] Refactor make_training_loop function in utils.py --- src/flowMC/nfmodel/utils.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/flowMC/nfmodel/utils.py b/src/flowMC/nfmodel/utils.py index 4881671..a64507b 100644 --- a/src/flowMC/nfmodel/utils.py +++ b/src/flowMC/nfmodel/utils.py @@ -1,13 +1,15 @@ import jax import jax.numpy as jnp # JAX NumPy -from tqdm import trange +from tqdm import trange, tqdm import optax import equinox as eqx from typing import Callable, Tuple -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, PRNGKeyArray, Float -def make_training_loop(optim: optax.GradientTransformation) -> Callable: +def make_training_loop( + optim: optax.GradientTransformation, +) -> tuple[Callable, Callable, Callable]: """ Create a function that trains an NF model. @@ -24,7 +26,9 @@ def loss_fn(model, x): return -jnp.mean(model.log_prob(x)) @eqx.filter_jit - def train_step(model, x, opt_state): + def train_step( + model: eqx.Module, x: Float[Array, "n_batch n_dim"], opt_state: optax.OptState + ) -> Tuple[Float, eqx.Module, optax.OptState]: """Train for a single step. Args: @@ -42,7 +46,13 @@ def train_step(model, x, opt_state): model = eqx.apply_updates(model, updates) return loss, model, opt_state - def train_epoch(rng, model, state, train_ds, batch_size): + def train_epoch( + rng: PRNGKeyArray, + model: eqx.Module, + state: optax.OptState, + train_ds: Float[Array, "n_example n_dim"], + batch_size: Float, + )-> Tuple[Float, eqx.Module, optax.OptState]: """Train for a single epoch.""" train_ds_size = len(train_ds) steps_per_epoch = train_ds_size // batch_size @@ -67,7 +77,7 @@ def train_flow( num_epochs: int, batch_size: int, verbose: bool = True, - ) -> Tuple[PRNGKeyArray, eqx.Module, Array]: + ) -> Tuple[PRNGKeyArray, eqx.Module, optax.OptState, Array]: """Train a normalizing flow model. Args: @@ -100,6 +110,7 @@ def train_flow( best_model = model best_loss = loss_values[epoch] if verbose: + assert isinstance(pbar, tqdm) if num_epochs > 10: if epoch % int(num_epochs / 10) == 0: pbar.set_description(f"Training NF, current loss: {value:.3f}") From d573de3a1de9fc5fd106747befde9090e2cdabff Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 10:50:49 -0400 Subject: [PATCH 17/23] Add Float type hint to _data_mean and _data_cov --- src/flowMC/nfmodel/rqSpline.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/flowMC/nfmodel/rqSpline.py b/src/flowMC/nfmodel/rqSpline.py index d532644..9af5611 100644 --- a/src/flowMC/nfmodel/rqSpline.py +++ b/src/flowMC/nfmodel/rqSpline.py @@ -1,7 +1,7 @@ from typing import Sequence, Tuple import jax import jax.numpy as jnp -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, PRNGKeyArray, Float import equinox as eqx from flowMC.nfmodel.base import NFModel, Bijection, Distribution @@ -352,8 +352,8 @@ class MaskedCouplingRQSpline(NFModel): base_dist: Distribution layers: list[eqx.Module] _n_features: int - _data_mean: Array - _data_cov: Array + _data_mean: Float[Array, " n_dim"] + _data_cov: Float[Array, " n_dim n_dim"] @property def n_features(self): @@ -371,27 +371,33 @@ def __init__( self, n_features: int, n_layers: int, - hidden_size: Sequence[int], + hidden_size: list[int], num_bins: int, key: PRNGKeyArray, - spline_range: Sequence[float] = (-10.0, 10.0), + spline_range: tuple[float, float] = (-10.0, 10.0), **kwargs ): if kwargs.get("base_dist") is not None: - self.base_dist = kwargs.get("base_dist") + dist = kwargs.get("base_dist") + assert isinstance(dist, Distribution) + self.base_dist = dist else: self.base_dist = Gaussian( jnp.zeros(n_features), jnp.eye(n_features), learnable=False ) if kwargs.get("data_mean") is not None: - self._data_mean = kwargs.get("data_mean") + data_mean = kwargs.get("data_mean") + assert isinstance(data_mean, Array) + self._data_mean = data_mean else: self._data_mean = jnp.zeros(n_features) if kwargs.get("data_cov") is not None: - self._data_cov = kwargs.get("data_cov") + data_cov = kwargs.get("data_cov") + assert isinstance(data_cov, Array) + self._data_cov = data_cov else: self._data_cov = jnp.eye(n_features) From 7f96f0811afb3ea53acb3025f408146e7c49a0af Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 11:46:05 -0400 Subject: [PATCH 18/23] Fix type hints in rqSpline.py --- src/flowMC/nfmodel/rqSpline.py | 54 +++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/src/flowMC/nfmodel/rqSpline.py b/src/flowMC/nfmodel/rqSpline.py index 9af5611..e865209 100644 --- a/src/flowMC/nfmodel/rqSpline.py +++ b/src/flowMC/nfmodel/rqSpline.py @@ -1,4 +1,3 @@ -from typing import Sequence, Tuple import jax import jax.numpy as jnp from jaxtyping import Array, PRNGKeyArray, Float @@ -21,7 +20,7 @@ def _normalize_bin_sizes( @partial(jax.vmap, in_axes=(0, None)) def _normalize_knot_slopes( - unnormalized_knot_slopes: Array, min_knot_slope: float + unnormalized_knot_slopes: Array, min_knot_slope: Float ) -> Array: """Make knot slopes be no less than `min_knot_slope`.""" # The offset is such that the normalized knot slope will be equal to 1 @@ -34,7 +33,7 @@ def _normalize_knot_slopes( @partial(jax.vmap, in_axes=(0, 0, 0, 0)) def _rational_quadratic_spline_fwd( x: Array, x_pos: Array, y_pos: Array, knot_slopes: Array -) -> Tuple[Array, Array]: +) -> tuple[Array, Array]: """Applies a rational-quadratic spline to a scalar. Args: @@ -110,7 +109,9 @@ def _rational_quadratic_spline_fwd( # If x is outside the spline range, we default to a linear transformation. y = jnp.where(below_range, (x - x_pos[0]) * knot_slopes[0] + y_pos[0], y) - y = jnp.where(above_range, (x - x_pos[-1]) * knot_slopes[-1] + y_pos[-1], y) + y = jnp.where( + above_range, (x - x_pos[-1]) * knot_slopes[-1] + y_pos[-1], y + ) # type: ignore logdet = jnp.where(below_range, jnp.log(knot_slopes[0]), logdet) logdet = jnp.where(above_range, jnp.log(knot_slopes[-1]), logdet) return y, logdet @@ -146,7 +147,7 @@ def _safe_quadratic_root(a: Array, b: Array, c: Array) -> Array: @partial(jax.vmap, in_axes=(0, 0, 0, 0)) def _rational_quadratic_spline_inv( y: Array, x_pos: Array, y_pos: Array, knot_slopes: Array -) -> Tuple[Array, Array]: +) -> tuple[Array, Array]: """Applies the inverse of a rational-quadratic spline to a scalar. Args: @@ -217,20 +218,21 @@ def _rational_quadratic_spline_inv( # If y is outside the spline range, we default to a linear transformation. x = jnp.where(below_range, (y - y_pos[0]) / knot_slopes[0] + x_pos[0], x) - x = jnp.where(above_range, (y - y_pos[-1]) / knot_slopes[-1] + x_pos[-1], x) + x = jnp.where( + above_range, (y - y_pos[-1]) / knot_slopes[-1] + x_pos[-1], x + ) # type: ignore logdet = jnp.where(below_range, -jnp.log(knot_slopes[0]), logdet) logdet = jnp.where(above_range, -jnp.log(knot_slopes[-1]), logdet) return x, logdet class RQSpline(Bijection): - _range_min: float _range_max: float _num_bins: int _min_bin_size: float _min_knot_slope: float - conditioner: eqx.Module + conditioner: MLP """A rational-quadratic spline bijection. @@ -279,13 +281,12 @@ def dtype(self): def __init__( self, - conditioner: eqx.Module, + conditioner: MLP, range_min: float, range_max: float, min_bin_size: float = 1e-4, min_knot_slope: float = 1e-4, ): - self._range_min = range_min self._range_max = range_max self._min_bin_size = min_bin_size @@ -294,7 +295,11 @@ def __init__( self.conditioner = conditioner - def get_params(self, x: Array) -> Array: + def get_params( + self, x: Float[Array, " n_condition"] + ) -> tuple[ + Float[Array, " n_param"], Float[Array, " n_param"], Float[Array, " n_param"] + ]: params = self.conditioner(x).reshape(-1, self._num_bins * 3 + 1) unnormalized_bin_widths = params[:, : self._num_bins] unnormalized_bin_heights = params[:, self._num_bins : 2 * self._num_bins] @@ -320,14 +325,14 @@ def get_params(self, x: Array) -> Array: ) return x_pos, y_pos, knot_slopes - def __call__(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Array, condition_x: Array) -> tuple[Array, Array]: return self.forward(x, condition_x) - def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def forward(self, x: Array, condition_x: Array) -> tuple[Array, Array]: x_pos, y_pos, knot_slopes = self.get_params(condition_x) return _rational_quadratic_spline_fwd(x, x_pos, y_pos, knot_slopes) - def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, Array]: + def inverse(self, x: Array, condition_x: Array) -> tuple[Array, Array]: x_pos, y_pos, knot_slopes = self.get_params(condition_x) return _rational_quadratic_spline_inv(x, x_pos, y_pos, knot_slopes) @@ -350,7 +355,7 @@ class MaskedCouplingRQSpline(NFModel): """ base_dist: Distribution - layers: list[eqx.Module] + layers: list[Bijection] _n_features: int _data_mean: Float[Array, " n_dim"] _data_cov: Float[Array, " n_dim n_dim"] @@ -377,7 +382,6 @@ def __init__( spline_range: tuple[float, float] = (-10.0, 10.0), **kwargs ): - if kwargs.get("base_dist") is not None: dist = kwargs.get("base_dist") assert isinstance(dist, Distribution) @@ -427,10 +431,14 @@ def __init__( mask = jnp.logical_not(mask) self.layers = layers - def __call__(self, x: Array) -> Tuple[Array, Array]: + def __call__( + self, x: Float[Array, " n_dim"] + ) -> tuple[Float[Array, " n_dim"], Float]: return self.forward(x) - def forward(self, x: Array) -> Tuple[Array, Array]: + def forward( + self, x: Float[Array, " n_dim"] + ) -> tuple[Float[Array, " n_dim"], Float]: log_det = 0.0 for layer in self.layers: x, log_det_i = layer(x) @@ -438,7 +446,9 @@ def forward(self, x: Array) -> Tuple[Array, Array]: return x, log_det @partial(jax.vmap, in_axes=(None, 0)) - def inverse(self, x: Array) -> Tuple[Array, Array]: + def inverse( + self, x: Float[Array, " n_dim"] + ) -> tuple[Float[Array, " n_dim"], Float]: """From latent space to data space""" log_det = 0.0 for layer in reversed(self.layers): @@ -447,7 +457,9 @@ def inverse(self, x: Array) -> Tuple[Array, Array]: return x, log_det @eqx.filter_jit - def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> Float[Array, "n_samples n_dim"]: samples = self.base_dist.sample(rng_key, n_samples) samples = self.inverse(samples)[0] samples = samples * jnp.sqrt(jnp.diag(self.data_cov)) + self.data_mean @@ -455,7 +467,7 @@ def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: @eqx.filter_jit @partial(jax.vmap, in_axes=(None, 0)) - def log_prob(self, x: Array) -> Array: + def log_prob(self, x: Float[Array, "n_sample n_dim"]) -> Float[Array, " n_sample"]: """From data space to latent space""" x = (x - self.data_mean) / jnp.sqrt(jnp.diag(self.data_cov)) y, log_det = self.__call__(x) From 2183470d9b8dbee79a0da342c5b735fd23bd0c91 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 11:49:26 -0400 Subject: [PATCH 19/23] Update version and Python requirement --- setup.cfg | 4 ++-- src/flowMC/utils/PRNG_keys.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 686aca5..22a133c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = flowMC -version = 0.2.4 +version = 0.2.5 author = Kaze Wong, Marylou GabriƩ, Dan Foreman-Mackey author_email = kazewong.physics@gmail.com url = https://github.com/kazewong/flowMC @@ -23,7 +23,7 @@ install_requires = tqdm corner -python_requires = >=3.9 +python_requires = >=3.10 [options.packages.find] where=src diff --git a/src/flowMC/utils/PRNG_keys.py b/src/flowMC/utils/PRNG_keys.py index f69e99c..1ce51a6 100644 --- a/src/flowMC/utils/PRNG_keys.py +++ b/src/flowMC/utils/PRNG_keys.py @@ -16,7 +16,7 @@ def initialize_rng_keys(n_chains, seed=42): rng_keys_nf (Device Array): RNG keys for the normalizing flow global sampler. init_rng_keys_nf (Device Array): RNG keys for initializing wieght of the normalizing flow model. """ - rng_key = PRNGKeyArray(seed) + rng_key = jax.random.PRNGKey(seed) rng_key_init, rng_key_mcmc, rng_key_nf = jax.random.split(rng_key, 3) rng_keys_mcmc = jax.random.split(rng_key_mcmc, n_chains) From df9e9906ae91a0ba0132c5325d78c8b2ec2d0ed8 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 12:03:18 -0400 Subject: [PATCH 20/23] Fix array shape in ProposalBase class --- src/flowMC/sampler/Proposal_Base.py | 12 ++++++------ test/integration/test_MALA.py | 1 - 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/flowMC/sampler/Proposal_Base.py b/src/flowMC/sampler/Proposal_Base.py index 4358ba2..e89589d 100644 --- a/src/flowMC/sampler/Proposal_Base.py +++ b/src/flowMC/sampler/Proposal_Base.py @@ -48,7 +48,7 @@ def precompilation(self, n_chains, n_dims, n_step, data): .lower( key, jnp.ones((n_chains, n_dims)), - jnp.ones((n_chains, 1)), + jnp.ones((n_chains, )), data, ) .compile() @@ -60,8 +60,8 @@ def precompilation(self, n_chains, n_dims, n_step, data): ( key, jnp.ones((n_chains, n_step, n_dims)), - jnp.ones((n_chains, n_step, 1)), - jnp.zeros((n_chains, n_step, 1)), + jnp.ones((n_chains, n_step, )), + jnp.zeros((n_chains, n_step, )), data, ), ) @@ -74,7 +74,7 @@ def precompilation(self, n_chains, n_dims, n_step, data): self.kernel_vmap( key, jnp.ones((n_chains, n_dims)), - jnp.ones((n_chains, 1)), + jnp.ones((n_chains, )), data, ) self.update_vmap( @@ -82,8 +82,8 @@ def precompilation(self, n_chains, n_dims, n_step, data): ( key, jnp.ones((n_chains, n_step, n_dims)), - jnp.ones((n_chains, n_step, 1)), - jnp.zeros((n_chains, n_step, 1)), + jnp.ones((n_chains, n_step, )), + jnp.zeros((n_chains, n_step, )), data, ), ) diff --git a/test/integration/test_MALA.py b/test/integration/test_MALA.py index 4aa2440..2415808 100644 --- a/test/integration/test_MALA.py +++ b/test/integration/test_MALA.py @@ -9,7 +9,6 @@ def dual_moon_pe(x, data): """ Term 2 and 3 separate the distribution and smear it along the first and second dimension """ - print("compile count") term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2 term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 From 0dbe3c98e6a1c44d4ad24b0231fd1ff0b14154a0 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 12:26:32 -0400 Subject: [PATCH 21/23] Fix data type in Sampler.py and update test files --- src/flowMC/sampler/Sampler.py | 2 +- test/integration/test_HMC.py | 4 ++-- test/integration/test_MALA.py | 21 +++++++++++++-------- test/integration/test_RWMCMC.py | 19 +++++++++++-------- test/integration/test_normalizingFlow.py | 2 +- 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/flowMC/sampler/Sampler.py b/src/flowMC/sampler/Sampler.py index 81de1d5..afa13cb 100644 --- a/src/flowMC/sampler/Sampler.py +++ b/src/flowMC/sampler/Sampler.py @@ -71,7 +71,7 @@ def __init__( self, n_dim: int, rng_key_set: Tuple, - data: jnp.ndarray, + data: dict, local_sampler: ProposalBase, nf_model: NFModel, **kwargs, diff --git a/test/integration/test_HMC.py b/test/integration/test_HMC.py index 1f3ee4c..92f2f13 100644 --- a/test/integration/test_HMC.py +++ b/test/integration/test_HMC.py @@ -40,7 +40,6 @@ def dual_moon_pe(x: Float[Array, "n_dim"], data: dict): initial_PE = HMC_sampler.logpdf_vmap(initial_position, data) -HMC_sampler.precompilation(n_chains, n_dim, n_local_steps, data) initial_position = jnp.repeat(initial_position[:, None], n_local_steps, 1) initial_PE = jnp.repeat(initial_PE[:, None], n_local_steps, 1) @@ -80,7 +79,7 @@ def dual_moon_pe(x: Float[Array, "n_dim"], data: dict): nf_sampler = Sampler( n_dim, rng_key_set, - jnp.arange(5), + data, HMC_sampler, model, n_loop_training=n_loop_training, @@ -89,6 +88,7 @@ def dual_moon_pe(x: Float[Array, "n_dim"], data: dict): n_global_steps=n_global_steps, n_chains=n_chains, use_global=False, + precompile=True, ) nf_sampler.sample(initial_position, data) diff --git a/test/integration/test_MALA.py b/test/integration/test_MALA.py index 2415808..1b34888 100644 --- a/test/integration/test_MALA.py +++ b/test/integration/test_MALA.py @@ -3,13 +3,15 @@ import jax import jax.numpy as jnp from jax.scipy.special import logsumexp +from jaxtyping import Float, Array -def dual_moon_pe(x, data): +def dual_moon_pe(x: Float[Array, "n_dim"], data: dict): """ Term 2 and 3 separate the distribution and smear it along the first and second dimension """ - term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2 + print("compile count") + term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2 term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 return -(term1 - logsumexp(term2) - logsumexp(term3)) @@ -21,7 +23,7 @@ def dual_moon_pe(x, data): step_size = 0.01 n_leapfrog = 10 -data = jnp.arange(5) +data = {"data": jnp.arange(5)} rng_key_set = initialize_rng_keys(n_chains, seed=42) @@ -29,9 +31,6 @@ def dual_moon_pe(x, data): MALA_Sampler = MALA(dual_moon_pe, True, {"step_size": step_size}) -MALA_Sampler.precompilation(n_chains, n_dim, n_local_steps, data) - - initial_position = jnp.repeat(initial_position[:, None], n_local_steps, 1) initial_logp = jnp.repeat( jax.vmap(dual_moon_pe, in_axes=(0, None))(initial_position[:, 0], data)[:, None], @@ -43,7 +42,12 @@ def dual_moon_pe(x, data): rng_key_set[1], initial_position, initial_logp, - jnp.zeros((n_chains, n_local_steps, 1)), + jnp.zeros( + ( + n_chains, + n_local_steps, + ) + ), data, ) @@ -76,7 +80,7 @@ def dual_moon_pe(x, data): nf_sampler = Sampler( n_dim, rng_key_set, - jnp.arange(5), + data, MALA_Sampler, model, n_loop_training=n_loop_training, @@ -86,6 +90,7 @@ def dual_moon_pe(x, data): n_chains=n_chains, # local_autotune=mala_sampler_autotune, use_global=False, + precompile=True, ) nf_sampler.sample(initial_position, data) diff --git a/test/integration/test_RWMCMC.py b/test/integration/test_RWMCMC.py index 3ac1c42..eb1da0d 100644 --- a/test/integration/test_RWMCMC.py +++ b/test/integration/test_RWMCMC.py @@ -3,14 +3,15 @@ import jax import jax.numpy as jnp from jax.scipy.special import logsumexp +from jaxtyping import Float, Array -def dual_moon_pe(x, data): +def dual_moon_pe(x: Float[Array, "n_dim"], data: dict): """ Term 2 and 3 separate the distribution and smear it along the first and second dimension """ print("compile count") - term1 = 0.5 * ((jnp.linalg.norm(x - data) - 2) / 0.1) ** 2 + term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2 term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 return -(term1 - logsumexp(term2) - logsumexp(term3)) @@ -22,7 +23,7 @@ def dual_moon_pe(x, data): step_size = 0.1 n_leapfrog = 10 -data = jnp.arange(5) +data = {"data": jnp.arange(5)} rng_key_set = initialize_rng_keys(n_chains, seed=42) @@ -30,26 +31,28 @@ def dual_moon_pe(x, data): RWMCMC_sampler = GaussianRandomWalk(dual_moon_pe, True, {"step_size": step_size}) -RWMCMC_sampler.precompilation(n_chains, n_dim, n_local_steps, data) - initial_position = jnp.repeat(initial_position[:, None], n_local_steps, 1) initial_logp = jnp.repeat( jax.vmap(dual_moon_pe, in_axes=(0, None))(initial_position[:, 0], data)[:, None], n_local_steps, 1, -)[..., None] +) state = ( rng_key_set[1], initial_position, initial_logp, - jnp.zeros((n_chains, n_local_steps, 1)), + jnp.zeros( + ( + n_chains, + n_local_steps, + ) + ), data, ) RWMCMC_sampler.update_vmap(1, state) - state = RWMCMC_sampler.sample( rng_key_set[1], n_local_steps, initial_position[:, 0], data ) diff --git a/test/integration/test_normalizingFlow.py b/test/integration/test_normalizingFlow.py index cc2e28b..d799fdc 100644 --- a/test/integration/test_normalizingFlow.py +++ b/test/integration/test_normalizingFlow.py @@ -60,5 +60,5 @@ def test_rqSpline(): rng, best_model, state, loss_values = train_flow( rng, model, data, state, num_epochs, batch_size, verbose=True ) - rng_key_nf = PRNGKeyArray(124098) + rng_key_nf = jax.random.PRNGKey(124098) model.sample(rng_key_nf, 10000) From 50d99d80d567770ad36e5dade5c09b4a503621ed Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 12:29:47 -0400 Subject: [PATCH 22/23] Update log_posterior function to accept data as a dictionary --- test/integration/test_quickstart.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/integration/test_quickstart.py b/test/integration/test_quickstart.py index 3353c0b..72066ea 100644 --- a/test/integration/test_quickstart.py +++ b/test/integration/test_quickstart.py @@ -7,11 +7,11 @@ from flowMC.nfmodel.utils import * -def log_posterior(x, data): - return -0.5 * jnp.sum((x - data) ** 2) +def log_posterior(x, data: dict): + return -0.5 * jnp.sum((x - data['data']) ** 2) -data = jnp.arange(5) +data = {'data':jnp.arange(5)} n_dim = 5 n_chains = 10 @@ -25,7 +25,7 @@ def log_posterior(x, data): nf_sampler = Sampler( n_dim, rng_key_set, - jnp.arange(n_dim), + data, local_sampler, model, n_local_steps=50, From eaeba8c380ee39d47e88e30dea1148e5ba4ac08b Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 26 Mar 2024 12:33:00 -0400 Subject: [PATCH 23/23] Fix PRNGKeyArray usage in test_kernels.py and test_nf.py --- test/unit/test_kernels.py | 4 ++-- test/unit/test_nf.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/unit/test_kernels.py b/test/unit/test_kernels.py index bf184b7..98780e4 100644 --- a/test/unit/test_kernels.py +++ b/test/unit/test_kernels.py @@ -254,7 +254,7 @@ def test_Gaussian_random_walk_close_gaussian(self): class TestNF: def test_NF_kernel(self): - key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3) + key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) data = jax.random.normal(key1, (100, 2)) num_epochs = 5 @@ -278,7 +278,7 @@ def test_NF_kernel(self): rng, self.model, state, loss_values = train_flow( rng, model, data, state, num_epochs, batch_size, verbose=True ) - key1, rng, init_rng = jax.random.split(PRNGKeyArray(1), 3) + key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(1), 3) n_dim = 2 n_chains = 1 diff --git a/test/unit/test_nf.py b/test/unit/test_nf.py index d162ec7..3a454ec 100644 --- a/test/unit/test_nf.py +++ b/test/unit/test_nf.py @@ -9,7 +9,7 @@ def test_affine_coupling_forward_and_inverse(): n_hidden = 4 x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) mask = jnp.where(jnp.arange(n_features) % 2 == 0, 1.0, 0.0) - key = PRNGKeyArray(0) + key = jax.random.PRNGKey(0) dt = 0.5 layer = AffineCoupling(n_features, n_hidden, mask, key, dt) @@ -26,7 +26,7 @@ def test_realnvp(): n_layer = 2 x = jnp.array([[1, 2, 3], [4, 5, 6]]) - rng_key, rng_subkey = jax.random.split(PRNGKeyArray(0), 2) + rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(0), 2) model = RealNVP(n_features, n_layer, n_hidden, rng_key) y, log_det = jax.vmap(model)(x) @@ -41,7 +41,7 @@ def test_realnvp(): assert jnp.allclose(x, y_inv) assert jnp.allclose(log_det, -log_det_inv) - rng_key = PRNGKeyArray(0) + rng_key = jax.random.PRNGKey(0) samples = model.sample(rng_key, 2) assert samples.shape == (2, 3) @@ -57,14 +57,14 @@ def test_rqspline(): n_layer = 2 n_bins = 8 - rng_key, rng_subkey = jax.random.split(PRNGKeyArray(0), 2) + rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(0), 2) model = MaskedCouplingRQSpline( - n_features, n_layer, hidden_layes, n_bins, PRNGKeyArray(10) + n_features, n_layer, hidden_layes, n_bins, jax.random.PRNGKey(10) ) jnp.array([[1, 2, 3], [4, 5, 6]]) - rng_key = PRNGKeyArray(0) + rng_key = jax.random.PRNGKey(0) samples = model.sample(rng_key, 2) assert samples.shape == (2, 3)