Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert committed Oct 11, 2024
1 parent d9444ba commit f56fe50
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 141 deletions.
172 changes: 67 additions & 105 deletions .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,37 @@ on:
schedule:
- cron: "0 13 * * 1" # Every Monday at 9AM EST

jobs:
debug:
jobs:
lint-and-typecheck:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: pyproject.toml
- name: Setup environment
run: |
python -m pip install --upgrade pip
pip install ".[dev,jeig]"
- name: Lint Python files
run: |
find . -name "*.py" | xargs black --check
find . -name "*.py" | xargs isort --profile black --check-only
- name: Typecheck with mypy
run: |
mypy src
mypy examples
- name: Validate docstrings
run: |
darglint src --strictness=short --ignore-raise=ValueError
darglint examples --strictness=short --ignore-raise=ValueError
test-fmmax:
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -26,11 +53,10 @@ jobs:
python -m pip install --upgrade pip
pip install ".[tests,dev]"
- name: Test fmmax
run: pytest tests/fmmax/test_debug.py
run: pytest tests/fmmax

debug-jeig:
test-fmmax-jeig:
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -45,105 +71,41 @@ jobs:
python -m pip install --upgrade pip
pip install ".[tests,dev,jeig]"
- name: Test fmmax
run: pytest tests/fmmax/test_debug.py

# lint-and-typecheck:
# runs-on: ubuntu-latest
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# - name: Setup python
# uses: actions/setup-python@v5
# with:
# python-version: "3.10"
# cache: "pip"
# cache-dependency-path: pyproject.toml
# - name: Setup environment
# run: |
# python -m pip install --upgrade pip
# pip install ".[dev,jeig]"
# - name: Lint Python files
# run: |
# find . -name "*.py" | xargs black --check
# find . -name "*.py" | xargs isort --profile black --check-only
# - name: Typecheck with mypy
# run: |
# mypy src
# mypy examples
# - name: Validate docstrings
# run: |
# darglint src --strictness=short --ignore-raise=ValueError
# darglint examples --strictness=short --ignore-raise=ValueError

# test-fmmax:
# runs-on: ubuntu-latest
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# - name: Setup python
# uses: actions/setup-python@v5
# with:
# python-version: "3.10"
# cache: "pip"
# cache-dependency-path: pyproject.toml
# - name: Setup environment
# run: |
# python -m pip install --upgrade pip
# pip install ".[tests,dev]"
# - name: Test fmmax
# run: pytest tests/fmmax
run: pytest tests/fmmax

# test-fmmax-jeig:
# runs-on: ubuntu-latest
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# - name: Setup python
# uses: actions/setup-python@v5
# with:
# python-version: "3.10"
# cache: "pip"
# cache-dependency-path: pyproject.toml
# - name: Setup environment
# run: |
# python -m pip install --upgrade pip
# pip install ".[tests,dev,jeig]"
# - name: Test fmmax
# run: pytest tests/fmmax

# test-grcwa:
# runs-on: ubuntu-latest
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# - name: Setup python
# uses: actions/setup-python@v5
# with:
# python-version: "3.10"
# cache: "pip"
# cache-dependency-path: pyproject.toml
# - name: Setup environment
# run: |
# python -m pip install --upgrade pip
# pip install ".[tests,dev]"
# - name: Test grcwa
# run: pytest tests/grcwa
test-grcwa:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: pyproject.toml
- name: Setup environment
run: |
python -m pip install --upgrade pip
pip install ".[tests,dev]"
- name: Test grcwa
run: pytest tests/grcwa

# test-examples:
# runs-on: ubuntu-latest
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# - name: Setup python
# uses: actions/setup-python@v5
# with:
# python-version: "3.10"
# cache: "pip"
# cache-dependency-path: pyproject.toml
# - name: Setup environment
# run: |
# python -m pip install --upgrade pip
# pip install ".[tests,dev]"
# - name: Test examples
# run: python -m pytest tests/examples
test-examples:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: pyproject.toml
- name: Setup environment
run: |
python -m pip install --upgrade pip
pip install ".[tests,dev]"
- name: Test examples
run: python -m pytest tests/examples

29 changes: 15 additions & 14 deletions src/fmmax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,22 @@ def eig(

def _eig_jax(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Eigendecomposition using `jax.numpy.linalg.eig`."""
eigval, eigvec = jax.pure_callback(
_eig_jax_cpu,
(
jnp.ones(matrix.shape[:-1], dtype=complex), # Eigenvalues
jnp.ones(matrix.shape, dtype=complex), # Eigenvectors
),
matrix.astype(complex),
vectorized=True,
)
return jnp.asarray(eigval), jnp.asarray(eigvec)
if jax.devices()[0] == jax.devices("cpu")[0]:
return jnp.linalg.eig(matrix)
else:
return jax.pure_callback(
_eig_jax_cpu,
(
jnp.ones(matrix.shape[:-1], dtype=complex), # Eigenvalues
jnp.ones(matrix.shape, dtype=complex), # Eigenvectors
),
matrix.astype(complex),
vectorized=True,
)


# Define jax eigendecomposition that runs on CPU. Note that the compilation takes
# place at module import time. If the `jit` is inside a function, deadlocks can occur.
with jax.default_device(jax.devices("cpu")[0]):
_eig_jax_cpu = jax.jit(jnp.linalg.eig)

Expand All @@ -158,10 +162,7 @@ def _eig(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
if _JEIG_AVAILABLE:
return jeig.eig(matrix)
else:
if jax.devices()[0] == jax.devices("cpu")[0]:
return jnp.linalg.eig(matrix)
else:
return _eig_jax(matrix)
return _eig_jax(matrix)


def _eig_fwd(
Expand Down
22 changes: 0 additions & 22 deletions tests/fmmax/test_debug.py

This file was deleted.

0 comments on commit f56fe50

Please sign in to comment.