Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten entire array containers #91

Merged
merged 23 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
917c656
add flatten to numpy for an entire container
alexfikl Sep 25, 2021
768a007
add flatten and unflatten
alexfikl Sep 27, 2021
3e81dd2
mention that serialize_container should be deterministic
alexfikl Sep 27, 2021
18e08ab
rename argument in unflatten_to_numpy
alexfikl Sep 27, 2021
da86855
remove flatten_to_numpy and unflatten_from_numpy
alexfikl Sep 27, 2021
5478997
raise if flattened container does not have homogeneous dtypes
alexfikl Sep 27, 2021
1d914ab
update tests with more array container layouts
alexfikl Sep 27, 2021
d859cd8
complain if unflattened strides do not match
alexfikl Sep 27, 2021
91e8950
improve docs for serialize_container
alexfikl Sep 27, 2021
8fb7d28
remove unused memoize_in
alexfikl Sep 27, 2021
82a8459
make failed reshape message less opaque
alexfikl Sep 28, 2021
c56e8f5
hardcode flatten and unflatten in c order
alexfikl Sep 28, 2021
087351c
add tests for flatten edge cases
alexfikl Sep 30, 2021
85c8b39
slice forwards when unflattening
alexfikl Sep 30, 2021
4d0fe4d
add missing space in exception message
alexfikl Sep 30, 2021
5dc9da2
add some more tests for unflatten
alexfikl Sep 30, 2021
7b03174
update test skip condition
alexfikl Oct 1, 2021
853f501
improve docs
alexfikl Oct 7, 2021
c826993
unflatten: better check that template and ary sizes match
alexfikl Oct 7, 2021
c50ee3e
update xfail link
alexfikl Oct 18, 2021
004e85f
pyopencl: remove unused astype
alexfikl Oct 18, 2021
f16ccab
Merge branch 'main' into flatten-to-numpy
alexfikl Oct 19, 2021
6bc8d56
Merge branch 'main' into flatten-to-numpy
alexfikl Oct 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
rec_map_reduce_array_container,
rec_multimap_reduce_array_container,
thaw, freeze,
flatten, unflatten,
from_numpy, to_numpy)

from .impl.pyopencl import PyOpenCLArrayContext
Expand Down Expand Up @@ -92,6 +93,7 @@
"map_reduce_array_container", "multimap_reduce_array_container",
"rec_map_reduce_array_container", "rec_multimap_reduce_array_container",
"thaw", "freeze",
"flatten", "unflatten",
"from_numpy", "to_numpy",

"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
Expand Down
6 changes: 5 additions & 1 deletion arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]:
r"""Serialize the array container into an iterable over its components.

The order of the components and their identifiers are entirely under
the control of the container class.
the control of the container class. However, the order is required to be
deterministic, i.e. two calls to :func:`serialize_container` on
array containers of the same types with the same number of
sub-arrays must result in an iterable with the keys in the same
order.

If *ary* is mutable, the serialization function is not required to ensure
that the serialization result reflects the array state at the time of the
Expand Down
130 changes: 130 additions & 0 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
.. autofunction:: freeze
.. autofunction:: thaw

Flattening and unflattening
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: flatten
.. autofunction:: unflatten

Numpy conversion
~~~~~~~~~~~~~~~~
.. autofunction:: from_numpy
Expand Down Expand Up @@ -493,6 +498,131 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
# }}}


# {{{ flatten / unflatten

def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer`
into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`.

The operation requires :attr:`arraycontext.ArrayContext.np` to have
``ravel`` and ``concatenate`` methods implemented. The order in which the
individual leaf arrays appear in the final array is dependent on the order
given by :func:`~arraycontext.serialize_container`.
"""
common_dtype = None
result: List[Any] = []

def _flatten(subary: ArrayOrContainerT) -> None:
nonlocal common_dtype

try:
iterable = serialize_container(subary)
except TypeError:
if common_dtype is None:
common_dtype = subary.dtype

if subary.dtype != common_dtype:
raise ValueError("arrays in container have different dtypes: "
f"got {subary.dtype}, expected {common_dtype}")

