Skip to content

Commit

Permalink
add tests for [multi]mapped_over_array_containers
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Dec 16, 2021
1 parent b395fa8 commit 442a119
Showing 1 changed file with 86 additions and 8 deletions.
94 changes: 86 additions & 8 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory):
assert result is not None


def test_container_map(actx_factory):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
_get_test_containers(actx)

# {{{ check

def _check_allclose(f, arg1, arg2, atol=2.0e-14):
from arraycontext import NotAnArrayContainerError
try:
arg1_iterable = serialize_container(arg1)
arg2_iterable = serialize_container(arg2)
except NotAnArrayContainerError:
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
else:
arg1_subarrays = [
subarray for _, subarray in arg1_iterable]
arg2_subarrays = [
subarray for _, subarray in arg2_iterable]
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
_check_allclose(f, subarray1, subarray2)

def func(x):
return x + 1

from arraycontext import rec_map_array_container
result = rec_map_array_container(func, 1)
assert result == 2

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = rec_map_array_container(func, ary)
_check_allclose(func, ary, result)

from arraycontext import mapped_over_array_containers

@mapped_over_array_containers
def mapped_func(x):
return func(x)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = mapped_func(ary)
_check_allclose(func, ary, result)

@mapped_over_array_containers(leaf_class=DOFArray)
def check_leaf(x):
assert isinstance(x, DOFArray)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
check_leaf(ary)

# }}}


def test_container_multimap(actx_factory):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
Expand All @@ -764,7 +817,19 @@ def test_container_multimap(actx_factory):
# {{{ check

def _check_allclose(f, arg1, arg2, atol=2.0e-14):
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
from arraycontext import NotAnArrayContainerError
try:
arg1_iterable = serialize_container(arg1)
arg2_iterable = serialize_container(arg2)
except NotAnArrayContainerError:
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
else:
arg1_subarrays = [
subarray for _, subarray in arg1_iterable]
arg2_subarrays = [
subarray for _, subarray in arg2_iterable]
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
_check_allclose(f, subarray1, subarray2)

def func_all_scalar(x, y):
return x + y
Expand All @@ -779,17 +844,30 @@ def func_multiple_scalar(a, subary1, b, subary2):
result = rec_multimap_array_container(func_all_scalar, 1, 2)
assert result == 3

from functools import partial
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = rec_multimap_array_container(func_first_scalar, 1, ary)
rec_multimap_array_container(
partial(_check_allclose, lambda x: 1 + x),
ary, result)
_check_allclose(lambda x: 1 + x, ary, result)

result = rec_multimap_array_container(func_multiple_scalar, 2, ary, 2, ary)
rec_multimap_array_container(
partial(_check_allclose, lambda x: 4 * x),
ary, result)
_check_allclose(lambda x: 4 * x, ary, result)

from arraycontext import multimapped_over_array_containers

@multimapped_over_array_containers
def mapped_func(a, subary1, b, subary2):
return func_multiple_scalar(a, subary1, b, subary2)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
result = mapped_func(2, ary, 2, ary)
_check_allclose(lambda x: 4 * x, ary, result)

@multimapped_over_array_containers(leaf_class=DOFArray)
def check_leaf(a, subary1, b, subary2):
assert isinstance(subary1, DOFArray)
assert isinstance(subary2, DOFArray)

for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
check_leaf(2, ary, 2, ary)

with pytest.raises(AssertionError):
rec_multimap_array_container(func_multiple_scalar, 2, ary_dof, 2, dc_of_dofs)
Expand Down

0 comments on commit 442a119

Please sign in to comment.