diff --git a/lab/types.py b/lab/types.py index 65d686d..52bcc01 100644 --- a/lab/types.py +++ b/lab/types.py @@ -82,10 +82,10 @@ def _module_attr(module, attr): _ag_tensor = ModuleType("autograd.tracer", "Box") # Define JAX module types. -if sys.version_info.minor <= 7: +if sys.version_info.minor <= 7: # pragma: specific no cover 3.8 3.9 3.10 3.11 # `jax` 0.4 deprecated Python 3.7 support. Rely on older JAX versions. _jax_tensor = ModuleType("jax.interpreters.xla", "DeviceArray") -else: +else: # pragma: specific no cover 3.7 _jax_tensor = ModuleType("jaxlib.xla_extension", "ArrayImpl") _jax_tracer = ModuleType("jax.core", "Tracer") _jax_dtype = ModuleType("jax._src.numpy.lax_numpy", "_ScalarMeta")