Skip to content

Commit

Permalink
fix for jax > 0.4.31 (#136)
Browse files Browse the repository at this point in the history
* fix for jax > 0.4.31

* remove outdated comment

* improve comments
  • Loading branch information
mfschubert authored Oct 11, 2024
1 parent a342a8b commit 695ed9b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ on:
jobs:
lint-and-typecheck:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down Expand Up @@ -39,6 +40,7 @@ jobs:
test-fmmax:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -57,6 +59,7 @@ jobs:

test-fmmax-jeig:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -75,6 +78,7 @@ jobs:

test-grcwa:
runs-on: ubuntu-latest
timeout-minutes: 2
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -93,6 +97,7 @@ jobs:

test-examples:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -108,4 +113,3 @@ jobs:
pip install ".[tests,dev]"
- name: Test examples
run: python -m pytest tests/examples

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ maintainers = [

# TODO add gpu channels
dependencies = [
"jax <= 0.4.31",
"jax <= 0.4.34",
"jaxlib",
"numpy",
]
Expand Down
42 changes: 22 additions & 20 deletions src/fmmax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,32 +135,34 @@ def eig(
return _eig(matrix)


def _eig_host_jax(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""

def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
# We force this computation to be performed on the cpu by jit-ing and
# explicitly specifying the device.
with jax.default_device(jax.devices("cpu")[0]):
return jax.jit(jnp.linalg.eig)(matrix)

return jax.pure_callback(
_eig_cpu,
(
jnp.ones(matrix.shape[:-1], dtype=complex), # Eigenvalues
jnp.ones(matrix.shape, dtype=complex), # Eigenvectors
),
matrix.astype(complex),
vectorized=True,
)
def _eig_jax(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Eigendecomposition using `jax.numpy.linalg.eig`."""
# If using CPU backend, using `pure_callback` to call a jit-compiled version of
# `jnp.linalg.eig` is flaky and can cause deadlocks. Directly call it instead.
if jax.devices()[0] == jax.devices("cpu")[0]:
return jnp.linalg.eig(matrix)
else:
return jax.pure_callback(
_eig_jax_cpu,
(
jnp.ones(matrix.shape[:-1], dtype=complex), # Eigenvalues
jnp.ones(matrix.shape, dtype=complex), # Eigenvectors
),
matrix.astype(complex),
vectorized=True,
)


with jax.default_device(jax.devices("cpu")[0]):
_eig_jax_cpu = jax.jit(jnp.linalg.eig)


def _eig(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Eigendecomposition using `jeig` if available, and `_eig_host_jax` if not."""
"""Eigendecomposition using `jeig` if available, and `_eig_jax` if not."""
if _JEIG_AVAILABLE:
return jeig.eig(matrix)
else:
return _eig_host_jax(matrix)
return _eig_jax(matrix)


def _eig_fwd(
Expand Down

0 comments on commit 695ed9b

Please sign in to comment.