try:
flat_subary = actx.np.ravel(subary, order="C")
except ValueError as exc:
# NOTE: we can't do much if the array context fails to ravel,
# since it is the one responsible for the actual memory layout
if hasattr(subary, "strides"):
strides_msg = f" and strides {subary.strides}"
else:
strides_msg = ""

raise NotImplementedError(
f"'{type(actx).__name__}.np.ravel' failed to reshape "
f"an array with shape {subary.shape}{strides_msg}. "
"This functionality needs to be implemented by the "
"array context.") from exc

result.append(flat_subary)
else:
for _, isubary in iterable:
_flatten(isubary)

_flatten(ary)

return actx.np.concatenate(result)


def unflatten(
template: ArrayOrContainerT, ary: Any,
actx: ArrayContext) -> ArrayOrContainerT:
"""Unflatten an array *ary* produced by :func:`flatten` back into an
:class:`~arraycontext.ArrayContainer`.

The order and sizes of each slice into *ary* are determined by the
array container *template*.
"""
# NOTE: https://github.com/python/mypy/issues/7057
offset = 0

def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
nonlocal offset

try:
inducer marked this conversation as resolved.
Show resolved Hide resolved
iterable = serialize_container(template_subary)
except TypeError:
if (offset + template_subary.size) > ary.size:
raise ValueError("'template' and 'ary' sizes do not match: "
"'template' is too large")

if template_subary.dtype != ary.dtype:
raise ValueError("'template' dtype does not match 'ary': "
f"got {template_subary.dtype}, expected {ary.dtype}")

inducer marked this conversation as resolved.
Show resolved Hide resolved
flat_subary = ary[offset:offset + template_subary.size]
try:
subary = actx.np.reshape(flat_subary,
template_subary.shape, order="C")
except ValueError as exc:
# NOTE: we can't do much if the array context fails to reshape,
# since it is the one responsible for the actual memory layout
raise NotImplementedError(
f"'{type(actx).__name__}.np.reshape' failed to reshape "
f"the flat array into shape {template_subary.shape}. "
"This functionality needs to be implemented by the "
"array context.") from exc

if hasattr(template_subary, "strides"):
if template_subary.strides != subary.strides:
raise ValueError(
f"strides do not match template: got {subary.strides}, "
f"expected {template_subary.strides}")

offset += template_subary.size
return subary
else:
return deserialize_container(template_subary, [
(key, _unflatten(isubary)) for key, isubary in iterable
])

if not isinstance(ary, actx.array_types):
raise TypeError("'ary' does not have a type supported by the provided "
f"array context: got '{type(ary).__name__}', expected one of "
f"{actx.array_types}")

if ary.ndim != 1:
raise ValueError(
"only one dimensional arrays can be unflattened: "
f"'ary' has shape {ary.shape}")

result = _unflatten(template)
if offset != ary.size:
raise ValueError("'template' and 'ary' sizes do not match: "
"'ary' is too large")

return result

# }}}


# {{{ numpy conversion

def from_numpy(ary: Any, actx: ArrayContext) -> Any:
Expand Down
6 changes: 4 additions & 2 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ def stack(self, arrays, axis=0):
queue=self._array_context.queue),
*arrays)

def reshape(self, a, newshape):
return cl_array.reshape(a, newshape)
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: ary.reshape(newshape, order=order),
a)

def concatenate(self, arrays, axis=0):
return cl_array.concatenate(
Expand Down
6 changes: 4 additions & 2 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def __getattr__(self, name):

return super().__getattr__(name)

def reshape(self, a, newshape):
return rec_multimap_array_container(pt.reshape, a, newshape)
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: pt.reshape(a, newshape, order=order),
a)

def transpose(self, a, axes=None):
return rec_multimap_array_container(pt.transpose, a, axes)
Expand Down
101 changes: 87 additions & 14 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def imag(self):

@serialize_container.register(DOFArray)
def _serialize_dof_container(ary: DOFArray):
return enumerate(ary.data)
return list(enumerate(ary.data))


@deserialize_container.register(DOFArray)
Expand Down Expand Up @@ -203,17 +203,27 @@ def randn(shape, dtype):
rng = np.random.default_rng()
dtype = np.dtype(dtype)

if shape == 0:
ashape = 1
else:
ashape = shape

