-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add array API inspection utilities (#592)
* Add array API inspection utilities * JAX default_device() returns None
- Loading branch information
Showing
5 changed files
with
67 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from cubed.backend_array_api import namespace as nxp | ||
|
||
|
||
class __array_namespace_info__: | ||
# capabilities are determined by Cubed, not the backend array API | ||
def capabilities(self): | ||
return { | ||
"boolean indexing": False, | ||
"data-dependent shapes": False, | ||
} | ||
|
||
# devices and dtypes are determined by the backend array API | ||
|
||
def default_device(self): | ||
return nxp.__array_namespace_info__().default_device() | ||
|
||
def default_dtypes(self, *, device=None): | ||
return nxp.__array_namespace_info__().default_dtypes(device=device) | ||
|
||
def devices(self): | ||
return nxp.__array_namespace_info__().devices() | ||
|
||
def dtypes(self, *, device=None, kind=None): | ||
return nxp.__array_namespace_info__().dtypes(device=device, kind=kind) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import cubed.array_api as xp | ||
|
||
info = xp.__array_namespace_info__() | ||
|
||
|
||
def test_capabilities(): | ||
capabilities = info.capabilities() | ||
assert capabilities["boolean indexing"] is False | ||
assert capabilities["data-dependent shapes"] is False | ||
|
||
|
||
def test_default_device(): | ||
assert ( | ||
info.default_device() is None or info.default_device() == xp.asarray(0).device | ||
) | ||
|
||
|
||
def test_default_dtypes(): | ||
dtypes = info.default_dtypes() | ||
assert dtypes["real floating"] == xp.asarray(0.0).dtype | ||
assert dtypes["complex floating"] == xp.asarray(0.0j).dtype | ||
assert dtypes["integral"] == xp.asarray(0).dtype | ||
assert dtypes["indexing"] == xp.argmax(xp.zeros(10)).dtype | ||
|
||
|
||
def test_devices(): | ||
assert len(info.devices()) > 0 | ||
|
||
|
||
def test_dtypes(): | ||
assert len(info.dtypes()) > 0 |