Skip to content

eguiraud/correctionlib-gradients

Repository files navigation

correctionlib-gradients

ci codecov pre-commit.ci status code style: black PyPI - Version PyPI - Python Version

A JAX-friendly, auto-differentiable, Python-only implementation of correctionlib correction evaluations.


Table of Contents

Installation

pip install correctionlib-gradients

Usage

  1. construct a CorrectionWithGradient object from a correctionlib.schemav2.Correction
  2. there is no point 2: you can use CorrectionWithGradient.evaluate as a normal JAX-friendly, auto-differentiable function

Example

import jax
import jax.numpy as jnp

from correctionlib import schemav2
from correctionlib_gradients import CorrectionWithGradient

# given a correctionlib schema:
formula_schema = schemav2.Correction(
    name="x squared",
    version=2,
    inputs=[schemav2.Variable(name="x", type="real")],
    output=schemav2.Variable(name="a scale", type="real"),
    data=schemav2.Formula(
        nodetype="formula",
        expression="x * x",
        parser="TFormula",
        variables=["x"],
    ),
)

# construct a CorrectionWithGradient
c = CorrectionWithGradient(formula_schema)

# use c.evaluate as a JAX-friendly, auto-differentiable function
value, grad = jax.value_and_grad(c.evaluate)(3.0)
assert jnp.isclose(value, 9.0)
assert jnp.isclose(grad, 6.0)

# for Formula corrections, jax.jit and jax.vmap work too
xs = jnp.array([3.0, 4.0])
values, grads = jax.vmap(jax.jit(jax.value_and_grad(c.evaluate)))(xs)
assert jnp.allclose(values, jnp.array([9.0, 16.0]))
assert jnp.allclose(grads, jnp.array([6.0, 8.0]))

Supported types of corrections

Currently the following corrections from correctionlib.schemav2 are supported:

  • Formula, including parametrical formulas
  • Binning with uniform or non-uniform bin edges and flow="clamp"; bin contents can be either:
    • all scalar values
    • all Formula or FormulaRef
  • scalar constants

Known limitations

Only the evaluation of Formula corrections is fully JAX traceable.

For other corrections, e.g. Binning, gradients can be computed (jax.grad works) but as JAX cannot trace the computation utilities such as jax.jit and jax.vmap will not work. np.vectorize can be used as an alternative to jax.vmap in these cases.

License

correctionlib-gradients is distributed under the terms of the BSD 3-Clause license.

About

Automatic differentiation for high-energy physics correction factor calculations.

Resources

License

Stars

Watchers

Forks

Languages