From 90091a23a00112ea63522163e2118bc79931bb47 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 30 May 2023 07:26:36 -0500 Subject: [PATCH] tests pick list fusion --- test/test_pytato_transforms.py | 225 +++++++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 test/test_pytato_transforms.py diff --git a/test/test_pytato_transforms.py b/test/test_pytato_transforms.py new file mode 100644 index 00000000..a05e32bf --- /dev/null +++ b/test/test_pytato_transforms.py @@ -0,0 +1,225 @@ +import numpy as np # noqa: F401 +import pyopencl as cl +from typing import Union +from meshmode.mesh import BTAG_ALL +from meshmode.mesh.generation import generate_regular_rect_mesh +from arraycontext.metadata import NameHint +from meshmode.array_context import (PytatoPyOpenCLArrayContext, + PyOpenCLArrayContext) +from pytato.transform import CombineMapper +from pytato.array import (Placeholder, DataWrapper, SizeParam, IndexBase, + Array, DictOfNamedArrays) +from meshmode.discretization.connection import (FACE_RESTR_INTERIOR, + FACE_RESTR_ALL) +from pytools.obj_array import make_obj_array +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests) +import grudge +import grudge.op as op + + +# {{{ utilities for test_push_indirections_* + +class _IndexeeArraysMaterializedChecker(CombineMapper[bool]): + def combine(self, *args: bool) -> bool: + return all(args) + + def map_placeholder(self, expr: Placeholder) -> bool: + return True + + def map_data_wrapper(self, expr: DataWrapper) -> bool: + return True + + def map_size_param(self, expr: SizeParam) -> bool: + return True + + def _map_index_base(self, expr: IndexBase) -> bool: + from grudge.pytato_transforms.pytato_indirection_transforms import ( + _is_materialized) + return self.combine( + _is_materialized(expr.array) or isinstance(expr.array, IndexBase), + self.rec(expr.array) + ) + + +def are_all_indexees_materialized_nodes( + expr: Union[Array, DictOfNamedArrays]) -> bool: + """ + Returns *True* only if all indexee arrays are either materialized nodes, + OR, other indexing nodes that have materialized indexees. + """ + return _IndexeeArraysMaterializedChecker()(expr) + + +class _IndexerArrayDatawrapperChecker(CombineMapper[bool]): + def combine(self, *args: bool) -> bool: + return all(args) + + def map_placeholder(self, expr: Placeholder) -> bool: + return True + + def map_data_wrapper(self, expr: DataWrapper) -> bool: + return True + + def map_size_param(self, expr: SizeParam) -> bool: + return True + + def _map_index_base(self, expr: IndexBase) -> bool: + return self.combine( + *[isinstance(idx, DataWrapper) + for idx in expr.indices + if isinstance(idx, Array)], + super()._map_index_base(expr), + ) + + +def are_all_indexer_arrays_datawrappers( + expr: Union[Array, DictOfNamedArrays]) -> bool: + """ + Returns *True* only if all indexer arrays are instances of + :class:`~pytato.array.DataWrapper`. + """ + return _IndexerArrayDatawrapperChecker()(expr) + +# }}} + + +def _evaluate_dict_of_named_arrays(actx, dict_of_named_arrays): + container = make_obj_array([dict_of_named_arrays._data[name] + for name in sorted(dict_of_named_arrays.keys())]) + + evaluated_container = actx.thaw(actx.freeze(container)) + + return {name: evaluated_container[i] + for i, name in enumerate(sorted(dict_of_named_arrays.keys()))} + + +class FluxOptimizerActx(PytatoPyOpenCLArrayContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.check_completed = False + + def transform_dag(self, dag): + from grudge.pytato_transforms.pytato_indirection_transforms import ( + fuse_dof_pick_lists, fold_constant_indirections) + from pytato.tags import PrefixNamed + + if ( + len(dag) == 1 + and PrefixNamed("flux_container") in list(dag._data.values())[0].tags + ): + assert not are_all_indexer_arrays_datawrappers(dag) + self.check_completed = True + + dag = fuse_dof_pick_lists(dag) + dag = fold_constant_indirections( + dag, lambda x: _evaluate_dict_of_named_arrays(self, x)) + + if ( + len(dag) == 1 + and PrefixNamed("flux_container") in list(dag._data.values())[0].tags + ): + assert are_all_indexer_arrays_datawrappers(dag) + self.check_completed = True + + return dag + + +# {{{ test_resampling_indirections_are_fused_0 + +def _compute_flux_0(dcoll, actx, u): + u_interior_tpair, = op.interior_trace_pairs(dcoll, u) + flux_on_interior_faces = u_interior_tpair.avg + flux_on_all_faces = op.project( + dcoll, FACE_RESTR_INTERIOR, FACE_RESTR_ALL, flux_on_interior_faces) + + flux_on_all_faces = actx.tag(NameHint("flux_container"), flux_on_all_faces) + return flux_on_all_faces + + +def test_resampling_indirections_are_fused_0(ctx_factory): + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + + ref_actx = PyOpenCLArrayContext(cq) + actx = FluxOptimizerActx(cq) + + dim = 3 + nel_1d = 4 + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dim, + b=(0.5,)*dim, + nelements_per_axis=(nel_1d,)*dim, + boundary_tag_to_face={"bdry": ["-x", "+x", + "-y", "+y", + "-z", "+z"]} + ) + dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2) + + x, _, _ = dcoll.nodes() + compiled_flux_0 = actx.compile(lambda ary: _compute_flux_0(dcoll, actx, ary)) + + ref_output = ref_actx.to_numpy( + _compute_flux_0(dcoll, ref_actx, ref_actx.thaw(x))) + output = actx.to_numpy( + compiled_flux_0(actx.thaw(x))) + + np.testing.assert_allclose(ref_output[0], output[0]) + assert actx.check_completed + +# }}} + + +# {{{ test_resampling_indirections_are_fused_1 + +def _compute_flux_1(dcoll, actx, u): + u_interior_tpair, = op.interior_trace_pairs(dcoll, u) + flux_on_interior_faces = u_interior_tpair.avg + flux_on_bdry = op.project(dcoll, "vol", BTAG_ALL, u) + flux_on_all_faces = ( + op.project(dcoll, + FACE_RESTR_INTERIOR, + FACE_RESTR_ALL, + flux_on_interior_faces) + + op.project(dcoll, BTAG_ALL, FACE_RESTR_ALL, flux_on_bdry) + ) + + result = op.inverse_mass(dcoll, op.face_mass(dcoll, flux_on_all_faces)) + + result = actx.tag(NameHint("flux_container"), result) + return result + + +def test_resampling_indirections_are_fused_1(ctx_factory): + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + + ref_actx = PyOpenCLArrayContext(cq) + actx = FluxOptimizerActx(cq) + + dim = 3 + nel_1d = 4 + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dim, + b=(0.5,)*dim, + nelements_per_axis=(nel_1d,)*dim, + boundary_tag_to_face={"bdry": ["-x", "+x", + "-y", "+y", + "-z", "+z"]} + ) + dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2) + + x, _, _ = dcoll.nodes() + compiled_flux_1 = actx.compile(lambda ary: _compute_flux_1(dcoll, actx, ary)) + + ref_output = ref_actx.to_numpy( + _compute_flux_1(dcoll, ref_actx, ref_actx.thaw(x))) + output = actx.to_numpy( + compiled_flux_1(actx.thaw(x))) + + np.testing.assert_allclose(ref_output[0], output[0]) + assert actx.check_completed + +# }}} + +# vim: fdm=marker