diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index aa6de375..07c15446 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -256,7 +256,8 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: def rec_map_array_container( f: Callable[[Any], Any], - ary: ArrayOrContainerT) -> ArrayOrContainerT: + ary: ArrayOrContainerT, + leaf_class: Optional[type] = None) -> ArrayOrContainerT: r"""Applies *f* recursively to an :class:`ArrayContainer`. For a non-recursive version see :func:`map_array_container`. @@ -264,18 +265,32 @@ def rec_map_array_container( :param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s, or an instance of a base array type. """ - return _map_array_container_impl(f, ary, recursive=True) + return _map_array_container_impl(f, ary, leaf_cls=leaf_class, recursive=True) def mapped_over_array_containers( - f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]: + f: Optional[Callable[[Any], Any]] = None, + leaf_class: Optional[type] = None) -> Union[ + Callable[[ArrayOrContainerT], ArrayOrContainerT], + Callable[ + [Callable[[Any], Any]], + Callable[[ArrayOrContainerT], ArrayOrContainerT]]]: """Decorator around :func:`rec_map_array_container`.""" - wrapper = partial(rec_map_array_container, f) - update_wrapper(wrapper, f) - return wrapper + def decorator(g: Callable[[Any], Any]) -> Callable[ + [ArrayOrContainerT], ArrayOrContainerT]: + wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class) + update_wrapper(wrapper, g) + return wrapper + if f is not None: + return decorator(f) + else: + return decorator -def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: +def rec_multimap_array_container( + f: Callable[..., Any], + *args: Any, + leaf_class: Optional[type] = None) -> Any: r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s. For a non-recursive version see :func:`multimap_array_container`. @@ -283,19 +298,28 @@ def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: :param args: all :class:`ArrayContainer` arguments must be of the same type and with the same structure (same number of components, etc.). """ - return _multimap_array_container_impl(f, *args, recursive=True) + return _multimap_array_container_impl( + f, *args, leaf_cls=leaf_class, recursive=True) def multimapped_over_array_containers( - f: Callable[..., Any]) -> Callable[..., Any]: + f: Optional[Callable[..., Any]] = None, + leaf_class: Optional[type] = None) -> Union[ + Callable[..., Any], + Callable[[Callable[..., Any]], Callable[..., Any]]]: """Decorator around :func:`rec_multimap_array_container`.""" - # can't use functools.partial, because its result is insufficiently - # function-y to be used as a method definition. - def wrapper(*args: Any) -> Any: - return rec_multimap_array_container(f, *args) + def decorator(g: Callable[..., Any]) -> Callable[..., Any]: + # can't use functools.partial, because its result is insufficiently + # function-y to be used as a method definition. + def wrapper(*args: Any) -> Any: + return rec_multimap_array_container(g, *args, leaf_class=leaf_class) + update_wrapper(wrapper, g) + return wrapper + if f is not None: + return decorator(f) + else: + return decorator - update_wrapper(wrapper, f) - return wrapper # }}} @@ -401,7 +425,8 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any def rec_map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any], - ary: ArrayOrContainerT) -> "DeviceArray": + ary: ArrayOrContainerT, + leaf_class: Optional[type] = None) -> "DeviceArray": """Perform a map-reduce over array containers recursively. :param reduce_func: callable used to reduce over the components of *ary* @@ -440,14 +465,17 @@ def rec_map_reduce_array_container( or any other such traversal. """ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: - try: - iterable = serialize_container(_ary) - except NotAnArrayContainerError: + if type(_ary) is leaf_class: return map_func(_ary) else: - return reduce_func([ - rec(subary) for _, subary in iterable - ]) + try: + iterable = serialize_container(_ary) + except NotAnArrayContainerError: + return map_func(_ary) + else: + return reduce_func([ + rec(subary) for _, subary in iterable + ]) return rec(ary) @@ -455,7 +483,8 @@ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: def rec_multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], - *args: Any) -> "DeviceArray": + *args: Any, + leaf_class: Optional[type] = None) -> "DeviceArray": r"""Perform a map-reduce over multiple array containers recursively. :param reduce_func: callable used to reduce over the components of any @@ -478,7 +507,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any return _multimap_array_container_impl( map_func, *args, - reduce_func=_reduce_wrapper, leaf_cls=None, recursive=True) + reduce_func=_reduce_wrapper, leaf_cls=leaf_class, recursive=True) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0fe5480d..ac7ebbf5 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory): assert result is not None +def test_container_map(actx_factory): + actx = actx_factory() + ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ + _get_test_containers(actx) + + # {{{ check + + def _check_allclose(f, arg1, arg2, atol=2.0e-14): + from arraycontext import NotAnArrayContainerError + try: + arg1_iterable = serialize_container(arg1) + arg2_iterable = serialize_container(arg2) + except NotAnArrayContainerError: + assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol + else: + arg1_subarrays = [ + subarray for _, subarray in arg1_iterable] + arg2_subarrays = [ + subarray for _, subarray in arg2_iterable] + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays): + _check_allclose(f, subarray1, subarray2) + + def func(x): + return x + 1 + + from arraycontext import rec_map_array_container + result = rec_map_array_container(func, 1) + assert result == 2 + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + result = rec_map_array_container(func, ary) + _check_allclose(func, ary, result) + + from arraycontext import mapped_over_array_containers + + @mapped_over_array_containers + def mapped_func(x): + return func(x) + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + result = mapped_func(ary) + _check_allclose(func, ary, result) + + @mapped_over_array_containers(leaf_class=DOFArray) + def check_leaf(x): + assert isinstance(x, DOFArray) + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + check_leaf(ary) + + # }}} + + def test_container_multimap(actx_factory): actx = actx_factory() ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ @@ -764,7 +817,19 @@ def test_container_multimap(actx_factory): # {{{ check def _check_allclose(f, arg1, arg2, atol=2.0e-14): - assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol + from arraycontext import NotAnArrayContainerError + try: + arg1_iterable = serialize_container(arg1) + arg2_iterable = serialize_container(arg2) + except NotAnArrayContainerError: + assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol + else: + arg1_subarrays = [ + subarray for _, subarray in arg1_iterable] + arg2_subarrays = [ + subarray for _, subarray in arg2_iterable] + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays): + _check_allclose(f, subarray1, subarray2) def func_all_scalar(x, y): return x + y @@ -779,17 +844,30 @@ def func_multiple_scalar(a, subary1, b, subary2): result = rec_multimap_array_container(func_all_scalar, 1, 2) assert result == 3 - from functools import partial for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: result = rec_multimap_array_container(func_first_scalar, 1, ary) - rec_multimap_array_container( - partial(_check_allclose, lambda x: 1 + x), - ary, result) + _check_allclose(lambda x: 1 + x, ary, result) result = rec_multimap_array_container(func_multiple_scalar, 2, ary, 2, ary) - rec_multimap_array_container( - partial(_check_allclose, lambda x: 4 * x), - ary, result) + _check_allclose(lambda x: 4 * x, ary, result) + + from arraycontext import multimapped_over_array_containers + + @multimapped_over_array_containers + def mapped_func(a, subary1, b, subary2): + return func_multiple_scalar(a, subary1, b, subary2) + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + result = mapped_func(2, ary, 2, ary) + _check_allclose(lambda x: 4 * x, ary, result) + + @multimapped_over_array_containers(leaf_class=DOFArray) + def check_leaf(a, subary1, b, subary2): + assert isinstance(subary1, DOFArray) + assert isinstance(subary2, DOFArray) + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + check_leaf(2, ary, 2, ary) with pytest.raises(AssertionError): rec_multimap_array_container(func_multiple_scalar, 2, ary_dof, 2, dc_of_dofs)