Skip to content

Commit

Permalink
Allow specifying leaf class in recursive map and map-reduce (#128)
Browse files Browse the repository at this point in the history
* allow specifying leaf class in recursive map and map-reduce

* revert broken changes to decorators

* add leaf_class to decorators

* add tests for [multi]mapped_over_array_containers

Co-authored-by: Andreas Klöckner <[email protected]>
  • Loading branch information
majosm and inducer authored Dec 27, 2021
1 parent 5c64c75 commit 1810b0e
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 32 deletions.
77 changes: 53 additions & 24 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,46 +256,70 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:

def rec_map_array_container(
f: Callable[[Any], Any],
ary: ArrayOrContainerT) -> ArrayOrContainerT:
ary: ArrayOrContainerT,
leaf_class: Optional[type] = None) -> ArrayOrContainerT:
r"""Applies *f* recursively to an :class:`ArrayContainer`.
For a non-recursive version see :func:`map_array_container`.
:param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
or an instance of a base array type.
"""
return _map_array_container_impl(f, ary, recursive=True)
return _map_array_container_impl(f, ary, leaf_cls=leaf_class, recursive=True)


def mapped_over_array_containers(
f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]:
f: Optional[Callable[[Any], Any]] = None,
leaf_class: Optional[type] = None) -> Union[
Callable[[ArrayOrContainerT], ArrayOrContainerT],
Callable[
[Callable[[Any], Any]],
Callable[[ArrayOrContainerT], ArrayOrContainerT]]]:
"""Decorator around :func:`rec_map_array_container`."""
wrapper = partial(rec_map_array_container, f)
update_wrapper(wrapper, f)
return wrapper
def decorator(g: Callable[[Any], Any]) -> Callable[
[ArrayOrContainerT], ArrayOrContainerT]:
wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class)
update_wrapper(wrapper, g)
return wrapper
if f is not None:
return decorator(f)
else:
return decorator


def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
def rec_multimap_array_container(
f: Callable[..., Any],
*args: Any,
leaf_class: Optional[type] = None) -> Any:
r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
For a non-recursive version see :func:`multimap_array_container`.
:param args: all :class:`ArrayContainer` arguments must be of the same
type and with the same structure (same number of components, etc.).
"""
return _multimap_array_container_impl(f, *args, recursive=True)
return _multimap_array_container_impl(
f, *args, leaf_cls=leaf_class, recursive=True)


def multimapped_over_array_containers(
f: Callable[..., Any]) -> Callable[..., Any]:
f: Optional[Callable[..., Any]] = None,
leaf_class: Optional[type] = None) -> Union[
Callable[..., Any],
Callable[[Callable[..., Any]], Callable[..., Any]]]:
"""Decorator around :func:`rec_multimap_array_container`."""
# can't use functools.partial, because its result is insufficiently
# function-y to be used as a method definition.
def wrapper(*args: Any) -> Any:
return rec_multimap_array_container(f, *args)
def decorator(g: Callable[..., Any]) -> Callable[..., Any]:
# can't use functools.partial, because its result is insufficiently
# function-y to be used as a method definition.
def wrapper(*args: Any) -> Any:
return rec_multimap_array_container(g, *args, leaf_class=leaf_class)
update_wrapper(wrapper, g)
return wrapper
if f is not None:
return decorator(f)
else:
return decorator

update_wrapper(wrapper, f)
return wrapper

# }}}

Expand Down Expand Up @@ -401,7 +425,8 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
def rec_map_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[[Any], Any],
ary: ArrayOrContainerT) -> "DeviceArray":
ary: ArrayOrContainerT,
leaf_class: Optional[type] = None) -> "DeviceArray":
"""Perform a map-reduce over array containers recursively.
:param reduce_func: callable used to reduce over the components of *ary*
Expand Down Expand Up @@ -440,22 +465,26 @@ def rec_map_reduce_array_container(
or any other such traversal.
"""
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
try:
iterable = serialize_container(_ary)
except NotAnArrayContainerError:
if type(_ary) is leaf_class:
return map_func(_ary)
else:
return reduce_func([
rec(subary) for _, subary in iterable
])
try:
iterable = serialize_container(_ary)
except NotAnArrayContainerError:
return map_func(_ary)
else:
return reduce_func([
rec(subary) for _, subary in iterable
])

return rec(ary)


def rec_multimap_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[..., Any],
*args: Any) -> "DeviceArray":
*args: Any,
leaf_class: Optional[type] = None) -> "DeviceArray":
r"""Perform a map-reduce over multiple array containers recursively.
:param reduce_func: callable used to reduce over the components of any
Expand All @@ -478,7 +507,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any

return _multimap_array_container_impl(
map_func, *args,
reduce_func=_reduce_wrapper, leaf_cls=None, recursive=True)
reduce_func=_reduce_wrapper, leaf_cls=leaf_class, recursive=True)

