Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add NumpyArrayContext #235

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike,
tag_axes)
from .impl.jax import EagerJAXArrayContext
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .loopy import make_loopy_program
Expand Down Expand Up @@ -101,6 +102,8 @@
"PytatoJAXArrayContext",
"EagerJAXArrayContext",

"NumpyArrayContext",

"make_loopy_program",

"PytestArrayContextFactory",
Expand Down
6 changes: 5 additions & 1 deletion arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,11 @@ def is_array_container(ary: Any) -> bool:
"cheaper option, see is_array_container_type.",
DeprecationWarning, stacklevel=2)
return (serialize_container.dispatch(ary.__class__)
is not serialize_container.__wrapped__) # type:ignore[attr-defined]
is not serialize_container.__wrapped__ # type:ignore[attr-defined]
# numpy values with scalar elements aren't array containers
and not (isinstance(ary, np.ndarray)
and ary.dtype.kind != "O")
)


@singledispatch
Expand Down
32 changes: 21 additions & 11 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
"""

from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -214,6 +213,15 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
if rel_comparison is None:
raise TypeError("rel_comparison must be specified")

if bcast_numpy_array:
from warnings import warn
warn("'bcast_numpy_array=True' is deprecated and will be unsupported"
" from December 2021", DeprecationWarning, stacklevel=2)

if _bcast_actx_array_type:
raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'"
" cannot be both set.")

if rel_comparison and eq_comparison is None:
eq_comparison = True

Expand Down Expand Up @@ -265,6 +273,7 @@ def wrap(cls: Any) -> Any:

if cls_has_array_context_attr is None:
if hasattr(cls, "array_context"):
from warnings import warn
cls_has_array_context_attr = _FailSafe
warn(f"{cls} has an .array_context attribute, but it does not "
"set _cls_has_array_context_attr to True when calling "
Expand Down Expand Up @@ -484,16 +493,17 @@ def {fname}(arg1):
bcast_actx_ary_types = ()

gen(f"""
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg2,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_same_cls_init_args})
if {numpy_pred("arg2")}:
result = np.empty_like(arg2, dtype=object)
for i in np.ndindex(arg2.shape):
result[i] = {op_str.format("arg1", "arg2[i]")}
return result

if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg2,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_same_cls_init_args})
return NotImplemented
""")
gen(f"cls.__{dunder_name}__ = {fname}")
Expand Down Expand Up @@ -530,16 +540,16 @@ def {fname}(arg1):
def {fname}(arg2, arg1):
# assert other.__cls__ is not cls

if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg1,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_init_args})
if {numpy_pred("arg1")}:
result = np.empty_like(arg1, dtype=object)
for i in np.ndindex(arg1.shape):
result[i] = {op_str.format("arg1[i]", "arg2")}
return result
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg1,
{tup_str(outer_bcast_type_names
+ bcast_actx_ary_types)}):
return cls({bcast_init_args})
return NotImplemented

cls.__r{dunder_name}__ = {fname}""")
Expand Down
134 changes: 134 additions & 0 deletions arraycontext/impl/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
.. currentmodule:: arraycontext


A mod :`numpy`-based array context.

.. autoclass:: NumpyArrayContext
"""
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from typing import Dict

import numpy as np

import loopy as lp

from arraycontext.container.traversal import (
rec_map_array_container, with_array_context)
from arraycontext.context import ArrayContext


class NumpyArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses :mod:`numpy.ndarray` to represent arrays


.. automethod:: __init__
"""
def __init__(self):
super().__init__()
self._loopy_transform_cache: \
Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {}

self.array_types = (np.ndarray,)

def _get_fake_numpy_namespace(self):
from .fake_numpy import NumpyFakeNumpyNamespace
return NumpyFakeNumpyNamespace(self)

# {{{ ArrayContext interface

def clone(self):
return type(self)()

def empty(self, shape, dtype):
return np.empty(shape, dtype=dtype)

def zeros(self, shape, dtype):
return np.zeros(shape, dtype)

def from_numpy(self, np_array):
# Uh oh...
return np_array

def to_numpy(self, array):
# Uh oh...
return array

def call_loopy(self, t_unit, **kwargs):
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
try:
t_unit = self._loopy_transform_cache[t_unit]
except KeyError:
orig_t_unit = t_unit
t_unit = self.transform_loopy_program(t_unit)
self._loopy_transform_cache[orig_t_unit] = t_unit
del orig_t_unit

