From 51d1ae8717f50032a9d147a444862e14c122ea6d Mon Sep 17 00:00:00 2001 From: Enrico Guiraud Date: Mon, 9 Oct 2023 18:23:11 -0600 Subject: [PATCH] Reformat markdown files with prettier --- ARCHITECTURE.md | 3 +++ README.md | 2 +- notes.md | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index fe011e2..04ea3e2 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -35,6 +35,7 @@ In the future, `correctionlib` might also add an `evaluate` signature that takes as per [this discussion](https://github.com/cms-nanoAOD/correctionlib/issues/166). ## Alternative considered: Python code generation + Instead of implementing a generic compute graph evaluator through which we pass JAX arrays, we could instead implement a code generator that takes a correction's compute graph and produces code for a function that evaluates the graph. @@ -50,6 +51,7 @@ much more awkward than stepping through a compute graph walk, and in a sense thi kind of code generation is exactly what `jax.jit` does. ## Duplication of functionality w.r.t. correctionlib + Since JAX has to do a forward pass that actually computes the correction's output in order to compute the corresponding gradients, this package ends up being a Python-only reimplementation of correctionlib (or at least of a differentiable subset of the supported @@ -60,6 +62,7 @@ be more complicated to propagate gradients through the C++ correction implementa original correctionlib package. Which brings us to... ## correctionlib autodifferentiation in C++ + `correctionlib-gradients`, by design, only serves Python users. That simplifies development significantly and lets us move quickly as we experiment with features and find out about roadblocks. diff --git a/README.md b/README.md index fd3b56c..5fc260a 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![PyPI - Version](https://img.shields.io/pypi/v/correctionlib-gradients.svg)](https://pypi.org/project/correctionlib-gradients) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/correctionlib-gradients.svg)](https://pypi.org/project/correctionlib-gradients) ------ +--- **Table of Contents** diff --git a/notes.md b/notes.md index fe5c655..1c62e1c 100644 --- a/notes.md +++ b/notes.md @@ -4,14 +4,16 @@ Notes about specific quirks, warts, design decisions that might need revisiting, See ARCHITECTURE.md for a broader, higher-level description of the package. ## Floating point precision + JAX aggressively casts to float32. Maybe in the case of correctionlib we prefer double precision whenever possible? -It can be configured with `from jax import config; config.update("jax_enable_x64", True)` +It can be configured with `from jax import config; config.update("jax_enable_x64", True)` but it also seems wrong to set it at global scope behind the users' back. With things as they are now, `test_scalar` would fail with `jit=True` because of loss of precision if we didn't configure JAX as above at the start of the test. ## `jax.jit` and correctionlib compute graph + We offer a `jit` flag to control whether the correction evaluation (and the one of the gradient) should be pass through `jax.jit`. Corrections that take strings as input cannot be jitted though (integers are ok). We can probably at least give users a heads up in this case.