-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Kaushik Kulkarni <[email protected]>
- Loading branch information
1 parent
e53fa90
commit 9e72bf0
Showing
5 changed files
with
353 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
""" | ||
.. currentmodule:: arraycontext | ||
A mod :`cupy`-based array context. | ||
.. autoclass:: CupyArrayContext | ||
""" | ||
__copyright__ = """ | ||
Copyright (C) 2024 University of Illinois Board of Trustees | ||
""" | ||
|
||
__license__ = """ | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
THE SOFTWARE. | ||
""" | ||
|
||
from collections.abc import Mapping | ||
|
||
|
||
try: | ||
import cupy as cp # type: ignore[import] | ||
except ModuleNotFoundError: | ||
pass | ||
|
||
import loopy as lp | ||
|
||
from arraycontext.container.traversal import ( | ||
rec_map_array_container, with_array_context) | ||
from arraycontext.context import ArrayContext | ||
|
||
|
||
class CupyArrayContext(ArrayContext): | ||
""" | ||
A :class:`ArrayContext` that uses :mod:`cupy.ndarray` to represent arrays | ||
.. automethod:: __init__ | ||
""" | ||
def __init__(self): | ||
super().__init__() | ||
self._loopy_transform_cache: \ | ||
Mapping["lp.TranslationUnit", "lp.TranslationUnit"] = {} | ||
|
||
self.array_types = (cp.ndarray,) | ||
|
||
def _get_fake_numpy_namespace(self): | ||
from .fake_numpy import CupyFakeNumpyNamespace | ||
return CupyFakeNumpyNamespace(self) | ||
|
||
# {{{ ArrayContext interface | ||
|
||
def clone(self): | ||
return type(self)() | ||
|
||
def empty(self, shape, dtype): | ||
return cp.empty(shape, dtype=dtype) | ||
|
||
def zeros(self, shape, dtype): | ||
return cp.zeros(shape, dtype) | ||
|
||
def from_numpy(self, np_array): | ||
return cp.array(np_array) | ||
|
||
def to_numpy(self, array): | ||
return cp.asnumpy(array) | ||
|
||
def call_loopy(self, t_unit, **kwargs): | ||
t_unit = t_unit.copy(target=lp.ExecutableCTarget()) | ||
try: | ||
t_unit = self._loopy_transform_cache[t_unit] | ||
except KeyError: | ||
orig_t_unit = t_unit | ||
t_unit = self.transform_loopy_program(t_unit) | ||
self._loopy_transform_cache[orig_t_unit] = t_unit | ||
del orig_t_unit | ||
|
||
_, result = t_unit(**kwargs) | ||
|
||
return result | ||
|
||
def freeze(self, array): | ||
def _freeze(ary): | ||
return cp.asnumpy(ary) | ||
|
||
return with_array_context(rec_map_array_container(_freeze, array), actx=None) | ||
|
||
def thaw(self, array): | ||
def _thaw(ary): | ||
return cp.array(ary) | ||
|
||
return with_array_context(rec_map_array_container(_thaw, array), actx=self) | ||
|
||
# }}} | ||
|
||
def transform_loopy_program(self, t_unit): | ||
raise ValueError("CupyArrayContext does not implement " | ||
"transform_loopy_program. Sub-classes are supposed " | ||
"to implement it.") | ||
|
||
def tag(self, tags, array): | ||
# No tagging support in CupyArrayContext | ||
return array | ||
|
||
def tag_axis(self, iaxis, tags, array): | ||
return array | ||
|
||
def einsum(self, spec, *args, arg_names=None, tagged=()): | ||
return cp.einsum(spec, *args) | ||
|
||
@property | ||
def permits_inplace_modification(self): | ||
return True | ||
|
||
@property | ||
def supports_nonscalar_broadcasting(self): | ||
return True | ||
|
||
@property | ||
def permits_advanced_indexing(self): | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
__copyright__ = """ | ||
Copyright (C) 2024 University of Illinois Board of Trustees | ||
""" | ||
|
||
__license__ = """ | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
THE SOFTWARE. | ||
""" | ||
from functools import partial, reduce | ||
|
||
import cupy as cp # type: ignore[import] | ||
|
||
from arraycontext.container import is_array_container | ||
from arraycontext.container.traversal import ( | ||
multimap_reduce_array_container, rec_map_array_container, | ||
rec_map_reduce_array_container, rec_multimap_array_container, | ||
rec_multimap_reduce_array_container) | ||
from arraycontext.fake_numpy import ( | ||
BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace) | ||
|
||
|
||
class CupyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): | ||
# Everything is implemented in the base class for now. | ||
pass | ||
|
||
|
||
_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", | ||
"sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", | ||
"sqrt", "concatenate", "transpose", | ||
"ones_like", "maximum", "minimum", "where", "conj", "arctan2", | ||
} | ||
|
||
|
||
class CupyFakeNumpyNamespace(BaseFakeNumpyNamespace): | ||
""" | ||
A :mod:`numpy` mimic for :class:`CupyArrayContext`. | ||
""" | ||
def _get_fake_numpy_linalg_namespace(self): | ||
return CupyFakeNumpyLinalgNamespace(self._array_context) | ||
|
||
def __getattr__(self, name): | ||
|
||
if name in _NUMPY_UFUNCS: | ||
from functools import partial | ||
return partial(rec_multimap_array_container, | ||
getattr(cp, name)) | ||
|
||
raise NotImplementedError | ||
|
||
def sum(self, a, axis=None, dtype=None): | ||
return rec_map_reduce_array_container(sum, partial(cp.sum, | ||
axis=axis, | ||
dtype=dtype), | ||
a) | ||
|
||
def min(self, a, axis=None): | ||
return rec_map_reduce_array_container( | ||
partial(reduce, cp.minimum), partial(cp.amin, axis=axis), a) | ||
|
||
def max(self, a, axis=None): | ||
return rec_map_reduce_array_container( | ||
partial(reduce, cp.maximum), partial(cp.amax, axis=axis), a) | ||
|
||
def stack(self, arrays, axis=0): | ||
return rec_multimap_array_container( | ||
lambda *args: cp.stack(args, axis=axis), | ||
*arrays) | ||
|
||
def broadcast_to(self, array, shape): | ||
return rec_map_array_container(partial(cp.broadcast_to, shape=shape), array) | ||
|
||
# {{{ relational operators | ||
|
||
def equal(self, x, y): | ||
return rec_multimap_array_container(cp.equal, x, y) | ||
|
||
def not_equal(self, x, y): | ||
return rec_multimap_array_container(cp.not_equal, x, y) | ||
|
||
def greater(self, x, y): | ||
return rec_multimap_array_container(cp.greater, x, y) | ||
|
||
def greater_equal(self, x, y): | ||
return rec_multimap_array_container(cp.greater_equal, x, y) | ||
|
||
def less(self, x, y): | ||
return rec_multimap_array_container(cp.less, x, y) | ||
|
||
def less_equal(self, x, y): | ||
return rec_multimap_array_container(cp.less_equal, x, y) | ||
|
||
# }}} | ||
|
||
def ravel(self, a, order="C"): | ||
return rec_map_array_container(partial(cp.ravel, order=order), a) | ||
|
||
def vdot(self, x, y, dtype=None): | ||
if dtype is not None: | ||
raise NotImplementedError("only 'dtype=None' supported.") | ||
|
||
return rec_multimap_reduce_array_container(sum, cp.vdot, x, y) | ||
|
||
def any(self, a): | ||
return rec_map_reduce_array_container(partial(reduce, cp.logical_or), | ||
lambda subary: cp.any(subary), a) | ||
|
||
def all(self, a): | ||
return rec_map_reduce_array_container(partial(reduce, cp.logical_and), | ||
lambda subary: cp.all(subary), a) | ||
|
||
def array_equal(self, a, b): | ||
if type(a) is not type(b): | ||
return False | ||
elif not is_array_container(a): | ||
if a.shape != b.shape: | ||
return False | ||
else: | ||
return cp.all(cp.equal(a, b)) | ||
else: | ||
try: | ||
return multimap_reduce_array_container(partial(reduce, | ||
cp.logical_and), | ||
self.array_equal, a, b) | ||
except TypeError: | ||
return True | ||
|
||
def zeros_like(self, ary): | ||
return rec_multimap_array_container(cp.zeros_like, ary) | ||
|
||
def reshape(self, a, newshape, order="C"): | ||
return rec_map_array_container( | ||
lambda ary: ary.reshape(newshape, order=order), | ||
a) | ||
|
||
def arange(self, *args, **kwargs): | ||
return cp.arange(*args, **kwargs) | ||
|
||
def linspace(self, *args, **kwargs): | ||
return cp.linspace(*args, **kwargs) | ||
|
||
# vim: fdm=marker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.