Skip to content

Commit

Permalink
Merge pull request #151 from kazewong/134-add-precommit-script-and-fi…
Browse files Browse the repository at this point in the history
…x-formatting

134 add precommit script and fix formatting
  • Loading branch information
kazewong authored Mar 26, 2024
2 parents cb7ef6e + eaeba8c commit da8cbc3
Show file tree
Hide file tree
Showing 29 changed files with 433 additions and 281 deletions.
2 changes: 1 addition & 1 deletion docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ n_chains = 10
rng_key_set = initialize_rng_keys(n_chains, seed=42)
initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1
model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, jax.random.PRNGKey(21))
model = MaskedCouplingRQSpline(n_dim, 3, [64, 64], 8, PRNGKeyArray(21))
step_size = 1e-1
local_sampler = MALA(log_posterior, True, {"step_size": step_size})
Expand Down
2 changes: 1 addition & 1 deletion example/dualmoon.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def target_dualmoon(x, data):
data = jnp.zeros(n_dim)

rng_key_set = initialize_rng_keys(n_chains, 42)
model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))
model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))

initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1

Expand Down
2 changes: 1 addition & 1 deletion example/non_jax_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def neal_funnel(x):
data = jnp.zeros(n_dim)

rng_key_set = initialize_rng_keys(n_chains, 42)
model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))
model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))

initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1

Expand Down
2 changes: 1 addition & 1 deletion example/notebook/analyzingChains.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
"print(\"Initializing chains, normalizing flow model and local MCMC sampler\")\n",
"\n",
"initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1\n",
"model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))\n",
"model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))\n",
"MALA_Sampler = MALA(target_dualmoon, True, {\"step_size\": step_size})\n",
"\n",
"print(\"Initializing samplers classes\")\n",
Expand Down
2 changes: 1 addition & 1 deletion example/notebook/dualmoon.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"data = jnp.zeros(n_dim)\n",
"\n",
"model = MaskedCouplingRQSpline(\n",
" n_dim, n_layers, hidden_size, num_bins, jax.random.PRNGKey(10)\n",
" n_dim, n_layers, hidden_size, num_bins, PRNGKeyArray(10)\n",
")"
]
},
Expand Down
2 changes: 1 addition & 1 deletion example/notebook/maximizing_likelihood.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
"\n",
"\n",
"rng_key_set = initialize_rng_keys(n_chains, 42)\n",
"model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))\n",
"model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))\n",
"\n",
"initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion example/notebook/mog_pretrain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
"## To use instead the more powerful RQSpline:\n",
"from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline\n",
"\n",
"model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, jax.random.PRNGKey(10))\n",
"model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, PRNGKeyArray(10))\n",
"\n",
"\n",
"# Local sampler\n",
Expand Down
4 changes: 2 additions & 2 deletions example/notebook/normalizingFlow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"n_layers = 10\n",
"n_hidden = 100\n",
"\n",
"key, subkey = jax.random.split(jax.random.PRNGKey(0), 2)\n",
"key, subkey = jax.random.split(PRNGKeyArray(0), 2)\n",
"\n",
"model = RealNVP(\n",
" n_feature,\n",
Expand Down Expand Up @@ -246,7 +246,7 @@
"n_hiddens = [64, 64]\n",
"n_bins = 8\n",
"\n",
"key, subkey = jax.random.split(jax.random.PRNGKey(1))\n",
"key, subkey = jax.random.split(PRNGKeyArray(1))\n",
"\n",
"model = MaskedCouplingRQSpline(\n",
" n_feature,\n",
Expand Down
4 changes: 2 additions & 2 deletions example/train_normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
data = make_moons(n_samples=20000, noise=0.05)
data = jnp.array(data[0])

key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3)
key1, rng, init_rng = jax.random.split(PRNGKeyArray(0), 3)

