diff --git a/src/fmmax/utils.py b/src/fmmax/utils.py index 8a902c8..8c8061a 100644 --- a/src/fmmax/utils.py +++ b/src/fmmax/utils.py @@ -20,7 +20,6 @@ except ModuleNotFoundError: _JEIG_AVAILABLE = False - EIG_EPS_RELATIVE = 1e-12 EIG_EPS_MINIMUM = 1e-24 @@ -150,8 +149,6 @@ def _eig_jax(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: return jnp.asarray(eigval), jnp.asarray(eigvec) -# Define jax eigendecomposition that runs on CPU. Note that the compilation takes -# place at module import time. If the `jit` is inside a function, deadlocks can occur. with jax.default_device(jax.devices("cpu")[0]): _eig_jax_cpu = jax.jit(jnp.linalg.eig) @@ -161,7 +158,10 @@ def _eig(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: if _JEIG_AVAILABLE: return jeig.eig(matrix) else: - return _eig_jax(matrix) + if jax.devices()[0] == jax.devices("cpu")[0]: + return jnp.linalg.eig(matrix) + else: + return _eig_jax(matrix) def _eig_fwd( diff --git a/tests/fmmax/test_debug.py b/tests/fmmax/test_debug.py index 9a4de11..9fa7029 100644 --- a/tests/fmmax/test_debug.py +++ b/tests/fmmax/test_debug.py @@ -16,27 +16,7 @@ jax.config.update("jax_enable_x64", True) -def _eig_jax(matrix): - """Eigendecomposition using `jax.numpy.linalg.eig`.""" - eigval, eigvec = jax.pure_callback( - _eig_jax_cpu, - ( - jnp.ones(matrix.shape[:-1], dtype=complex), # Eigenvalues - jnp.ones(matrix.shape, dtype=complex), # Eigenvectors - ), - matrix.astype(complex), - vectorized=True, - ) - return jnp.asarray(eigval), jnp.asarray(eigvec) - - -# Define jax eigendecomposition that runs on CPU. Note that the compilation takes -# place at module import time. If the `jit` is inside a function, deadlocks can occur. -with jax.default_device(jax.devices("cpu")[0]): - _eig_jax_cpu = jax.jit(jnp.linalg.eig) - - class DebugTest(unittest.TestCase): def test_simple(self): for _ in range(10): - _eig_jax(jnp.ones((440, 440))) + utils.eig(jnp.ones((440, 440)))