Skip to content

Commit

Permalink
Add array API inspection utilities (#592)
Browse files Browse the repository at this point in the history
* Add array API inspection utilities

* JAX default_device() returns None
  • Loading branch information
tomwhite authored Oct 7, 2024
1 parent 7ca5ae9 commit 5a79c9d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 7 deletions.
10 changes: 5 additions & 5 deletions api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | Multi-axis | :white_check_mark: | | |
| | Boolean array | :x: | | Shape is data dependent, [#73](https://github.com/cubed-dev/cubed/issues/73) |
| Indexing Functions | `take` | :white_check_mark: | 2022.12 | |
| Inspection | `capabilities` | :x: | 2023.12 | |
| | `default_device` | :x: | 2023.12 | |
| | `default_dtypes` | :x: | 2023.12 | |
| | `devices` | :x: | 2023.12 | |
| | `dtypes` | :x: | 2023.12 | |
| Inspection | `capabilities` | :white_check_mark: | 2023.12 | |
| | `default_device` | :white_check_mark: | 2023.12 | |
| | `default_dtypes` | :white_check_mark: | 2023.12 | |
| | `devices` | :white_check_mark: | 2023.12 | |
| | `dtypes` | :white_check_mark: | 2023.12 | |
| Linear Algebra Functions | `matmul` | :white_check_mark: | | |
| | `matrix_transpose` | :white_check_mark: | | |
| | `tensordot` | :white_check_mark: | | |
Expand Down
5 changes: 4 additions & 1 deletion cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@

__array_api_version__ = "2022.12"

__all__ += ["__array_api_version__"]
from .array_api.inspection import __array_namespace_info__

__all__ += ["__array_api_version__", "__array_namespace_info__"]


from .array_api.array_object import Array

Expand Down
4 changes: 3 additions & 1 deletion cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

__array_api_version__ = "2022.12"

__all__ += ["__array_api_version__"]
from .inspection import __array_namespace_info__

__all__ += ["__array_api_version__", "__array_namespace_info__"]

from .array_object import Array

Expand Down
24 changes: 24 additions & 0 deletions cubed/array_api/inspection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from cubed.backend_array_api import namespace as nxp


class __array_namespace_info__:
# capabilities are determined by Cubed, not the backend array API
def capabilities(self):
return {
"boolean indexing": False,
"data-dependent shapes": False,
}

# devices and dtypes are determined by the backend array API

def default_device(self):
return nxp.__array_namespace_info__().default_device()

def default_dtypes(self, *, device=None):
return nxp.__array_namespace_info__().default_dtypes(device=device)

def devices(self):
return nxp.__array_namespace_info__().devices()

def dtypes(self, *, device=None, kind=None):
return nxp.__array_namespace_info__().dtypes(device=device, kind=kind)
31 changes: 31 additions & 0 deletions cubed/tests/test_inspection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import cubed.array_api as xp

info = xp.__array_namespace_info__()


def test_capabilities():
capabilities = info.capabilities()
assert capabilities["boolean indexing"] is False
assert capabilities["data-dependent shapes"] is False


def test_default_device():
assert (
info.default_device() is None or info.default_device() == xp.asarray(0).device
)


def test_default_dtypes():
dtypes = info.default_dtypes()
assert dtypes["real floating"] == xp.asarray(0.0).dtype
assert dtypes["complex floating"] == xp.asarray(0.0j).dtype
assert dtypes["integral"] == xp.asarray(0).dtype
assert dtypes["indexing"] == xp.argmax(xp.zeros(10)).dtype


def test_devices():
assert len(info.devices()) > 0


def test_dtypes():
assert len(info.dtypes()) > 0

0 comments on commit 5a79c9d

Please sign in to comment.