From ad93513b96e8e852af7862ee3c1e96a0d1dfd552 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Wed, 1 Feb 2023 00:47:25 -0800 Subject: [PATCH] Fix get_batch_axes() --- jaxlie/utils/_utils.py | 4 +++- setup.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/jaxlie/utils/_utils.py b/jaxlie/utils/_utils.py index 4c1573a..e39bfdf 100644 --- a/jaxlie/utils/_utils.py +++ b/jaxlie/utils/_utils.py @@ -46,7 +46,9 @@ def _wrap(cls: Type[T]) -> Type[T]: # JIT all methods. for f in filter( - lambda f: not f.startswith("_") and callable(getattr(cls, f)), + lambda f: not f.startswith("_") + and callable(getattr(cls, f)) + and f != "get_batch_axes", # Avoid returning tracers. dir(cls), ): setattr(cls, f, jax.jit(getattr(cls, f))) diff --git a/setup.py b/setup.py index 01207a7..c0e0fa2 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="jaxlie", - version="1.3.0", + version="1.3.1", description="Matrix Lie groups in Jax", long_description=long_description, long_description_content_type="text/markdown",