# }}}

Expand Down
94 changes: 86 additions & 8 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory):
assert result is not None


def test_container_map(actx_factory):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
_get_test_containers(actx)

# {{{ check

def _check_allclose(f, arg1, arg2, atol=2.0e-14):
from arraycontext import NotAnArrayContainerError
try:
arg1_iterable = serialize_container(arg1)
arg2_iterable = serialize_container(arg2)
except NotAnArrayContainerError:
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
else:
arg1_subarrays = [
subarray for _, subarray in arg1_iterable]
arg2_subarrays = [
subarray for _, subarray in arg2_iterable]
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
_check_allclose(f, subarray1, subarray2)

def func(x):
return x + 1

from arraycontext import rec_map_array_container
result = rec_map_array_container(func, 1)
assert result == 2

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = rec_map_array_container(func, ary)
_check_allclose(func, ary, result)

from arraycontext import mapped_over_array_containers

@mapped_over_array_containers
def mapped_func(x):
return func(x)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = mapped_func(ary)
_check_allclose(func, ary, result)

@mapped_over_array_containers(leaf_class=DOFArray)
def check_leaf(x):
assert isinstance(x, DOFArray)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
check_leaf(ary)

# }}}


def test_container_multimap(actx_factory):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
Expand All @@ -764,7 +817,19 @@ def test_container_multimap(actx_factory):
# {{{ check

def _check_allclose(f, arg1, arg2, atol=2.0e-14):
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
from arraycontext import NotAnArrayContainerError
try:
arg1_iterable = serialize_container(arg1)
arg2_iterable = serialize_container(arg2)
except NotAnArrayContainerError:
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
else:
arg1_subarrays = [
subarray for _, subarray in arg1_iterable]
arg2_subarrays = [
subarray for _, subarray in arg2_iterable]
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
_check_allclose(f, subarray1, subarray2)

def func_all_scalar(x, y):
return x + y
Expand All @@ -779,17 +844,30 @@ def func_multiple_scalar(a, subary1, b, subary2):
result = rec_multimap_array_container(func_all_scalar, 1, 2)
assert result == 3

from functools import partial
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = rec_multimap_array_container(func_first_scalar, 1, ary)
rec_multimap_array_container(
partial(_check_allclose, lambda x: 1 + x),
ary, result)
_check_allclose(lambda x: 1 + x, ary, result)

result = rec_multimap_array_container(func_multiple_scalar, 2, ary, 2, ary)
rec_multimap_array_container(
partial(_check_allclose, lambda x: 4 * x),
ary, result)
_check_allclose(lambda x: 4 * x, ary, result)

from arraycontext import multimapped_over_array_containers

@multimapped_over_array_containers
def mapped_func(a, subary1, b, subary2):
return func_multiple_scalar(a, subary1, b, subary2)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = mapped_func(2, ary, 2, ary)
_check_allclose(lambda x: 4 * x, ary, result)

@multimapped_over_array_containers(leaf_class=DOFArray)
def check_leaf(a, subary1, b, subary2):
assert isinstance(subary1, DOFArray)
assert isinstance(subary2, DOFArray)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
check_leaf(2, ary, 2, ary)

with pytest.raises(AssertionError):
rec_multimap_array_container(func_multiple_scalar, 2, ary_dof, 2, dc_of_dofs)
Expand Down

0 comments on commit 1810b0e

Please sign in to comment.