diff --git a/pytransform3d/_array_api_compat.py b/pytransform3d/_array_api_compat.py new file mode 100644 index 000000000..e985990e1 --- /dev/null +++ b/pytransform3d/_array_api_compat.py @@ -0,0 +1,9 @@ +import numpy as np + + +def array_namespace(*args): + try: + import array_api_compat + return array_api_compat.array_namespace(*args) + except ImportError: + return np diff --git a/pytransform3d/rotations/_utils.py b/pytransform3d/rotations/_utils.py index dca93a2b2..047d9493e 100644 --- a/pytransform3d/rotations/_utils.py +++ b/pytransform3d/rotations/_utils.py @@ -1,9 +1,9 @@ """Utility functions for rotations.""" import warnings import math -import array_api_compat import numpy as np from ._constants import unitz, eps, two_pi +from .._array_api_compat import array_namespace def norm_vector(v): @@ -19,7 +19,7 @@ def norm_vector(v): u : array, shape (n,) nd unit vector with norm 1 or the zero vector """ - xp = array_api_compat.array_namespace(v) + xp = array_namespace(v) norm = xp.linalg.norm(v) if norm == 0.0: diff --git a/setup.py b/setup.py index 1e7984762..e7ea962a5 100644 --- a/setup.py +++ b/setup.py @@ -24,10 +24,10 @@ ], license='BSD-3-Clause', packages=find_packages(), - install_requires=["numpy", "scipy", "matplotlib", "lxml", - "array-api-compat"], + install_requires=["numpy", "scipy", "matplotlib", "lxml"], extras_require={ - "all": ["pydot", "trimesh", "pycollada", "open3d"], + "all": ["pydot", "trimesh", "pycollada", "open3d", + "array-api-compat"], "doc": ["numpydoc", "sphinx", "sphinx-gallery", "sphinx-bootstrap-theme"], "test": ["pytest", "pytest-cov"]