Skip to content

Commit

Permalink
Implement CupyArrayContext
Browse files Browse the repository at this point in the history
Co-authored-by: Kaushik Kulkarni <[email protected]>
  • Loading branch information
matthiasdiener and kaushikcfd committed Feb 16, 2024
1 parent e53fa90 commit 55b4cbf
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 11 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ jobs:
. ./ci-support-v0
build_py_project_in_conda_env
python -m pip install mypy pytest
conda install cupy
./run-mypy.sh
pytest3_pocl:
Expand Down
5 changes: 4 additions & 1 deletion arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike,
tag_axes)
from .impl.cupy import CupyArrayContext
from .impl.jax import EagerJAXArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
Expand Down Expand Up @@ -106,7 +107,9 @@
"PytestArrayContextFactory",
"PytestPyOpenCLArrayContextFactory",
"pytest_generate_tests_for_array_contexts",
"pytest_generate_tests_for_pyopencl_array_context"
"pytest_generate_tests_for_pyopencl_array_context",

"CupyArrayContext",
)


Expand Down
136 changes: 136 additions & 0 deletions arraycontext/impl/cupy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
.. currentmodule:: arraycontext
A mod :`cupy`-based array context.
.. autoclass:: CupyArrayContext
"""
__copyright__ = """
Copyright (C) 2024 University of Illinois Board of Trustees
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from collections.abc import Mapping


try:
import cupy as cp # type: ignore[import-untyped]
except ModuleNotFoundError:
pass

import loopy as lp

from arraycontext.container.traversal import (
rec_map_array_container, with_array_context)
from arraycontext.context import ArrayContext


class CupyArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses :mod:`cupy.ndarray` to represent arrays
.. automethod:: __init__
"""
def __init__(self):
super().__init__()
self._loopy_transform_cache: \
Mapping["lp.TranslationUnit", "lp.TranslationUnit"] = {}

self.array_types = (cp.ndarray,)

def _get_fake_numpy_namespace(self):
from .fake_numpy import CupyFakeNumpyNamespace
return CupyFakeNumpyNamespace(self)

# {{{ ArrayContext interface

def clone(self):
return type(self)()

def empty(self, shape, dtype):
return cp.empty(shape, dtype=dtype)

def zeros(self, shape, dtype):
return cp.zeros(shape, dtype)

def from_numpy(self, np_array):
return cp.array(np_array)

def to_numpy(self, array):
return cp.asnumpy(array)

def call_loopy(self, t_unit, **kwargs):
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
try:
t_unit = self._loopy_transform_cache[t_unit]
except KeyError:
orig_t_unit = t_unit
t_unit = self.transform_loopy_program(t_unit)
self._loopy_transform_cache[orig_t_unit] = t_unit
del orig_t_unit

_, result = t_unit(**kwargs)

return result

def freeze(self, array):
def _freeze(ary):
return cp.asnumpy(ary)

return with_array_context(rec_map_array_container(_freeze, array), actx=None)

def thaw(self, array):
def _thaw(ary):
return cp.array(ary)

return with_array_context(rec_map_array_container(_thaw, array), actx=self)

# }}}

def transform_loopy_program(self, t_unit):
raise ValueError("CupyArrayContext does not implement "
"transform_loopy_program. Sub-classes are supposed "
"to implement it.")

def tag(self, tags, array):
# No tagging support in CupyArrayContext
return array

def tag_axis(self, iaxis, tags, array):
return array

def einsum(self, spec, *args, arg_names=None, tagged=()):
return cp.einsum(spec, *args)

@property
def permits_inplace_modification(self):
return True

@property
def supports_nonscalar_broadcasting(self):
return True

@property
def permits_advanced_indexing(self):
return True
156 changes: 156 additions & 0 deletions arraycontext/impl/cupy/fake_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
__copyright__ = """
Copyright (C) 2024 University of Illinois Board of Trustees
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, reduce

import cupy as cp # type: ignore[import-untyped]