if dtype.kind == "c":
dtype = np.dtype(f"<f{dtype.itemsize // 2}")
return rng.standard_normal(shape, dtype) \
+ 1j * rng.standard_normal(shape, dtype)
r = rng.standard_normal(ashape, dtype) \
+ 1j * rng.standard_normal(ashape, dtype)
elif dtype.kind == "f":
return rng.standard_normal(shape, dtype)
r = rng.standard_normal(ashape, dtype)
elif dtype.kind == "i":
return rng.integers(0, 128, shape, dtype)
r = rng.integers(0, 512, ashape, dtype)
else:
raise TypeError(dtype.kind)

if shape == 0:
return np.array(r[0])

return r


def assert_close_to_numpy(actx, op, args):
assert np.allclose(
Expand Down Expand Up @@ -672,11 +682,14 @@ def array_context(self):
return self.mass.array_context


def _get_test_containers(actx, ambient_dim=2, size=50_000):
if size == 0:
x = DOFArray(actx, (actx.from_numpy(np.array(np.random.randn())),))
else:
x = DOFArray(actx, (actx.from_numpy(np.random.randn(size)),))
def _get_test_containers(actx, ambient_dim=2, shapes=50_000):
from numbers import Number
if isinstance(shapes, (Number, tuple)):
shapes = [shapes]

x = DOFArray(actx, tuple([
actx.from_numpy(randn(shape, np.float64))
for shape in shapes]))

# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
dataclass_of_dofs = MyContainer(
Expand Down Expand Up @@ -705,7 +718,7 @@ def _get_test_containers(actx, ambient_dim=2, size=50_000):
def test_container_scalar_map(actx_factory):
actx = actx_factory()

arys = _get_test_containers(actx, size=0)
arys = _get_test_containers(actx, shapes=0)
arys += (np.pi,)

from arraycontext import (
Expand Down Expand Up @@ -877,16 +890,76 @@ def test_container_norm(actx_factory, ord):
# }}}


# {{{ test flatten and unflatten

@pytest.mark.parametrize("shapes", [
0, # tests device scalars when flattening
512,
[(128, 67)],
[(127, 67), (18, 0)], # tests 0-sized arrays
[(64, 7), (154, 12)]
])
def test_flatten_array_container(actx_factory, shapes):
if np.prod(shapes) == 0:
# https://github.com/inducer/loopy/pull/497
# NOTE: only fails for the pytato array context at the moment
pytest.xfail("strides do not match in subary")

actx = actx_factory()

from arraycontext import flatten, unflatten
arys = _get_test_containers(actx, shapes=shapes)

for ary in arys:
flat = flatten(ary, actx)
assert flat.ndim == 1

ary_roundtrip = unflatten(ary, flat, actx)

from arraycontext import rec_multimap_reduce_array_container
assert rec_multimap_reduce_array_container(
np.prod,
lambda x, y: x.shape == y.shape,
ary, ary_roundtrip)

assert actx.to_numpy(
actx.np.linalg.norm(ary - ary_roundtrip)
) < 1.0e-15


def test_flatten_array_container_failure(actx_factory):
actx = actx_factory()

from arraycontext import flatten, unflatten
ary = _get_test_containers(actx, shapes=512)[0]
flat_ary = flatten(ary, actx)

with pytest.raises(TypeError):
# cannot unflatten from a numpy array
unflatten(ary, actx.to_numpy(flat_ary), actx)

with pytest.raises(ValueError):
# cannot unflatten non-flat arrays
unflatten(ary, flat_ary.reshape(2, -1), actx)

with pytest.raises(ValueError):
# cannot unflatten partially
unflatten(ary, flat_ary[:-1], actx)

# }}}


# {{{ test from_numpy and to_numpy

def test_numpy_conversion(actx_factory):
actx = actx_factory()

nelements = 42
ac = MyContainer(
name="test_numpy_conversion",
mass=np.random.rand(42),
momentum=make_obj_array([np.random.rand(42) for _ in range(3)]),
enthalpy=np.random.rand(42),
mass=np.random.rand(nelements, nelements),
momentum=make_obj_array([np.random.rand(nelements) for _ in range(3)]),
enthalpy=np.array(np.random.rand()),
)

from arraycontext import from_numpy, to_numpy
Expand Down