diff --git a/api_status.md b/api_status.md index 2955ea0e..45cb046b 100644 --- a/api_status.md +++ b/api_status.md @@ -46,11 +46,11 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | | Multi-axis | :white_check_mark: | | | | | Boolean array | :x: | | Shape is data dependent, [#73](https://github.com/cubed-dev/cubed/issues/73) | | Indexing Functions | `take` | :white_check_mark: | 2022.12 | | -| Inspection | `capabilities` | :x: | 2023.12 | | -| | `default_device` | :x: | 2023.12 | | -| | `default_dtypes` | :x: | 2023.12 | | -| | `devices` | :x: | 2023.12 | | -| | `dtypes` | :x: | 2023.12 | | +| Inspection | `capabilities` | :white_check_mark: | 2023.12 | | +| | `default_device` | :white_check_mark: | 2023.12 | | +| | `default_dtypes` | :white_check_mark: | 2023.12 | | +| | `devices` | :white_check_mark: | 2023.12 | | +| | `dtypes` | :white_check_mark: | 2023.12 | | | Linear Algebra Functions | `matmul` | :white_check_mark: | | | | | `matrix_transpose` | :white_check_mark: | | | | | `tensordot` | :white_check_mark: | | | diff --git a/cubed/__init__.py b/cubed/__init__.py index 3cd2e6b4..a9e061d9 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -47,7 +47,10 @@ __array_api_version__ = "2022.12" -__all__ += ["__array_api_version__"] +from .array_api.inspection import __array_namespace_info__ + +__all__ += ["__array_api_version__", "__array_namespace_info__"] + from .array_api.array_object import Array diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index ea0a8c2b..a0c9f8cc 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -2,7 +2,9 @@ __array_api_version__ = "2022.12" -__all__ += ["__array_api_version__"] +from .inspection import __array_namespace_info__ + +__all__ += ["__array_api_version__", "__array_namespace_info__"] from .array_object import Array diff --git a/cubed/array_api/inspection.py b/cubed/array_api/inspection.py new file mode 100644 index 00000000..cc35d1c5 --- /dev/null +++ b/cubed/array_api/inspection.py @@ -0,0 +1,24 @@ +from cubed.backend_array_api import namespace as nxp + + +class __array_namespace_info__: + # capabilities are determined by Cubed, not the backend array API + def capabilities(self): + return { + "boolean indexing": False, + "data-dependent shapes": False, + } + + # devices and dtypes are determined by the backend array API + + def default_device(self): + return nxp.__array_namespace_info__().default_device() + + def default_dtypes(self, *, device=None): + return nxp.__array_namespace_info__().default_dtypes(device=device) + + def devices(self): + return nxp.__array_namespace_info__().devices() + + def dtypes(self, *, device=None, kind=None): + return nxp.__array_namespace_info__().dtypes(device=device, kind=kind) diff --git a/cubed/tests/test_inspection.py b/cubed/tests/test_inspection.py new file mode 100644 index 00000000..7d47fe95 --- /dev/null +++ b/cubed/tests/test_inspection.py @@ -0,0 +1,31 @@ +import cubed.array_api as xp + +info = xp.__array_namespace_info__() + + +def test_capabilities(): + capabilities = info.capabilities() + assert capabilities["boolean indexing"] is False + assert capabilities["data-dependent shapes"] is False + + +def test_default_device(): + assert ( + info.default_device() is None or info.default_device() == xp.asarray(0).device + ) + + +def test_default_dtypes(): + dtypes = info.default_dtypes() + assert dtypes["real floating"] == xp.asarray(0.0).dtype + assert dtypes["complex floating"] == xp.asarray(0.0j).dtype + assert dtypes["integral"] == xp.asarray(0).dtype + assert dtypes["indexing"] == xp.argmax(xp.zeros(10)).dtype + + +def test_devices(): + assert len(info.devices()) > 0 + + +def test_dtypes(): + assert len(info.dtypes()) > 0