From 9e72bf09a8871b55a35cd0d6f61180e639733c7b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 16 Feb 2024 16:40:22 -0600 Subject: [PATCH] Implement CupyArrayContext Co-authored-by: Kaushik Kulkarni --- arraycontext/__init__.py | 5 +- arraycontext/impl/cupy/__init__.py | 136 +++++++++++++++++++++++ arraycontext/impl/cupy/fake_numpy.py | 156 +++++++++++++++++++++++++++ arraycontext/pytest.py | 18 ++++ test/test_arraycontext.py | 49 +++++++-- 5 files changed, 353 insertions(+), 11 deletions(-) create mode 100644 arraycontext/impl/cupy/__init__.py create mode 100644 arraycontext/impl/cupy/fake_numpy.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index b01b9917..a0a9deb6 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -50,6 +50,7 @@ Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike, tag_axes) +from .impl.cupy import CupyArrayContext from .impl.jax import EagerJAXArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext @@ -106,7 +107,9 @@ "PytestArrayContextFactory", "PytestPyOpenCLArrayContextFactory", "pytest_generate_tests_for_array_contexts", - "pytest_generate_tests_for_pyopencl_array_context" + "pytest_generate_tests_for_pyopencl_array_context", + + "CupyArrayContext", ) diff --git a/arraycontext/impl/cupy/__init__.py b/arraycontext/impl/cupy/__init__.py new file mode 100644 index 00000000..5c93c31b --- /dev/null +++ b/arraycontext/impl/cupy/__init__.py @@ -0,0 +1,136 @@ +""" +.. currentmodule:: arraycontext + + +A mod :`cupy`-based array context. + +.. autoclass:: CupyArrayContext +""" +__copyright__ = """ +Copyright (C) 2024 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from collections.abc import Mapping + + +try: + import cupy as cp # type: ignore[import] +except ModuleNotFoundError: + pass + +import loopy as lp + +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) +from arraycontext.context import ArrayContext + + +class CupyArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :mod:`cupy.ndarray` to represent arrays + + + .. automethod:: __init__ + """ + def __init__(self): + super().__init__() + self._loopy_transform_cache: \ + Mapping["lp.TranslationUnit", "lp.TranslationUnit"] = {} + + self.array_types = (cp.ndarray,) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import CupyFakeNumpyNamespace + return CupyFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def clone(self): + return type(self)() + + def empty(self, shape, dtype): + return cp.empty(shape, dtype=dtype) + + def zeros(self, shape, dtype): + return cp.zeros(shape, dtype) + + def from_numpy(self, np_array): + return cp.array(np_array) + + def to_numpy(self, array): + return cp.asnumpy(array) + + def call_loopy(self, t_unit, **kwargs): + t_unit = t_unit.copy(target=lp.ExecutableCTarget()) + try: + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + orig_t_unit = t_unit + t_unit = self.transform_loopy_program(t_unit) + self._loopy_transform_cache[orig_t_unit] = t_unit + del orig_t_unit + + _, result = t_unit(**kwargs) + + return result + + def freeze(self, array): + def _freeze(ary): + return cp.asnumpy(ary) + + return with_array_context(rec_map_array_container(_freeze, array), actx=None) + + def thaw(self, array): + def _thaw(ary): + return cp.array(ary) + + return with_array_context(rec_map_array_container(_thaw, array), actx=self) + + # }}} + + def transform_loopy_program(self, t_unit): + raise ValueError("CupyArrayContext does not implement " + "transform_loopy_program. Sub-classes are supposed " + "to implement it.") + + def tag(self, tags, array): + # No tagging support in CupyArrayContext + return array + + def tag_axis(self, iaxis, tags, array): + return array + + def einsum(self, spec, *args, arg_names=None, tagged=()): + return cp.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True diff --git a/arraycontext/impl/cupy/fake_numpy.py b/arraycontext/impl/cupy/fake_numpy.py new file mode 100644 index 00000000..0c4461e6 --- /dev/null +++ b/arraycontext/impl/cupy/fake_numpy.py @@ -0,0 +1,156 @@ +__copyright__ = """ +Copyright (C) 2024 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +import cupy as cp # type: ignore[import] + +from arraycontext.container import is_array_container +from arraycontext.container.traversal import ( + multimap_reduce_array_container, rec_map_array_container, + rec_map_reduce_array_container, rec_multimap_array_container, + rec_multimap_reduce_array_container) +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace) + + +class CupyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", + "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", + "sqrt", "concatenate", "transpose", + "ones_like", "maximum", "minimum", "where", "conj", "arctan2", + } + + +class CupyFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`CupyArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return CupyFakeNumpyLinalgNamespace(self._array_context) + + def __getattr__(self, name): + + if name in _NUMPY_UFUNCS: + from functools import partial + return partial(rec_multimap_array_container, + getattr(cp, name)) + + raise NotImplementedError + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container(sum, partial(cp.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, cp.minimum), partial(cp.amin, axis=axis), a) + + def max(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, cp.maximum), partial(cp.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: cp.stack(args, axis=axis), + *arrays) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(cp.broadcast_to, shape=shape), array) + + # {{{ relational operators + + def equal(self, x, y): + return rec_multimap_array_container(cp.equal, x, y) + + def not_equal(self, x, y): + return rec_multimap_array_container(cp.not_equal, x, y) + + def greater(self, x, y): + return rec_multimap_array_container(cp.greater, x, y) + + def greater_equal(self, x, y): + return rec_multimap_array_container(cp.greater_equal, x, y) + + def less(self, x, y): + return rec_multimap_array_container(cp.less, x, y) + + def less_equal(self, x, y): + return rec_multimap_array_container(cp.less_equal, x, y) + + # }}} + + def ravel(self, a, order="C"): + return rec_map_array_container(partial(cp.ravel, order=order), a) + + def vdot(self, x, y, dtype=None): + if dtype is not None: + raise NotImplementedError("only 'dtype=None' supported.") + + return rec_multimap_reduce_array_container(sum, cp.vdot, x, y) + + def any(self, a): + return rec_map_reduce_array_container(partial(reduce, cp.logical_or), + lambda subary: cp.any(subary), a) + + def all(self, a): + return rec_map_reduce_array_container(partial(reduce, cp.logical_and), + lambda subary: cp.all(subary), a) + + def array_equal(self, a, b): + if type(a) is not type(b): + return False + elif not is_array_container(a): + if a.shape != b.shape: + return False + else: + return cp.all(cp.equal(a, b)) + else: + try: + return multimap_reduce_array_container(partial(reduce, + cp.logical_and), + self.array_equal, a, b) + except TypeError: + return True + + def zeros_like(self, ary): + return rec_multimap_array_container(cp.zeros_like, ary) + + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: ary.reshape(newshape, order=order), + a) + + def arange(self, *args, **kwargs): + return cp.arange(*args, **kwargs) + + def linspace(self, *args, **kwargs): + return cp.linspace(*args, **kwargs) + +# vim: fdm=marker diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 4fce5885..548b6bda 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -224,6 +224,23 @@ def __str__(self): return "" +class _PytestCupyArrayContextFactory(PytestArrayContextFactory): + @classmethod + def is_available(cls) -> bool: + try: + import cupy # type: ignore[import] # noqa: F401 + return True + except ImportError: + return False + + def __call__(self): + from arraycontext import CupyArrayContext + return CupyArrayContext() + + def __str__(self): + return "" + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -232,6 +249,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "cupy": _PytestCupyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index e53f4295..712d5f24 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -30,13 +30,15 @@ from pytools.obj_array import make_obj_array from arraycontext import ( # noqa: F401 - ArrayContainer, ArrayContext, EagerJAXArrayContext, FirstAxisIsElementsTag, - PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, - deserialize_container, pytest_generate_tests_for_array_contexts, - serialize_container, tag_axes, with_array_context, with_container_arithmetic) + ArrayContainer, ArrayContext, CupyArrayContext, EagerJAXArrayContext, + FirstAxisIsElementsTag, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, + dataclass_array_container, deserialize_container, + pytest_generate_tests_for_array_contexts, serialize_container, tag_axes, + with_array_context, with_container_arithmetic) from arraycontext.pytest import ( - _PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, - _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory) + _PytestCupyArrayContextFactory, _PytestEagerJaxArrayContextFactory, + _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, + _PytestPytatoPyOpenCLArrayContextFactory) logger = logging.getLogger(__name__) @@ -84,6 +86,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestCupyArrayContextFactory, ]) @@ -412,6 +415,11 @@ def test_array_context_np_like(actx_factory, sym_name, n_args, dtype): actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args) for c in (42.0,) + _get_test_containers(actx): + if (isinstance(actx, CupyArrayContext) + and isinstance(c, (int, float, complex))): + # CupyArrayContext does not support zeros_like/ones_like with + # Python scalars. + continue result = getattr(actx.np, sym_name)(c) result = actx.thaw(actx.freeze(result)) @@ -935,11 +943,13 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14): with pytest.raises(TypeError): dc_of_dofs + ary_dof - bcast_result = ary_dof + bcast_dc_of_dofs - bcast_dc_of_dofs + ary_dof + if not isinstance(actx, CupyArrayContext): + # CupyArrayContext does not support operations between numpy and cupy arrays. + bcast_result = ary_dof + bcast_dc_of_dofs + bcast_dc_of_dofs + ary_dof - assert actx.to_numpy(actx.np.linalg.norm(bcast_result.mass - - 2*ary_of_dofs)) < 1e-8 + assert actx.to_numpy(actx.np.linalg.norm(bcast_result.mass + - 2*ary_of_dofs)) < 1e-8 mock_gradient = MyContainerDOFBcast( name="yo", @@ -1129,6 +1139,9 @@ def test_flatten_with_leaf_class(actx_factory): def test_numpy_conversion(actx_factory): actx = actx_factory() + if isinstance(actx, CupyArrayContext): + pytest.skip("Irrelevant tests for CupyArrayContext. " + "Also, CupyArrayContextdoes not support object arrays.") nelements = 42 ac = MyContainer( @@ -1219,6 +1232,8 @@ def scale_and_orthogonalize(alpha, vel): def test_actx_compile(actx_factory): actx = actx_factory() + if isinstance(actx, CupyArrayContext): + pytest.skip("CupyArrayContext does not support object arrays") compiled_rhs = actx.compile(scale_and_orthogonalize) @@ -1236,6 +1251,8 @@ def test_actx_compile(actx_factory): def test_actx_compile_python_scalar(actx_factory): actx = actx_factory() + if isinstance(actx, CupyArrayContext): + pytest.skip("CupyArrayContext does not support object arrays") compiled_rhs = actx.compile(scale_and_orthogonalize) @@ -1253,6 +1270,8 @@ def test_actx_compile_python_scalar(actx_factory): def test_actx_compile_kwargs(actx_factory): actx = actx_factory() + if isinstance(actx, CupyArrayContext): + pytest.skip("CupyArrayContext does not support object arrays") compiled_rhs = actx.compile(scale_and_orthogonalize) @@ -1273,6 +1292,8 @@ def test_actx_compile_with_tuple_output_keys(actx_factory): # key stringification logic. from arraycontext import from_numpy, to_numpy actx = actx_factory() + if isinstance(actx, CupyArrayContext): + pytest.skip("CupyArrayContext does not support object arrays") def my_rhs(scale, vel): result = np.empty((1, 1), dtype=object) @@ -1337,6 +1358,8 @@ def array_context(self): def test_leaf_array_type_broadcasting(actx_factory): # test support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() + if isinstance(actx, CupyArrayContext): + pytest.skip("CupyArrayContext has no leaf array type broadcasting support") foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, ))) bar = foo + 4 @@ -1550,6 +1573,9 @@ def test_tagging(actx_factory): if isinstance(actx, EagerJAXArrayContext): pytest.skip("Eager JAX has no tagging support") + if isinstance(actx, CupyArrayContext): + pytest.skip("CupyArrayContext has no tagging support") + from pytools.tag import Tag class ExampleTag(Tag): @@ -1596,6 +1622,9 @@ def test_linspace(actx_factory, args, kwargs): actx = actx_factory() + if isinstance(actx, CupyArrayContext) and kwargs.get("dtype") == np.complex128: + pytest.skip("CupyArrayContext does not support complex args to linspace") + actx_linspace = actx.to_numpy(actx.np.linspace(*args, **kwargs)) np_linspace = np.linspace(*args, **kwargs)