Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Eager tag negoatiation #233

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
# {{{ imports

from typing import (
TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type)
TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type,
Dict)
from dataclasses import dataclass

from meshmode.array_context import (
Expand Down Expand Up @@ -63,6 +64,8 @@
import pyopencl.tools
from mpi4py import MPI

from grudge.trace_pair import CommunicationTag


class PyOpenCLArrayContext(_PyOpenCLArrayContextBase):
"""Inherits from :class:`meshmode.array_context.PyOpenCLArrayContext`. Extends it
Expand Down Expand Up @@ -233,13 +236,18 @@ class MPIPyOpenCLArrayContext(PyOpenCLArrayContext, MPIBasedArrayContext):

.. autofunction:: __init__
"""
_source_rank_sym_tag_to_num_tag: Dict[Tuple[int, CommunicationTag], int]
_dest_rank_sym_tag_to_num_tag: Dict[Tuple[int, CommunicationTag], int]
_dest_rank_to_taken_num_tag: Dict[int, int]
mpi_base_tag: int

def __init__(self,
mpi_communicator,
queue: "pyopencl.CommandQueue",
*, allocator: Optional["pyopencl.tools.AllocatorInterface"] = None,
wait_event_queue_length: Optional[int] = None,
force_device_scalars: bool = False) -> None:
force_device_scalars: bool = False,
mpi_base_tag: int) -> None:
"""
See :class:`arraycontext.impl.pyopencl.PyOpenCLArrayContext` for most
arguments.
Expand All @@ -250,13 +258,20 @@ def __init__(self,

self.mpi_communicator = mpi_communicator

self.mpi_base_tag = mpi_base_tag

self._source_rank_sym_tag_to_num_tag = {}
self._dest_rank_sym_tag_to_num_tag = {}
self._dest_rank_to_next_num_tag = {}

def clone(self):
# type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member
# pylint: disable=no-member
return type(self)(self.mpi_communicator, self.queue,
allocator=self.allocator,
wait_event_queue_length=self._wait_event_queue_length,
force_device_scalars=self._force_device_scalars)
force_device_scalars=self._force_device_scalars,
mpi_base_tag=self.mpi_base_tag)

# }}}

Expand Down
153 changes: 105 additions & 48 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@
"""


from typing import List, Hashable, Optional, Type, Any

from pytools.persistent_dict import KeyBuilder
from typing import List, Hashable, Dict, Tuple, TYPE_CHECKING, Callable

from arraycontext import (
ArrayContainer,
Expand All @@ -75,6 +73,9 @@
import numpy as np
import grudge.dof_desc as dof_desc

if TYPE_CHECKING:
import mpi4py.MPI


# {{{ trace pair container class

Expand Down Expand Up @@ -310,27 +311,102 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
# }}}


# {{{ distributed-memory functionality
# {{{ generic distributed support

CommunicationTag = Hashable


@memoize_on_first_arg
def connected_ranks(dcoll: DiscretizationCollection):
from meshmode.distributed import get_connected_partitions
return get_connected_partitions(dcoll._volume_discr.mesh)

# }}}


# {{{ eager distributed

@dataclass
class _EagerMPITags:
send_mpi_tag: int
recv_mpi_tag: int


@dataclass
class _EagerMPIState:
mpi_communicator: "mpi4py.MPI.Comm"

class _RankBoundaryCommunication:
base_comm_tag = 1273
# base_tag is used for tag
tag_negotiation_tag: int
first_assignable_tag: int

source_rank_sym_tag_to_num_tag: Dict[Tuple[int, CommunicationTag], int]
dest_rank_sym_tag_to_num_tag: Dict[Tuple[int, CommunicationTag], int]
dest_rank_to_taken_num_tag: Dict[int, int]


class _EagerSymbolicTagNegotiator:
# You may ask: why do we need to communicate at all to agree
# on tag mappings? Imagine the case where different ranks
# hit different tags in different order. (Well, as long
# as we're expecting eager comm to complete inside of
# cross_rank_trace_pairs, I guess that would deadlock.
# But still.)

def __init__(self, eager_mpi_state: _EagerMPIState, sym_tag: CommunicationTag,
remote_rank: int,
continuation: Callable[
[_EagerMPITags], "_RankBoundaryCommunicationEager"]):
self.eager_mpi_state = eager_mpi_state
self.sym_tag = sym_tag
self.remote_rank = remote_rank
self.continuation = continuation

rank_n_tag = (remote_rank, sym_tag)
assert rank_n_tag not in eager_mpi_state.source_rank_sym_tag_to_num_tag
assert rank_n_tag not in eager_mpi_state.dest_rank_sym_tag_to_num_tag

self.send_num_tag = eager_mpi_state.dest_rank_to_taken_num_tag.setdefault(
remote_rank, eager_mpi_state.first_assignable_tag)
eager_mpi_state.dest_rank_sym_tag_to_num_tag[rank_n_tag] = self.send_num_tag

comm = eager_mpi_state.mpi_communicator
self.send_req = comm.isend((sym_tag, self.send_num_tag),
remote_rank, tag=eager_mpi_state.tag_negotiation_tag)
self.recv_req = comm.irecv(
remote_rank, tag=eager_mpi_state.tag_negotiation_tag)

def finish(self):
recv_sym_tag: CommunicationTag
recv_num_tag: int
recv_sym_tag, recv_num_tag = self.recv_req.wait()
self.send_req.wait()

self.eager_mpi_state.source_rank_sym_tag_to_num_tag[
self.remote_rank, recv_sym_tag] = recv_num_tag

# FIXME This asserts that the whole tag negotiation process
# is pointless. Unless there is a way to have eager communication
# for more than one tag pending at the same time (which, for now,
# there isn't), this whole endeavor is thoroughly unnecessary.
assert recv_sym_tag == self.sym_tag

return self.continuation(_EagerMPITags(
send_mpi_tag=self.send_num_tag, recv_mpi_tag=recv_num_tag))


class _RankBoundaryCommunicationEager:
def __init__(self,
dcoll: DiscretizationCollection,
array_container: ArrayOrContainerT,
remote_rank, comm_tag: Optional[int] = None):
mpi_communicator,
dcoll: DiscretizationCollection,
array_container: ArrayOrContainerT,
*, remote_rank: int, send_mpi_tag: int, recv_mpi_tag: int):
actx = get_container_context_recursively(array_container)
assert actx is not None

