diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 8d371434..2127c5d2 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -480,7 +480,7 @@ def rec_map_reduce_array_container( or any other such traversal. """ - def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: + def rec(_ary: ArrayOrContainerT) -> Optional[ArrayOrContainerT]: if type(_ary) is leaf_class: return map_func(_ary) else: @@ -489,11 +489,22 @@ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: except NotAnArrayContainerError: return map_func(_ary) else: - return reduce_func([ - rec(subary) for _, subary in iterable - ]) + subary_results = [ + rec(subary) for _, subary in iterable] + filtered_subary_results = [ + result for result in subary_results + if result is not None] + if len(filtered_subary_results) > 0: + return reduce_func(filtered_subary_results) + else: + return None - return rec(ary) + result = rec(ary) + + if result is None: + raise ValueError("cannot reduce empty array container") + + return result def rec_multimap_reduce_array_container( @@ -519,12 +530,23 @@ def rec_multimap_reduce_array_container( # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any: - return reduce_func([subary for _, subary in iterable]) + filtered_subary_results = [ + result for _, result in iterable + if result is not None] + if len(filtered_subary_results) > 0: + return reduce_func(filtered_subary_results) + else: + return None - return _multimap_array_container_impl( + result = _multimap_array_container_impl( map_func, *args, reduce_func=_reduce_wrapper, leaf_cls=leaf_class, recursive=True) + if result is None: + raise ValueError("cannot reduce empty array container") + + return result + # }}}