diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 4fce5885..b1bbec95 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -189,7 +189,7 @@ def is_available(cls) -> bool: return False def __call__(self): - from jax.config import config + from jax import config from arraycontext import EagerJAXArrayContext config.update("jax_enable_x64", True) @@ -214,7 +214,7 @@ def is_available(cls) -> bool: return False def __call__(self): - from jax.config import config + from jax import config from arraycontext import PytatoJAXArrayContext config.update("jax_enable_x64", True)