Skip to content

Commit

Permalink
Merge pull request #171 from kazewong/170-lower-precision-training
Browse files Browse the repository at this point in the history
170 lower precision training
  • Loading branch information
kazewong authored May 26, 2024
2 parents 94eee83 + fe28277 commit 21b1bcc
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/communityExamples.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ If you find flowMC useful, please consider contributing your example to this pag

- [jim - A JAX-based gravitational-wave inference toolkit](https://github.com/kazewong/jim)
- [Bayeux - Stitching together models and samplers](https://github.com/jax-ml/bayeux)
- [Colab example](https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing)
- [Colab example](https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing)
- [Markovian Flow Matching: Accelerating MCMC with Continuous Normalizing Flows](https://arxiv.org/pdf/2405.14392)
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ packages_dir=
=src
packages = find:
install_requires =
jax>=0.4.12
jaxlib>=0.4.12
jax>=0.4.24
jaxlib>=0.4.24
equinox>=0.10.6
optax>=0.1.5
evosax>=0.1.4
Expand Down
22 changes: 22 additions & 0 deletions src/flowMC/nfmodel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,28 @@ def train(

return rng, best_model, best_state, loss_values

def to_precision(self, precision: str = "float32"):
"""Convert all parameters to a given precision.
!!! warning
This function is **experimental** and may change in the future.
Args:
precision (str): Precision to convert to.
Returns:
eqx.Module: Model with parameters converted to the given precision.
"""

precisions_dict = {"float16": jnp.float16, "bfloat16": jnp.bfloat16, "float32": jnp.float32, "float64": jnp.float64}
try:
precision_format = precisions_dict[precision.lower()]
except KeyError:
raise ValueError(f"Precision {precision} not supported. Choose from {precisions_dict.keys()}")
dynamic_model, static_model = eqx.partition(self, eqx.is_array)
dynamic_model = jax.tree.map(lambda x: x.astype(precision_format), dynamic_model)
return eqx.combine(dynamic_model, static_model)


class Bijection(eqx.Module):
"""
Expand Down

0 comments on commit 21b1bcc

Please sign in to comment.