From d8787e92c7eaa213569c97c21a72c93bc0b1d6fd Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 17 Jan 2023 08:14:57 -0600 Subject: [PATCH] Implement PytatoSplitArrayContext --- arraycontext/__init__.py | 3 + .../impl/pytato/split_actx/__init__.py | 133 +++++++ arraycontext/impl/pytato/split_actx/utils.py | 364 ++++++++++++++++++ 3 files changed, 500 insertions(+) create mode 100644 arraycontext/impl/pytato/split_actx/__init__.py create mode 100644 arraycontext/impl/pytato/split_actx/utils.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index b01b9917..e6367f2b 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -53,6 +53,7 @@ from .impl.jax import EagerJAXArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext +from .impl.pytato.split_actx import SplitPytatoPyOpenCLArrayContext from .loopy import make_loopy_program # deprecated, remove in 2022. from .metadata import _FirstAxisIsElementsTag @@ -98,6 +99,8 @@ "outer", "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", + "SplitPytatoPyOpenCLArrayContext", + "PytatoJAXArrayContext", "EagerJAXArrayContext", diff --git a/arraycontext/impl/pytato/split_actx/__init__.py b/arraycontext/impl/pytato/split_actx/__init__.py new file mode 100644 index 00000000..30fc3758 --- /dev/null +++ b/arraycontext/impl/pytato/split_actx/__init__.py @@ -0,0 +1,133 @@ +""" +.. autoclass:: SplitPytatoPyOpenCLArrayContext + +""" + +__copyright__ = """ +Copyright (C) 2023 Kaushik Kulkarni +Copyright (C) 2023 Andreas Kloeckner +Copyright (C) 2022 Matthias Diener +Copyright (C) 2022 Matt Smith +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +from typing import TYPE_CHECKING + +import loopy as lp + +from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext + + +if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", False): + import pytato + + +class SplitPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): + """ + .. note:: + + Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for + details on the transformation algorithm provided by this array context. + + .. warning:: + + For expression graphs with large number of nodes high compile times are + expected. + """ + def transform_dag(self, + dag: "pytato.DictOfNamedArrays") -> "pytato.DictOfNamedArrays": + r""" + Returns a transformed version of *dag*, where the applied transform is: + + #. Materialize as per MPMS materialization heuristic. + #. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs. + """ + import pytato as pt + + # Step 1. Collapse equivalent nodes in DAG. + # ----------------------------------------- + # type-ignore-reason: mypy is right pytato provides imprecise types. + dag = pt.transform.deduplicate_data_wrappers(dag) # type: ignore[assignment] + + # Step 2. Materialize einsum inputs/outputs. + # ------------------------------------------ + from .utils import get_inputs_and_outputs_of_einsum + einsum_inputs_outputs = frozenset.union( + *get_inputs_and_outputs_of_einsum(dag)) + + def materialize_einsum(expr: pt.transform.ArrayOrNames + ) -> pt.transform.ArrayOrNames: + if expr in einsum_inputs_outputs: + if isinstance(expr, pt.InputArgumentBase): + return expr + else: + return expr.tagged(pt.tags.ImplStored()) + else: + return expr + + # type-ignore-reason: mypy is right pytato provides imprecise types. + dag = pt.transform.map_and_copy(dag, # type: ignore[assignment] + materialize_einsum) + + # Step 3. MPMS materialize + # ------------------------ + dag = pt.transform.materialize_with_mpms(dag) + + return dag + + def transform_loopy_program(self, + t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + r""" + Returns a transformed version of *t_unit*, where the applied transform is: + + #. An execution grid size :math:`G` is selected based on *self*'s + OpenCL-device. + #. The iteration domain for each statement in the *t_unit* is divided to + equally among the work-items in :math:`G`. + #. Kernel boundaries are drawn between every statement in the instruction. + Although one can relax this constraint by letting :mod:`loopy` compute + where to insert the global barriers, but it is not guaranteed to be + performance profitable since we do not attempt any further loop-fusion + and/or array contraction. + #. Once the kernel boundaries are inferred, :func:`alias_global_temporaries` + is invoked to reduce the memory peak memory used by the transformed + program. + """ + # Step 1. Split the iteration across work-items + # --------------------------------------------- + from .utils import split_iteration_domain_across_work_items + t_unit = split_iteration_domain_across_work_items(t_unit, self.queue.device) + + # Step 2. Add a global barrier between individual loop nests. + # ------------------------------------------------------ + from .utils import add_gbarrier_between_disjoint_loop_nests + t_unit = add_gbarrier_between_disjoint_loop_nests(t_unit) + + # Step 3. Alias global temporaries with disjoint live intervals + # ------------------------------------------------------------- + from .utils import alias_global_temporaries + t_unit = alias_global_temporaries(t_unit) + + return t_unit + +# vim: fdm=marker diff --git a/arraycontext/impl/pytato/split_actx/utils.py b/arraycontext/impl/pytato/split_actx/utils.py new file mode 100644 index 00000000..58dcdbdc --- /dev/null +++ b/arraycontext/impl/pytato/split_actx/utils.py @@ -0,0 +1,364 @@ +__copyright__ = """ +Copyright (C) 2023 Kaushik Kulkarni +Copyright (C) 2023 Andreas Kloeckner +Copyright (C) 2022 Matthias Diener +Copyright (C) 2022 Matt Smith +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import logging +from dataclasses import dataclass +from typing import Any, Dict, FrozenSet, List, Set, Tuple, Union + +import loopy as lp +import loopy.match as lp_match +import pymbolic.primitives as prim +import pytato as pt +from loopy.translation_unit import for_each_kernel +from pymbolic.mapper.optimize import optimize_mapper + + +logger = logging.getLogger(__name__) + + +# {{{ EinsumInputOutputCollector + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class EinsumInputOutputCollector(pt.transform.CachedWalkMapper): + """ + .. note:: + + We deliberately avoid using :class:`pytato.transform.CombineMapper` since + the mapper's caching structure would still lead to recomputing + the union of sets for the results of a revisited node. + """ + def __init__(self) -> None: + self.collected_outputs: Set[pt.Array] = set() + self.collected_inputs: Set[pt.Array] = set() + super().__init__() + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, # type: ignore[override] + expr: pt.transform.ArrayOrNames) -> int: + return id(expr) + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def post_visit(self, expr: Any) -> None: # type: ignore[override] + if isinstance(expr, pt.Einsum): + self.collected_outputs.add(expr) + self.collected_inputs.update(expr.args) + + +def get_inputs_and_outputs_of_einsum( + expr: pt.DictOfNamedArrays) -> Tuple[FrozenSet[pt.Array], + FrozenSet[pt.Array]]: + mapper = EinsumInputOutputCollector() + mapper(expr) + return frozenset(mapper.collected_inputs), frozenset(mapper.collected_outputs) + +# }}} + + +# {{{ split_iteration_domain_across_work_items + +def get_iname_length(kernel: lp.LoopKernel, iname: str) -> Union[float, int]: + from loopy.isl_helpers import static_max_of_pw_aff + max_domain_size = static_max_of_pw_aff(kernel.get_iname_bounds(iname).size, + constants_only=False).max_val() + if max_domain_size.is_infty(): + import math + return math.inf + else: + return max_domain_size.to_python() + + +@for_each_kernel +def split_iteration_domain_across_work_items(kernel, cl_device): + # TODO: memoize_on_disk? + import loopy as lp + + ngroups = cl_device.max_compute_units * 4 # '4' to overfill the device + l_one_size = 4 + l_zero_size = 16 + inames_already_split: Set[str] = set() + + for insn in kernel.instructions: + if insn.within_inames & inames_already_split: + continue + else: + inames_already_split.update(insn.within_inames) + + if isinstance(insn, lp.CallInstruction): + # must be a callable kernel, don't touch. + pass + elif isinstance(insn, lp.Assignment): + bigger_loop = None + smaller_loop = None + + if len(insn.within_inames) == 0: + # iteration domain is singleton + continue + + if len(insn.within_inames) == 1: + iname, = insn.within_inames + + kernel = lp.split_iname(kernel, iname, + ngroups * l_zero_size * l_one_size) + kernel = lp.split_iname(kernel, f"{iname}_inner", + l_zero_size, inner_tag="l.0") + kernel = lp.split_iname(kernel, f"{iname}_inner_outer", + l_one_size, inner_tag="l.1", + outer_tag="g.0") + continue + + iname_pos_in_assignee = { + iname: insn.assignee.index_tuple.index(prim.Variable(iname)) + for iname in insn.within_inames} + + # Pick the loop with largest loop count. In case of ties, look at the + # iname position in the assignee and pick the iname indexing over + # leading axis for the work-group hardware iname. + sorted_inames = sorted(insn.within_inames, + key=lambda iname: (get_iname_length(kernel, + iname), + -iname_pos_in_assignee[iname])) + smaller_loop, bigger_loop = sorted_inames[-2], sorted_inames[-1] + + kernel = lp.split_iname(kernel, f"{bigger_loop}", + l_one_size * ngroups) + kernel = lp.split_iname(kernel, f"{bigger_loop}_inner", + l_one_size, inner_tag="l.1", outer_tag="g.0") + kernel = lp.split_iname(kernel, smaller_loop, + l_zero_size, inner_tag="l.0") + elif isinstance(insn, (lp.BarrierInstruction, lp.NoOpInstruction)): + pass + else: + raise NotImplementedError(type(insn)) + + return kernel + +# }}} + + +# {{{ add_gbarrier_between_disjoint_loop_nests + +@dataclass(frozen=True, eq=True) +class _LoopNest: + inames: FrozenSet[str] + insns_in_loop_nest: FrozenSet[str] + + +def _is_a_perfect_loop_nest(kernel: lp.LoopKernel, + inames: FrozenSet[str]) -> bool: + try: + template_iname = next(iter(inames)) + except StopIteration: + return True + else: + insn_ids_in_template_iname = kernel.iname_to_insns()[template_iname] + return all(kernel.iname_to_insns()[iname] == insn_ids_in_template_iname + for iname in inames) + + +def _get_loop_nest(kernel: lp.LoopKernel, insn: lp.InstructionBase) -> _LoopNest: + assert _is_a_perfect_loop_nest(kernel, insn.within_inames) + if insn.within_inames: + any_iname_in_nest, *other_inames = insn.within_inames + return _LoopNest(insn.within_inames, + frozenset(kernel.iname_to_insns()[any_iname_in_nest])) + else: + # we treat a loop nest with 0-depth in a special manner by putting each such + # instruction into a separate loop nest. + return _LoopNest(frozenset(), + frozenset([insn.id])) + + +@dataclass(frozen=True) +class InsnIds(lp_match.MatchExpressionBase): + insn_ids_to_match: FrozenSet[str] + + def __call__(self, kernel: lp.LoopKernel, matchable: lp.InstructionBase): + return matchable.id in self.insn_ids_to_match + + +def _get_call_kernel_insn_ids(kernel: lp.LoopKernel) -> Tuple[FrozenSet[str], ...]: + """ + Returns a sequence of collection of instruction ids where each entry in the + sequence corresponds to the instructions in a call-kernel to launch. + + In this heuristic we simply draw kernel boundaries such that instruction + belonging to disjoint loop-nest pairs are executed in different call kernels. + + .. note:: + + We require that every statement in *kernel* is nested within a perfect loop + nest. + """ + from pytools.graph import compute_topological_order + + loop_nest_dep_graph: Dict[_LoopNest, Set[_LoopNest]] = { + _get_loop_nest(kernel, insn): set() + for insn in kernel.instructions} + + for insn in kernel.instructions: + insn_loop_nest = _get_loop_nest(kernel, insn) + for dep_id in insn.depends_on: + dep_loop_nest = _get_loop_nest(kernel, kernel.id_to_insn[dep_id]) + if insn_loop_nest != dep_loop_nest: + loop_nest_dep_graph[dep_loop_nest].add(insn_loop_nest) + + # TODO: pass 'key' to compute_topological_order to ensure deterministic result + toposorted_loop_nests: List[_LoopNest] = compute_topological_order( + loop_nest_dep_graph) + + return tuple(loop_nest.insns_in_loop_nest for loop_nest in toposorted_loop_nests) + + +def add_gbarrier_between_disjoint_loop_nests( + t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + kernel = t_unit.default_entrypoint + + call_kernel_insn_ids = _get_call_kernel_insn_ids(kernel) + + for insns_before, insns_after in zip(call_kernel_insn_ids[:-1], + call_kernel_insn_ids[1:]): + kernel = lp.add_barrier(kernel, + insn_before=InsnIds(insns_before), + insn_after=InsnIds(insns_after)) + + return t_unit.with_kernel(kernel) + +# }}} + + +# {{{ global temp var aliasing for disjoint live intervals + +def alias_global_temporaries(t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + """ + Returns a copy of *t_unit* with temporaries of that have disjoint live + intervals using the same :attr:`loopy.TemporaryVariable.base_storage`. + + .. warning:: + + This routine **assumes** that the entrypoint in *t_unit* global + barriers inserted as per :func:`_get_call_kernel_insn_ids`. + """ + + from collections import defaultdict + + from loopy.kernel.data import AddressSpace + from pytools import UniqueNameGenerator + + # all loopy programs from pytato DAGs have exactly one entrypoint. + kernel = t_unit.default_entrypoint + + temp_vars = frozenset(tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL) + + call_kernel_insn_ids = _get_call_kernel_insn_ids(kernel) + temp_to_live_interval_start: Dict[str, int] = {} + temp_to_live_interval_end: Dict[str, int] = {} + + for icall_kernel, insn_ids in enumerate(call_kernel_insn_ids): + for insn_id in insn_ids: + for var in (kernel.id_to_insn[insn_id].dependency_names() + & temp_vars): + if var not in temp_to_live_interval_start: + assert var not in temp_to_live_interval_end + temp_to_live_interval_start[var] = icall_kernel + assert var in temp_to_live_interval_start + temp_to_live_interval_end[var] = icall_kernel + + vng = UniqueNameGenerator() + + # {{{ get mappings from icall_kernel to temps that are just alive or dead + + icall_kernel_to_just_live_temp_vars: List[Set[str]] = [ + set() for _ in call_kernel_insn_ids] + icall_kernel_to_just_dead_temp_vars: List[Set[str]] = [ + set() for _ in call_kernel_insn_ids] + + for tv_name, just_alive_idx in temp_to_live_interval_start.items(): + icall_kernel_to_just_live_temp_vars[just_alive_idx].add(tv_name) + + for tv_name, just_dead_idx in temp_to_live_interval_end.items(): + if just_dead_idx != (len(call_kernel_insn_ids) - 1): + # we ignore the temporaries that died at the last kernel since we cannot + # reclaim their memory + icall_kernel_to_just_dead_temp_vars[just_dead_idx+1].add(tv_name) + + # }}} + + new_tvs: Dict[str, lp.TemporaryVariable] = {} + # a mapping from shape to the available base storages from temp variables + # that were dead. + shape_to_available_base_storage: Dict[int, Set[str]] = defaultdict(set) + + for icall_kernel, _ in enumerate(call_kernel_insn_ids): + just_dead_temps = icall_kernel_to_just_dead_temp_vars[icall_kernel] + to_be_allocated_temps = icall_kernel_to_just_live_temp_vars[icall_kernel] + + # reclaim base storage from the dead temporaries + for tv_name in sorted(just_dead_temps): + tv = new_tvs[tv_name] + assert tv.base_storage is not None + assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] + shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) + + # assign base storages to 'to_be_allocated_temps' + for tv_name in sorted(to_be_allocated_temps): + tv = kernel.temporary_variables[tv_name] + assert tv.name not in new_tvs + assert tv.base_storage is None + if shape_to_available_base_storage[tv.nbytes]: + base_storage = sorted(shape_to_available_base_storage[tv.nbytes])[0] + shape_to_available_base_storage[tv.nbytes].remove(base_storage) + else: + base_storage = vng("_actx_tmp_base") + + new_tvs[tv.name] = tv.copy(base_storage=base_storage) + + for name, tv in kernel.temporary_variables.items(): + if tv.address_space != AddressSpace.GLOBAL: + new_tvs[name] = tv + else: + pass + + kernel = kernel.copy(temporary_variables=new_tvs) + + old_tmp_mem_requirement = sum( + tv.nbytes + for tv in kernel.temporary_variables.values()) + new_tmp_mem_requirement = sum( + {tv.base_storage: tv.nbytes + for tv in kernel.temporary_variables.values()}.values()) + logger.info( + "[alias_global_temporaries]: Reduced memory requirement from " + f"{old_tmp_mem_requirement*1e-6:.1f}MB to" + f" {new_tmp_mem_requirement*1e-6:.1f}MB.") + + return t_unit.with_kernel(kernel) + +# }}} + +# vim: fdm=marker