Skip to content

Commit

Permalink
add flatten to numpy for an entire container
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 26, 2021
1 parent 9314073 commit e9450ed
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 4 deletions.
4 changes: 3 additions & 1 deletion arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
rec_map_reduce_array_container,
rec_multimap_reduce_array_container,
thaw, freeze,
from_numpy, to_numpy)
from_numpy, to_numpy,
flatten_to_numpy, unflatten_from_numpy)

from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoPyOpenCLArrayContext
Expand Down Expand Up @@ -93,6 +94,7 @@
"rec_map_reduce_array_container", "rec_multimap_reduce_array_container",
"thaw", "freeze",
"from_numpy", "to_numpy",
"flatten_to_numpy", "unflatten_from_numpy",

"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",

Expand Down
65 changes: 65 additions & 0 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
~~~~~~~~~~~~~~~~
.. autofunction:: from_numpy
.. autofunction:: to_numpy
.. autofunction:: flatten_to_numpy
.. autofunction:: unflatten_from_numpy
"""

__copyright__ = """
Expand Down Expand Up @@ -520,6 +522,69 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any:
"""
return rec_map_array_container(actx.to_numpy, ary)


def flatten_to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> np.ndarray:
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer`
to host :mod:`numpy` arrays, flatten them using :func:`~numpy.ravel`
and concatenate them into a single :class:`~numpy.ndarray`.
The order in which the individual leaf arrays appear in the final array is
dependent on the order given by :func:`~arraycontext.serialize_container`.
"""
def _flatten_to_numpy(subary):
try:
iterable = serialize_container(subary)
except TypeError:
result.append(actx.to_numpy(subary).ravel())
else:
for _, isubary in iterable:
_flatten_to_numpy(isubary)

result = []
_flatten_to_numpy(ary)

return np.concatenate(result)


def unflatten_from_numpy(
template: ArrayOrContainerT, ary: np.ndarray,
actx: ArrayContext) -> ArrayOrContainerT:
"""Unflatten an :class:`~numpy.ndarray` produced by :func:`flatten_to_numpy`
back into an :class:`~arraycontext.ArrayContainer`.
The order and sizes of each slice into *ary* are determined by the
array container *template*.
"""
def _unflatten_from_numpy(subary: ArrayOrContainerT) -> ArrayOrContainerT:
nonlocal offset

try:
iterable = serialize_container(subary)
except TypeError:
# NOTE: the max is needed to handle device scalars with size == 0
offset += max(1, subary.size)
if offset > ary.size:
raise ValueError("'template' and 'ary' sizes do not match")

# FIXME: subary can be F-contiguous and ary will always be C-contiguous
return actx.from_numpy(
ary[offset - subary.size:offset]
.astype(subary.dtype, copy=False)
.reshape(subary.shape)
)
else:
return deserialize_container(subary, [
(key, _unflatten_from_numpy(isubary)) for key, isubary in iterable
])

if ary.ndim != 1:
raise ValueError(
"only one dimensional arrays can be unflattened: "
f"'ary' has shape {ary.shape}")

offset = 0
return _unflatten_from_numpy(template)

# }}}

# vim: foldmethod=marker
36 changes: 33 additions & 3 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,13 +882,16 @@ def test_container_norm(actx_factory, ord):
def test_numpy_conversion(actx_factory):
actx = actx_factory()

nelements = 42
ac = MyContainer(
name="test_numpy_conversion",
mass=np.random.rand(42),
momentum=make_obj_array([np.random.rand(42) for _ in range(3)]),
enthalpy=np.random.rand(42),
mass=np.random.rand(nelements, nelements),
momentum=make_obj_array([np.random.rand(nelements) for _ in range(3)]),
enthalpy=np.array(np.random.rand()),
)

# {{{ to/from_numpy

from arraycontext import from_numpy, to_numpy
ac_actx = from_numpy(ac, actx)
ac_roundtrip = to_numpy(ac_actx, actx)
Expand All @@ -907,6 +910,33 @@ def test_numpy_conversion(actx_factory):
with pytest.raises(ValueError):
to_numpy(ac, actx)

# }}}

# {{{ un/flatten

from arraycontext import flatten_to_numpy, unflatten_from_numpy
ac_flat = flatten_to_numpy(ac_actx, actx)
assert ac_flat.shape == (nelements**2 + 3 * nelements + 1,)

ac_roundtrip = unflatten_from_numpy(ac_actx, ac_flat, actx)
for name in ("mass", "momentum", "enthalpy"):
field = getattr(ac_actx, name)
field_roundtrip = getattr(ac_roundtrip, name)

assert field.dtype == field_roundtrip.dtype
assert field.shape == field_roundtrip.shape
assert np.linalg.norm(
np.linalg.norm(to_numpy(field - field_roundtrip, actx))
) < 1.0e-15

with pytest.raises(ValueError):
unflatten_from_numpy(ac_actx, ac_flat[:-12], actx)

with pytest.raises(ValueError):
unflatten_from_numpy(ac_actx, ac_flat.reshape(2, -1), actx)

# }}}

# }}}


Expand Down

0 comments on commit e9450ed

Please sign in to comment.