model = MaskedCouplingRQSpline(
2,
Expand All @@ -40,4 +40,4 @@

key, model, loss = train_flow(rng, model, data, num_epochs, batch_size, verbose=True)

nf_samples = model.sample(jax.random.PRNGKey(124098), 5000)
nf_samples = model.sample(PRNGKeyArray(124098), 5000)
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ignore = ["F722"]
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = flowMC
version = 0.2.4
version = 0.2.5
author = Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
author_email = [email protected]
url = https://github.com/kazewong/flowMC
Expand All @@ -23,7 +23,7 @@ install_requires =
tqdm
corner

python_requires = >=3.9
python_requires = >=3.10

[options.packages.find]
where=src
29 changes: 14 additions & 15 deletions src/flowMC/nfmodel/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod, abstractproperty
from typing import Tuple
import equinox as eqx
import jax
from jaxtyping import Array, PRNGKeyArray, Float
Expand All @@ -15,48 +14,48 @@ class NFModel(eqx.Module):
def __init__(self):
return NotImplemented

def __call__(self, x: Array) -> Tuple[Array, Array]:
def __call__(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]:
"""
Forward pass of the model.
Args:
x (Array): Input data.
x (Float[Array, "n_dim"]): Input data.
Returns:
Tuple[Array, Array]: Output data and log determinant of the Jacobian.
tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian.
"""
return self.forward(x)

@abstractmethod
def log_prob(self, x: Array) -> Array:
def log_prob(self, x: Float[Array, "n_dim"]) -> Float:
return NotImplemented

@abstractmethod
def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array:
def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array:
return NotImplemented

@abstractmethod
def forward(self, x: Array) -> Tuple[Array, Array]:
def forward(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]:
"""
Forward pass of the model.
Args:
x (Array): Input data.
x (Float[Array, "n_dim"]): Input data.
Returns:
Tuple[Array, Array]: Output data and log determinant of the Jacobian."""
tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian."""
return NotImplemented

@abstractmethod
def inverse(self, x: Array) -> Tuple[Array, Array]:
def inverse(self, x: Float[Array, "n_dim"]) -> tuple[Float[Array, "n_dim"], Float]:
"""
Inverse pass of the model.
Args:
x (Array): Input data.
x (Float[Array, "n_dim"]): Input data.
Returns:
Tuple[Array, Array]: Output data and log determinant of the Jacobian."""
tuple[Float[Array, "n_dim"], Float]: Output data and log determinant of the Jacobian."""
return NotImplemented

@abstractproperty
Expand All @@ -81,15 +80,15 @@ class Bijection(eqx.Module):
def __init__(self):
return NotImplemented

def __call__(self, x: Array) -> Tuple[Array, Array]:
def __call__(self, x: Array) -> tuple[Array, Array]:
return self.forward(x)

@abstractmethod
def forward(self, x: Array) -> Tuple[Array, Array]:
def forward(self, x: Array) -> tuple[Array, Array]:
return NotImplemented

@abstractmethod
def inverse(self, x: Array) -> Tuple[Array, Array]:
def inverse(self, x: Array) -> tuple[Array, Array]:
return NotImplemented


Expand Down
67 changes: 34 additions & 33 deletions src/flowMC/nfmodel/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Iterable, Tuple
from typing import Callable, List, Tuple

import jax
import jax.numpy as jnp
Expand All @@ -10,15 +10,15 @@
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Array
from jaxtyping import Array, Float, PRNGKeyArray


class MLP(eqx.Module):
r"""Multilayer perceptron.
Args:
shape (Iterable[int]): Shape of the MLP. The first element is the input dimension, the last element is the output dimension.
key (jax.random.PRNGKey): Random key.
shape (List[int]): Shape of the MLP. The first element is the input dimension, the last element is the output dimension.
key (PRNGKeyArray): Random key.
Attributes:
layers (List): List of layers.
Expand All @@ -29,9 +29,9 @@ class MLP(eqx.Module):

def __init__(
self,
shape: Iterable[int],
key: jax.random.PRNGKey,
scale: float = 1e-4,
shape: List[int],
key: PRNGKeyArray,
scale: Float = 1e-4,
activation: Callable = jax.nn.relu,
use_bias: bool = True,
):
Expand All @@ -52,7 +52,7 @@ def __init__(
eqx.nn.Linear(shape[-2], shape[-1], key=subkey, use_bias=use_bias)
)

def __call__(self, x: Array):
def __call__(self, x: Float[Array, "n_in"]) -> Float[Array, "n_out"]:
for layer in self.layers:
x = layer(x)
return x
Expand Down Expand Up @@ -83,25 +83,25 @@ class MaskedCouplingLayer(Bijection):
"""

_mask: Array
_mask: Float[Array, "n_dim"]
bijector: Bijection

@property
def mask(self):
def mask(self) -> Float[Array, "n_dim"]:
return jax.lax.stop_gradient(self._mask)

def __init__(self, bijector: Bijection, mask: Array):
def __init__(self, bijector: Bijection, mask: Float[Array, "n_dim"]):
self.bijector = bijector
self._mask = mask

def forward(self, x: Array) -> Tuple[Array, Array]:
y, log_det = self.bijector(x, x * self.mask)
def forward(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]:
y, log_det = self.bijector(x, x * self.mask) # type: ignore
y = (1 - self.mask) * y + self.mask * x
log_det = ((1 - self.mask) * log_det).sum()
return y, log_det

def inverse(self, x: Array) -> Tuple[Array, Array]:
y, log_det = self.bijector.inverse(x, x * self.mask)
def inverse(self, x: Float[Array, "n_dim"]) -> Tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]:
y, log_det = self.bijector.inverse(x, x * self.mask) # type: ignore
y = (1 - self.mask) * y + self.mask * x
log_det = ((1 - self.mask) * log_det).sum()
return y, log_det
Expand All @@ -110,17 +110,17 @@ def inverse(self, x: Array) -> Tuple[Array, Array]:
class MLPAffine(Bijection):
scale_MLP: MLP
shift_MLP: MLP
dt: float = 1
dt: Float = 1

