Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use hash of comm_tag if not numeric #222

Merged
merged 6 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,7 @@ def main(ctx_factory, dim=2, order=3,
else:
actx = actx_class(comm, queue,
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)),
force_device_scalars=True,
comm_tag_to_mpi_tag={
_WaveStateTag: 1234,
})
force_device_scalars=True)

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
mesh_dist = MPIMeshDistributor(comm)
Expand Down
19 changes: 3 additions & 16 deletions grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
# {{{ imports

from typing import (
TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional,
Hashable, Type)
TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type)
from dataclasses import dataclass

from meshmode.array_context import (
Expand Down Expand Up @@ -240,36 +239,24 @@ def __init__(self,
queue: "pyopencl.CommandQueue",
*, allocator: Optional["pyopencl.tools.AllocatorInterface"] = None,
wait_event_queue_length: Optional[int] = None,
force_device_scalars: bool = False,
comm_tag_to_mpi_tag: Optional[Mapping[Hashable, int]] = None) -> None:
force_device_scalars: bool = False) -> None:
"""
See :class:`arraycontext.impl.pyopencl.PyOpenCLArrayContext` for most
arguments.

:arg comm_tag_to_mpi_tag: A mapping from symbolic tags used
in the *comm_tag* argument of
:func:`grudge.trace_pair.cross_rank_trace_pairs` to numeric values
to be used with MPI.
"""
super().__init__(queue, allocator=allocator,
wait_event_queue_length=wait_event_queue_length,
force_device_scalars=force_device_scalars)

self.mpi_communicator = mpi_communicator

if comm_tag_to_mpi_tag is None:
comm_tag_to_mpi_tag = {}

self.comm_tag_to_mpi_tag = comm_tag_to_mpi_tag

def clone(self):
# type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member
# pylint: disable=no-member
return type(self)(self.mpi_communicator, self.queue,
allocator=self.allocator,
wait_event_queue_length=self._wait_event_queue_length,
force_device_scalars=self._force_device_scalars,
comm_tag_to_mpi_tag=self.comm_tag_to_mpi_tag)
force_device_scalars=self._force_device_scalars)

# }}}

Expand Down
22 changes: 13 additions & 9 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,20 @@ def cross_rank_trace_pairs(
if isinstance(comm_tag, int):
num_tag = comm_tag

from grudge.array_context import MPIPyOpenCLArrayContext
if isinstance(actx, MPIPyOpenCLArrayContext):
num_tag = actx.comm_tag_to_mpi_tag.get(comm_tag)

if num_tag is None:
raise ValueError("Encountered unknown symbolic tag "
f"'{comm_tag}'. To make this symbolic tag work, "
f"use 'grudge.array_context.MPIPyOpenCLArrayContext' and "
"assign this tag a numerical value via its "
"comm_tag_to_mpi_tag attribute.")
# FIXME: This isn't guaranteed to be correct.
# See here for discussion:
# - https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 # noqa
# - https://github.com/inducer/grudge/pull/222
from mpi4py import MPI
tag_ub = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)
matthiasdiener marked this conversation as resolved.
Show resolved Hide resolved
num_tag = hash(comm_tag) % tag_ub

from warnings import warn
warn("Encountered unknown symbolic tag "
f"'{comm_tag}', assigning a value of '{num_tag}'. "
"This is a temporary workaround, please ensure that "
"tags are sufficiently distinct for your use case.")

comm_tag = num_tag

Expand Down
8 changes: 5 additions & 3 deletions test/test_mpi_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@
from pytools.obj_array import flat_obj_array

import grudge.op as op
from testlib import SimpleTag


class SimpleTag:
pass


# {{{ mpi test infrastructure
Expand Down Expand Up @@ -86,8 +89,7 @@ def run_test_with_mpi_inner():
if actx_class is MPIPytatoArrayContext:
actx = actx_class(comm, queue, mpi_base_tag=15000)
elif actx_class is MPIPyOpenCLArrayContext:
actx = actx_class(comm, queue, force_device_scalars=True,
comm_tag_to_mpi_tag={SimpleTag: 15000})
actx = actx_class(comm, queue, force_device_scalars=True)
else:
raise ValueError("unknown actx_class")

Expand Down
5 changes: 0 additions & 5 deletions test/testlib.py

This file was deleted.