diff --git a/docs/communityExamples.md b/docs/communityExamples.md index 4effc93..6252ae2 100644 --- a/docs/communityExamples.md +++ b/docs/communityExamples.md @@ -9,4 +9,5 @@ If you find flowMC useful, please consider contributing your example to this pag - [jim - A JAX-based gravitational-wave inference toolkit](https://github.com/kazewong/jim) - [Bayeux - Stitching together models and samplers](https://github.com/jax-ml/bayeux) - - [Colab example](https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing) \ No newline at end of file + - [Colab example](https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing) +- [Markovian Flow Matching: Accelerating MCMC with Continuous Normalizing Flows](https://arxiv.org/pdf/2405.14392) \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index bd32912..a99d077 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,8 +15,8 @@ packages_dir= =src packages = find: install_requires = - jax>=0.4.12 - jaxlib>=0.4.12 + jax>=0.4.24 + jaxlib>=0.4.24 equinox>=0.10.6 optax>=0.1.5 evosax>=0.1.4 diff --git a/src/flowMC/nfmodel/base.py b/src/flowMC/nfmodel/base.py index 57986d0..111862d 100644 --- a/src/flowMC/nfmodel/base.py +++ b/src/flowMC/nfmodel/base.py @@ -191,6 +191,28 @@ def train( return rng, best_model, best_state, loss_values + def to_precision(self, precision: str = "float32"): + """Convert all parameters to a given precision. + + !!! warning + This function is **experimental** and may change in the future. + + Args: + precision (str): Precision to convert to. + + Returns: + eqx.Module: Model with parameters converted to the given precision. + """ + + precisions_dict = {"float16": jnp.float16, "bfloat16": jnp.bfloat16, "float32": jnp.float32, "float64": jnp.float64} + try: + precision_format = precisions_dict[precision.lower()] + except KeyError: + raise ValueError(f"Precision {precision} not supported. Choose from {precisions_dict.keys()}") + dynamic_model, static_model = eqx.partition(self, eqx.is_array) + dynamic_model = jax.tree.map(lambda x: x.astype(precision_format), dynamic_model) + return eqx.combine(dynamic_model, static_model) + class Bijection(eqx.Module): """