diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index cf9640fe..db3633d1 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -526,7 +526,7 @@ def _flatten(subary: ArrayOrContainerT) -> None: f"got {subary.dtype}, expected {common_dtype}") try: - flat_subary = actx.np.ravel(subary, order="A") + flat_subary = actx.np.ravel(subary, order="C") except ValueError as exc: # NOTE: we can't do much if the array context fails to ravel, # since it is the one responsible for the actual memory layout @@ -580,7 +580,8 @@ def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT: flat_subary = ary[offset - template_subary.size:offset] try: - subary = actx.np.reshape(flat_subary, template_subary.shape) + subary = actx.np.reshape(flat_subary, + template_subary.shape, order="C") except ValueError as exc: # NOTE: we can't do much if the array context fails to reshape, # since it is the one responsible for the actual memory layout diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index c5b57ef5..c60f33d7 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -172,8 +172,10 @@ def stack(self, arrays, axis=0): queue=self._array_context.queue), *arrays) - def reshape(self, a, newshape): - return cl_array.reshape(a, newshape) + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: ary.reshape(newshape, order=order), + a) def concatenate(self, arrays, axis=0): return cl_array.concatenate( diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 01efaec8..62b5e200 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -64,8 +64,9 @@ def __getattr__(self, name): return super().__getattr__(name) - def reshape(self, a, newshape): - return rec_multimap_array_container(pt.reshape, a, newshape) + def reshape(self, a, newshape, order="C"): + return rec_multimap_array_container( + partial(pt.reshape, order=order), a, newshape) def transpose(self, a, axes=None): return rec_multimap_array_container(pt.transpose, a, axes)