from arraycontext.container import is_array_container
from arraycontext.container.traversal import (
multimap_reduce_array_container, rec_map_array_container,
rec_map_reduce_array_container, rec_multimap_array_container,
rec_multimap_reduce_array_container)
from arraycontext.fake_numpy import (
BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace)


class CupyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass


_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan",
"sinh", "cosh", "tanh", "exp", "log", "log10", "isnan",
"sqrt", "concatenate", "transpose",
"ones_like", "maximum", "minimum", "where", "conj", "arctan2",
}


class CupyFakeNumpyNamespace(BaseFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`CupyArrayContext`.
"""
def _get_fake_numpy_linalg_namespace(self):
return CupyFakeNumpyLinalgNamespace(self._array_context)

def __getattr__(self, name):

if name in _NUMPY_UFUNCS:
from functools import partial
return partial(rec_multimap_array_container,
getattr(cp, name))

raise NotImplementedError

def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(sum, partial(cp.sum,
axis=axis,
dtype=dtype),
a)

def min(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, cp.minimum), partial(cp.amin, axis=axis), a)

def max(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, cp.maximum), partial(cp.amax, axis=axis), a)

def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: cp.stack(args, axis=axis),
*arrays)

def broadcast_to(self, array, shape):
return rec_map_array_container(partial(cp.broadcast_to, shape=shape), array)

# {{{ relational operators

def equal(self, x, y):
return rec_multimap_array_container(cp.equal, x, y)

def not_equal(self, x, y):
return rec_multimap_array_container(cp.not_equal, x, y)

def greater(self, x, y):
return rec_multimap_array_container(cp.greater, x, y)

def greater_equal(self, x, y):
return rec_multimap_array_container(cp.greater_equal, x, y)

def less(self, x, y):
return rec_multimap_array_container(cp.less, x, y)

def less_equal(self, x, y):
return rec_multimap_array_container(cp.less_equal, x, y)

# }}}

def ravel(self, a, order="C"):
return rec_map_array_container(partial(cp.ravel, order=order), a)

def vdot(self, x, y, dtype=None):
if dtype is not None:
raise NotImplementedError("only 'dtype=None' supported.")

return rec_multimap_reduce_array_container(sum, cp.vdot, x, y)

def any(self, a):
return rec_map_reduce_array_container(partial(reduce, cp.logical_or),
lambda subary: cp.any(subary), a)

def all(self, a):
return rec_map_reduce_array_container(partial(reduce, cp.logical_and),
lambda subary: cp.all(subary), a)

def array_equal(self, a, b):
if type(a) is not type(b):
return False
elif not is_array_container(a):
if a.shape != b.shape:
return False
else:
return cp.all(cp.equal(a, b))
else:
try:
return multimap_reduce_array_container(partial(reduce,
cp.logical_and),
self.array_equal, a, b)
except TypeError:
return True

def zeros_like(self, ary):
return rec_multimap_array_container(cp.zeros_like, ary)

def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: ary.reshape(newshape, order=order),
a)

def arange(self, *args, **kwargs):
return cp.arange(*args, **kwargs)

def linspace(self, *args, **kwargs):
return cp.linspace(*args, **kwargs)

# vim: fdm=marker
18 changes: 18 additions & 0 deletions arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,23 @@ def __str__(self):
return "<PytatoJAXArrayContext>"


class _PytestCupyArrayContextFactory(PytestArrayContextFactory):
@classmethod
def is_available(cls) -> bool:
try:
import cupy # type: ignore[import-untyped] # noqa: F401
return True
except ImportError:
return False

def __call__(self):
from arraycontext import CupyArrayContext
return CupyArrayContext()

def __str__(self):
return "<CupyArrayContext>"


_ARRAY_CONTEXT_FACTORY_REGISTRY: \
Dict[str, Type[PytestArrayContextFactory]] = {
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
Expand All @@ -232,6 +249,7 @@ def __str__(self):
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
"eagerjax": _PytestEagerJaxArrayContextFactory,
"cupy": _PytestCupyArrayContextFactory,
}


Expand Down
Loading

0 comments on commit 55b4cbf

Please sign in to comment.