btag = BTAG_PARTITION(remote_rank)

local_bdry_data = project(dcoll, "vol", btag, array_container)
comm = dcoll.mpi_communicator

self.dcoll = dcoll
self.array_context = actx
self.remote_btag = btag
Expand All @@ -339,10 +415,6 @@ def __init__(self,
self.local_bdry_data_np = \
to_numpy(flatten(self.local_bdry_data, actx), actx)

self.comm_tag = self.base_comm_tag
if comm_tag is not None:
self.comm_tag += comm_tag

# Here, we initialize both send and recieve operations through
# mpi4py `Request` (MPI_Request) instances for comm.Isend (MPI_Isend)
# and comm.Irecv (MPI_Irecv) respectively. These initiate non-blocking
Expand All @@ -364,11 +436,11 @@ def __init__(self,
# as well, just in case.
self.send_req = comm.Isend(self.local_bdry_data_np,
remote_rank,
tag=self.comm_tag)
tag=mpi_tag)
self.remote_data_host_numpy = np.empty_like(self.local_bdry_data_np)
self.recv_req = comm.Irecv(self.remote_data_host_numpy,
remote_rank,
tag=self.comm_tag)
tag=mpi_tag)

def finish(self):
# Wait for the nonblocking receive request to complete before
Expand All @@ -393,15 +465,18 @@ def finish(self):
interior=self.local_bdry_data,
exterior=swapped_remote_bdry_data)

# }}}

from pytato import make_distributed_recv, staple_distributed_send

# {{{ lazy distributed

class _RankBoundaryCommunicationLazy:
def __init__(self,
dcoll: DiscretizationCollection,
array_container: ArrayOrContainerT,
remote_rank: int, comm_tag: Hashable):
remote_rank: int, comm_tag: CommunicationTag):
from pytato import make_distributed_recv, staple_distributed_send

if comm_tag is None:
raise ValueError("lazy communication requires 'tag' to be supplied")

Expand Down Expand Up @@ -433,16 +508,15 @@ def finish(self):
interior=self.local_bdry_data,
exterior=bdry_conn(self.remote_data))

# }}}

class _TagKeyBuilder(KeyBuilder):
def update_for_type(self, key_hash, key: Type[Any]):
self.rec(key_hash, (key.__module__, key.__name__, key.__name__,))

# {{{ cross_rank_trace_pairs

def cross_rank_trace_pairs(
dcoll: DiscretizationCollection, ary,
comm_tag: Hashable = None,
tag: Hashable = None) -> List[TracePair]:
comm_tag: CommunicationTag = None,
tag: CommunicationTag = None) -> List[TracePair]:
r"""Get a :class:`list` of *ary* trace pairs for each partition boundary.

For each partition boundary, the field data values in *ary* are
Expand Down Expand Up @@ -481,6 +555,12 @@ def cross_rank_trace_pairs(
comm_tag = tag
del tag

# {{{


# }}}


if isinstance(ary, Number):
# NOTE: Assumed that the same number is passed on every rank
return [TracePair(BTAG_PARTITION(remote_rank), interior=ary, exterior=ary)
Expand All @@ -493,30 +573,7 @@ def cross_rank_trace_pairs(
if isinstance(actx, MPIPytatoArrayContextBase):
rbc = _RankBoundaryCommunicationLazy
else:
rbc = _RankBoundaryCommunication
if comm_tag is not None:
num_tag: Optional[int] = None
if isinstance(comm_tag, int):
num_tag = comm_tag

if num_tag is None:
# FIXME: This isn't guaranteed to be correct.
# See here for discussion:
# - https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 # noqa
# - https://github.com/inducer/grudge/pull/222
from mpi4py import MPI
tag_ub = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)
key_builder = _TagKeyBuilder()
digest = key_builder(comm_tag)
num_tag = sum(ord(ch) << i for i, ch in enumerate(digest)) % tag_ub

from warnings import warn
warn("Encountered unknown symbolic tag "
f"'{comm_tag}', assigning a value of '{num_tag}'. "
"This is a temporary workaround, please ensure that "
"tags are sufficiently distinct for your use case.")

comm_tag = num_tag
rbc = partial(_RankBoundaryCommunicationEager,

# Initialize and post all sends/receives
rank_bdry_communcators = [
Expand Down