Skip to content

Commit

Permalink
cpu-specific codepath
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert committed Oct 11, 2024
1 parent d3798d6 commit d9444ba
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 25 deletions.
8 changes: 4 additions & 4 deletions src/fmmax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
except ModuleNotFoundError:
_JEIG_AVAILABLE = False


EIG_EPS_RELATIVE = 1e-12
EIG_EPS_MINIMUM = 1e-24

Expand Down Expand Up @@ -150,8 +149,6 @@ def _eig_jax(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
return jnp.asarray(eigval), jnp.asarray(eigvec)


# Define jax eigendecomposition that runs on CPU. Note that the compilation takes
# place at module import time. If the `jit` is inside a function, deadlocks can occur.
with jax.default_device(jax.devices("cpu")[0]):
_eig_jax_cpu = jax.jit(jnp.linalg.eig)

Expand All @@ -161,7 +158,10 @@ def _eig(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
if _JEIG_AVAILABLE:
return jeig.eig(matrix)
else:
return _eig_jax(matrix)
if jax.devices()[0] == jax.devices("cpu")[0]:
return jnp.linalg.eig(matrix)
else:
return _eig_jax(matrix)


def _eig_fwd(
Expand Down
22 changes: 1 addition & 21 deletions tests/fmmax/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,7 @@
jax.config.update("jax_enable_x64", True)


def _eig_jax(matrix):
"""Eigendecomposition using `jax.numpy.linalg.eig`."""
eigval, eigvec = jax.pure_callback(
_eig_jax_cpu,
(
jnp.ones(matrix.shape[:-1], dtype=complex), # Eigenvalues
jnp.ones(matrix.shape, dtype=complex), # Eigenvectors
),
matrix.astype(complex),
vectorized=True,
)
return jnp.asarray(eigval), jnp.asarray(eigvec)


# Define jax eigendecomposition that runs on CPU. Note that the compilation takes
# place at module import time. If the `jit` is inside a function, deadlocks can occur.
with jax.default_device(jax.devices("cpu")[0]):
_eig_jax_cpu = jax.jit(jnp.linalg.eig)


class DebugTest(unittest.TestCase):
def test_simple(self):
for _ in range(10):
_eig_jax(jnp.ones((440, 440)))
utils.eig(jnp.ones((440, 440)))

0 comments on commit d9444ba

Please sign in to comment.