Skip to content

Commit

Permalink
Reformat markdown files with prettier
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Oct 10, 2023
1 parent 6ec796b commit 51d1ae8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ In the future, `correctionlib` might also add an `evaluate` signature that takes
as per [this discussion](https://github.com/cms-nanoAOD/correctionlib/issues/166).

## Alternative considered: Python code generation

Instead of implementing a generic compute graph evaluator through which we pass
JAX arrays, we could instead implement a code generator that takes a correction's
compute graph and produces code for a function that evaluates the graph.
Expand All @@ -50,6 +51,7 @@ much more awkward than stepping through a compute graph walk, and in a sense thi
kind of code generation is exactly what `jax.jit` does.

## Duplication of functionality w.r.t. correctionlib

Since JAX has to do a forward pass that actually computes the correction's output in
order to compute the corresponding gradients, this package ends up being a Python-only
reimplementation of correctionlib (or at least of a differentiable subset of the supported
Expand All @@ -60,6 +62,7 @@ be more complicated to propagate gradients through the C++ correction implementa
original correctionlib package. Which brings us to...

## correctionlib autodifferentiation in C++

`correctionlib-gradients`, by design, only serves Python users.
That simplifies development significantly and lets us move quickly as we experiment
with features and find out about roadblocks.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
[![PyPI - Version](https://img.shields.io/pypi/v/correctionlib-gradients.svg)](https://pypi.org/project/correctionlib-gradients)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/correctionlib-gradients.svg)](https://pypi.org/project/correctionlib-gradients)

-----
---

**Table of Contents**

Expand Down
4 changes: 3 additions & 1 deletion notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ Notes about specific quirks, warts, design decisions that might need revisiting,
See ARCHITECTURE.md for a broader, higher-level description of the package.

## Floating point precision

JAX aggressively casts to float32.
Maybe in the case of correctionlib we prefer double precision whenever possible?
It can be configured with `from jax import config; config.update("jax_enable_x64", True)`
It can be configured with `from jax import config; config.update("jax_enable_x64", True)`
but it also seems wrong to set it at global scope behind the users' back.
With things as they are now, `test_scalar` would fail with `jit=True` because of loss of precision
if we didn't configure JAX as above at the start of the test.

## `jax.jit` and correctionlib compute graph

We offer a `jit` flag to control whether the correction evaluation (and the one of the gradient)
should be pass through `jax.jit`. Corrections that take strings as input cannot be jitted though
(integers are ok). We can probably at least give users a heads up in this case.

0 comments on commit 51d1ae8

Please sign in to comment.