Skip to content

Commit

Permalink
hardcode flatten and unflatten in c order
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 28, 2021
1 parent 8f064ea commit 3ccf485
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
5 changes: 3 additions & 2 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def _flatten(subary: ArrayOrContainerT) -> None:
f"got {subary.dtype}, expected {common_dtype}")

try:
flat_subary = actx.np.ravel(subary, order="A")
flat_subary = actx.np.ravel(subary, order="C")
except ValueError as exc:
# NOTE: we can't do much if the array context fails to ravel,
# since it is the one responsible for the actual memory layout
Expand Down Expand Up @@ -580,7 +580,8 @@ def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:

flat_subary = ary[offset - template_subary.size:offset]
try:
subary = actx.np.reshape(flat_subary, template_subary.shape)
subary = actx.np.reshape(flat_subary,
template_subary.shape, order="C")
except ValueError as exc:
# NOTE: we can't do much if the array context fails to reshape,
# since it is the one responsible for the actual memory layout
Expand Down
6 changes: 4 additions & 2 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ def stack(self, arrays, axis=0):
queue=self._array_context.queue),
*arrays)

def reshape(self, a, newshape):
return cl_array.reshape(a, newshape)
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: ary.reshape(newshape, order=order),
a)

def concatenate(self, arrays, axis=0):
return cl_array.concatenate(
Expand Down
5 changes: 3 additions & 2 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ def __getattr__(self, name):

return super().__getattr__(name)

def reshape(self, a, newshape):
return rec_multimap_array_container(pt.reshape, a, newshape)
def reshape(self, a, newshape, order="C"):
return rec_multimap_array_container(
partial(pt.reshape, order=order), a, newshape)

def transpose(self, a, axes=None):
return rec_multimap_array_container(pt.transpose, a, axes)
Expand Down

0 comments on commit 3ccf485

Please sign in to comment.