From e9450ed1cde4879d0b4783b12b7b4d1dbcd1b8ca Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sat, 25 Sep 2021 16:44:20 -0500 Subject: [PATCH] add flatten to numpy for an entire container --- arraycontext/__init__.py | 4 +- arraycontext/container/traversal.py | 65 +++++++++++++++++++++++++++++ test/test_arraycontext.py | 36 ++++++++++++++-- 3 files changed, 101 insertions(+), 4 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 61203029..f1f9762c 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -58,7 +58,8 @@ rec_map_reduce_array_container, rec_multimap_reduce_array_container, thaw, freeze, - from_numpy, to_numpy) + from_numpy, to_numpy, + flatten_to_numpy, unflatten_from_numpy) from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoPyOpenCLArrayContext @@ -93,6 +94,7 @@ "rec_map_reduce_array_container", "rec_multimap_reduce_array_container", "thaw", "freeze", "from_numpy", "to_numpy", + "flatten_to_numpy", "unflatten_from_numpy", "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index ea5fce9c..1bf9fbcf 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -27,6 +27,8 @@ ~~~~~~~~~~~~~~~~ .. autofunction:: from_numpy .. autofunction:: to_numpy +.. autofunction:: flatten_to_numpy +.. autofunction:: unflatten_from_numpy """ __copyright__ = """ @@ -520,6 +522,69 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any: """ return rec_map_array_container(actx.to_numpy, ary) + +def flatten_to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> np.ndarray: + """Convert all arrays in the :class:`~arraycontext.ArrayContainer` + to host :mod:`numpy` arrays, flatten them using :func:`~numpy.ravel` + and concatenate them into a single :class:`~numpy.ndarray`. + + The order in which the individual leaf arrays appear in the final array is + dependent on the order given by :func:`~arraycontext.serialize_container`. + """ + def _flatten_to_numpy(subary): + try: + iterable = serialize_container(subary) + except TypeError: + result.append(actx.to_numpy(subary).ravel()) + else: + for _, isubary in iterable: + _flatten_to_numpy(isubary) + + result = [] + _flatten_to_numpy(ary) + + return np.concatenate(result) + + +def unflatten_from_numpy( + template: ArrayOrContainerT, ary: np.ndarray, + actx: ArrayContext) -> ArrayOrContainerT: + """Unflatten an :class:`~numpy.ndarray` produced by :func:`flatten_to_numpy` + back into an :class:`~arraycontext.ArrayContainer`. + + The order and sizes of each slice into *ary* are determined by the + array container *template*. + """ + def _unflatten_from_numpy(subary: ArrayOrContainerT) -> ArrayOrContainerT: + nonlocal offset + + try: + iterable = serialize_container(subary) + except TypeError: + # NOTE: the max is needed to handle device scalars with size == 0 + offset += max(1, subary.size) + if offset > ary.size: + raise ValueError("'template' and 'ary' sizes do not match") + + # FIXME: subary can be F-contiguous and ary will always be C-contiguous + return actx.from_numpy( + ary[offset - subary.size:offset] + .astype(subary.dtype, copy=False) + .reshape(subary.shape) + ) + else: + return deserialize_container(subary, [ + (key, _unflatten_from_numpy(isubary)) for key, isubary in iterable + ]) + + if ary.ndim != 1: + raise ValueError( + "only one dimensional arrays can be unflattened: " + f"'ary' has shape {ary.shape}") + + offset = 0 + return _unflatten_from_numpy(template) + # }}} # vim: foldmethod=marker diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 88000981..dfa238d0 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -882,13 +882,16 @@ def test_container_norm(actx_factory, ord): def test_numpy_conversion(actx_factory): actx = actx_factory() + nelements = 42 ac = MyContainer( name="test_numpy_conversion", - mass=np.random.rand(42), - momentum=make_obj_array([np.random.rand(42) for _ in range(3)]), - enthalpy=np.random.rand(42), + mass=np.random.rand(nelements, nelements), + momentum=make_obj_array([np.random.rand(nelements) for _ in range(3)]), + enthalpy=np.array(np.random.rand()), ) + # {{{ to/from_numpy + from arraycontext import from_numpy, to_numpy ac_actx = from_numpy(ac, actx) ac_roundtrip = to_numpy(ac_actx, actx) @@ -907,6 +910,33 @@ def test_numpy_conversion(actx_factory): with pytest.raises(ValueError): to_numpy(ac, actx) + # }}} + + # {{{ un/flatten + + from arraycontext import flatten_to_numpy, unflatten_from_numpy + ac_flat = flatten_to_numpy(ac_actx, actx) + assert ac_flat.shape == (nelements**2 + 3 * nelements + 1,) + + ac_roundtrip = unflatten_from_numpy(ac_actx, ac_flat, actx) + for name in ("mass", "momentum", "enthalpy"): + field = getattr(ac_actx, name) + field_roundtrip = getattr(ac_roundtrip, name) + + assert field.dtype == field_roundtrip.dtype + assert field.shape == field_roundtrip.shape + assert np.linalg.norm( + np.linalg.norm(to_numpy(field - field_roundtrip, actx)) + ) < 1.0e-15 + + with pytest.raises(ValueError): + unflatten_from_numpy(ac_actx, ac_flat[:-12], actx) + + with pytest.raises(ValueError): + unflatten_from_numpy(ac_actx, ac_flat.reshape(2, -1), actx) + + # }}} + # }}}