diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 75a91cb5..6b3fd1dd 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory): assert result is not None +def test_container_map(actx_factory): + actx = actx_factory() + ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ + _get_test_containers(actx) + + # {{{ check + + def _check_allclose(f, arg1, arg2, atol=2.0e-14): + from arraycontext import NotAnArrayContainerError + try: + arg1_iterable = serialize_container(arg1) + arg2_iterable = serialize_container(arg2) + except NotAnArrayContainerError: + assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol + else: + arg1_subarrays = [ + subarray for _, subarray in arg1_iterable] + arg2_subarrays = [ + subarray for _, subarray in arg2_iterable] + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays): + _check_allclose(f, subarray1, subarray2) + + def func(x): + return x + 1 + + from arraycontext import rec_map_array_container + result = rec_map_array_container(func, 1) + assert result == 2 + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + result = rec_map_array_container(func, ary) + _check_allclose(func, ary, result) + + from arraycontext import mapped_over_array_containers + + @mapped_over_array_containers + def mapped_func(x): + return func(x) + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + result = mapped_func(ary) + _check_allclose(func, ary, result) + + @mapped_over_array_containers(leaf_class=DOFArray) + def check_leaf(x): + assert isinstance(x, DOFArray) + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + check_leaf(ary) + + # }}} + + def test_container_multimap(actx_factory): actx = actx_factory() ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ @@ -764,7 +817,19 @@ def test_container_multimap(actx_factory): # {{{ check def _check_allclose(f, arg1, arg2, atol=2.0e-14): - assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol + from arraycontext import NotAnArrayContainerError + try: + arg1_iterable = serialize_container(arg1) + arg2_iterable = serialize_container(arg2) + except NotAnArrayContainerError: + assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol + else: + arg1_subarrays = [ + subarray for _, subarray in arg1_iterable] + arg2_subarrays = [ + subarray for _, subarray in arg2_iterable] + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays): + _check_allclose(f, subarray1, subarray2) def func_all_scalar(x, y): return x + y @@ -779,17 +844,30 @@ def func_multiple_scalar(a, subary1, b, subary2): result = rec_multimap_array_container(func_all_scalar, 1, 2) assert result == 3 - from functools import partial for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: result = rec_multimap_array_container(func_first_scalar, 1, ary) - rec_multimap_array_container( - partial(_check_allclose, lambda x: 1 + x), - ary, result) + _check_allclose(lambda x: 1 + x, ary, result) result = rec_multimap_array_container(func_multiple_scalar, 2, ary, 2, ary) - rec_multimap_array_container( - partial(_check_allclose, lambda x: 4 * x), - ary, result) + _check_allclose(lambda x: 4 * x, ary, result) + + from arraycontext import multimapped_over_array_containers + + @multimapped_over_array_containers + def mapped_func(a, subary1, b, subary2): + return func_multiple_scalar(a, subary1, b, subary2) + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + result = mapped_func(2, ary, 2, ary) + _check_allclose(lambda x: 4 * x, ary, result) + + @multimapped_over_array_containers(leaf_class=DOFArray) + def check_leaf(a, subary1, b, subary2): + assert isinstance(subary1, DOFArray) + assert isinstance(subary2, DOFArray) + + for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: + check_leaf(2, ary, 2, ary) with pytest.raises(AssertionError): rec_multimap_array_container(func_multiple_scalar, 2, ary_dof, 2, dc_of_dofs)