diff --git a/test/test_jax.py b/test/test_jax.py index 5214735e1..1ce9a30cb 100644 --- a/test/test_jax.py +++ b/test/test_jax.py @@ -25,7 +25,7 @@ import pytato as pt pytest.importorskip("jax") -from jax.config import config +from jax import config config.update("jax_enable_x64", True)