def __init__(self, scale_MLP: MLP, shift_MLP: MLP, dt: float = 1):
def __init__(self, scale_MLP: MLP, shift_MLP: MLP, dt: Float = 1):
self.scale_MLP = scale_MLP
self.shift_MLP = shift_MLP
self.dt = dt

def __call__(self, x: Array, condition_x: Array) -> Tuple[Array, Array]:
def __call__(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]:
return self.forward(x, condition_x)

def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]:
def forward(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]:
# Note that this note output log_det as an array instead of a number.
# This is because we need to sum over the log_det in the masked coupling layer.
scale = jnp.tanh(self.scale_MLP(condition_x)) * self.dt
Expand All @@ -129,7 +129,7 @@ def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]:
y = (x + shift) * jnp.exp(scale)
return y, log_det

def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, Array]:
def inverse(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]:
scale = jnp.tanh(self.scale_MLP(condition_x)) * self.dt
shift = self.shift_MLP(condition_x) * self.dt
log_det = -scale
Expand All @@ -141,19 +141,19 @@ class ScalarAffine(Bijection):
scale: Array
shift: Array

def __init__(self, scale: float, shift: float):
def __init__(self, scale: Float, shift: Float):
self.scale = jnp.array(scale)
self.shift = jnp.array(shift)

def __call__(self, x: Array, condition_x: Array) -> Tuple[Array, Array]:
def __call__(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]:
return self.forward(x, condition_x)

def forward(self, x: Array, condition_x: Array) -> Tuple[Array, Array]:
def forward(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]:
y = (x + self.shift) * jnp.exp(self.scale)
log_det = self.scale
return y, log_det

def inverse(self, x: Array, condition_x: Array) -> Tuple[Array, Array]:
def inverse(self, x: Float[Array, "n_dim"], condition_x: Float[Array, "n_cond"]) -> Tuple[Float[Array, "n_dim"], Float]:
y = x * jnp.exp(-self.scale) - self.shift
log_det = -self.scale
return y, log_det
Expand All @@ -173,33 +173,33 @@ class Gaussian(Distribution):
cov (Array): Covariance matrix.
"""

_mean: Array
_cov: Array
_mean: Float[Array, "n_dim"]
_cov: Float[Array, "n_dim n_dim"]
learnable: bool = False

@property
def mean(self) -> Array:
def mean(self) -> Float[Array, "n_dim"]:
if self.learnable:
return self._mean
else:
return jax.lax.stop_gradient(self._mean)

@property
def cov(self) -> Array:
def cov(self) -> Float[Array, "n_dim n_dim"]:
if self.learnable:
return self._cov
else:
return jax.lax.stop_gradient(self._cov)

def __init__(self, mean: Array, cov: Array, learnable: bool = False):
def __init__(self, mean: Float[Array, "n_dim"], cov: Float[Array, "n_dim n_dim"], learnable: bool = False):
self._mean = mean
self._cov = cov
self.learnable = learnable

def log_prob(self, x: Array) -> Array:
def log_prob(self, x: Float[Array, "n_dim"]) -> Float:
return jax.scipy.stats.multivariate_normal.logpdf(x, self.mean, self.cov)

def sample(self, key: jax.random.PRNGKey, n_samples: int = 1) -> Array:
def sample(self, key: PRNGKeyArray, n_samples: int = 1) -> Float[Array, "n_samples n_dim"]:
return jax.random.multivariate_normal(key, self.mean, self.cov, (n_samples,))


Expand All @@ -212,14 +212,15 @@ def __init__(self, distributions: list[Distribution], partitions: dict):
self.distributions = distributions
self.partitions = partitions

def log_prob(self, x: Array) -> Array:
def log_prob(self, x: Float[Array, "n_dim"]) -> Float:
log_prob = 0
for dist, (_, ranges) in zip(self.distributions, self.partitions.items()):
log_prob += dist.log_prob(x[ranges[0] : ranges[1]])
return log_prob

def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array:
def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> dict[str, Float[Array, "n_samples n_dim"]]:
samples = {}
for dist, (key, _) in zip(self.distributions, self.partitions.items()):
rng_key, sub_key = jax.random.split(rng_key)
samples[key] = dist.sample(sub_key, n_samples=n_samples)
return samples
Loading

0 comments on commit da8cbc3

Please sign in to comment.