_, result = t_unit(**kwargs)

return result

def freeze(self, array):
def _freeze(ary):
return ary

return with_array_context(rec_map_array_container(_freeze, array), actx=None)

def thaw(self, array):
def _thaw(ary):
return ary

return with_array_context(rec_map_array_container(_thaw, array), actx=self)

# }}}

def transform_loopy_program(self, t_unit):
raise ValueError("NumpyArrayContext does not implement "
"transform_loopy_program. Sub-classes are supposed "
"to implement it.")

def tag(self, tags, array):
# Numpy doesn't support tagging
return array

def tag_axis(self, iaxis, tags, array):
return array

def einsum(self, spec, *args, arg_names=None, tagged=()):
return np.einsum(spec, *args)

@property
def permits_inplace_modification(self):
return True

@property
def supports_nonscalar_broadcasting(self):
return True

@property
def permits_advanced_indexing(self):
return True
156 changes: 156 additions & 0 deletions arraycontext/impl/numpy/fake_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, reduce

import numpy as np

from arraycontext.container import is_array_container
from arraycontext.container.traversal import (
multimap_reduce_array_container, rec_map_array_container,
rec_map_reduce_array_container, rec_multimap_array_container,
rec_multimap_reduce_array_container)
from arraycontext.fake_numpy import (
BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace)


class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass


_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan",
"sinh", "cosh", "tanh", "exp", "log", "log10", "isnan",
"sqrt", "concatenate", "transpose",
"ones_like", "maximum", "minimum", "where", "conj", "arctan2",
}


class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`NumpyArrayContext`.
"""
def _get_fake_numpy_linalg_namespace(self):
return NumpyFakeNumpyLinalgNamespace(self._array_context)

def __getattr__(self, name):

if name in _NUMPY_UFUNCS:
from functools import partial
return partial(rec_multimap_array_container,
getattr(np, name))

raise NotImplementedError

def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(sum, partial(np.sum,
axis=axis,
dtype=dtype),
a)

def min(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, np.minimum), partial(np.amin, axis=axis), a)

def max(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, np.maximum), partial(np.amax, axis=axis), a)

def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: np.stack(arrays=args, axis=axis),
*arrays)

def broadcast_to(self, array, shape):
return rec_map_array_container(partial(np.broadcast_to, shape=shape), array)

# {{{ relational operators

def equal(self, x, y):
return rec_multimap_array_container(np.equal, x, y)

def not_equal(self, x, y):
return rec_multimap_array_container(np.not_equal, x, y)

def greater(self, x, y):
return rec_multimap_array_container(np.greater, x, y)

def greater_equal(self, x, y):
return rec_multimap_array_container(np.greater_equal, x, y)

def less(self, x, y):
return rec_multimap_array_container(np.less, x, y)

def less_equal(self, x, y):
return rec_multimap_array_container(np.less_equal, x, y)

# }}}

def ravel(self, a, order="C"):
return rec_map_array_container(partial(np.ravel, order=order), a)

def vdot(self, x, y, dtype=None):
if dtype is not None:
raise NotImplementedError("only 'dtype=None' supported.")

return rec_multimap_reduce_array_container(sum, np.vdot, x, y)

def any(self, a):
return rec_map_reduce_array_container(partial(reduce, np.logical_or),
lambda subary: np.any(subary), a)

def all(self, a):
return rec_map_reduce_array_container(partial(reduce, np.logical_and),
lambda subary: np.all(subary), a)

def array_equal(self, a, b):
if type(a) is not type(b):
return False
elif not is_array_container(a):
if a.shape != b.shape:
return False
else:
return np.all(np.equal(a, b))
else:
try:
return multimap_reduce_array_container(partial(reduce,
np.logical_and),
self.array_equal, a, b)
except TypeError:
return True

def zeros_like(self, ary):
return rec_multimap_array_container(np.zeros_like, ary)

def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: ary.reshape(newshape, order=order),
a)

def arange(self, *args, **kwargs):
return np.arange(*args, **kwargs)

def linspace(self, *args, **kwargs):
return np.linspace(*args, **kwargs)

# vim: fdm=marker
Loading
Loading