Skip to content

Commit

Permalink
Mesh Distribution: determinism fixes (#416)
Browse files Browse the repository at this point in the history
* Mesh Distribution: determinism fixes

* misc fixes

* restore old-timey interface (needed for test)

* use dict for pending_recv_identifiers

* add some simple determinism tests

* add comment regarding deterministic order

* reset example type ignore
  • Loading branch information
matthiasdiener authored Jul 2, 2024
1 parent 6bb638f commit c316e23
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 36 deletions.
8 changes: 6 additions & 2 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ class DiscretizationConnectionElementGroup:
def __init__(self, batches):
self.batches = batches

def __repr__(self):
return f"{type(self).__name__}({self.batches})"

# }}}


Expand Down Expand Up @@ -488,9 +491,10 @@ def _per_target_group_pick_info(
if batch.from_group_index == source_group_index]

# {{{ find and weed out duplicate dof pick lists
from pytools import unique

dof_pick_lists = list({tuple(batch_dof_pick_lists[bi])
for bi in batch_indices_for_this_source_group})
dof_pick_lists = list(unique(tuple(batch_dof_pick_lists[bi])
for bi in batch_indices_for_this_source_group))
dof_pick_list_to_index = {
p_ind: i for i, p_ind in enumerate(dof_pick_lists)}
# shape: (number of pick lists, nunit_dofs_tgt)
Expand Down
65 changes: 39 additions & 26 deletions meshmode/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Mapping, Sequence, Set, Union, cast
from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence, Union, cast
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -239,11 +239,6 @@ def __init__(self,
],
bdry_grp_factory: ElementGroupFactory):
"""
:arg local_bdry_conns: A :class:`dict` mapping remote part to
`local_bdry_conn`, where `local_bdry_conn` is a
:class:`~meshmode.discretization.connection.DirectDiscretizationConnection`
that performs data exchange from the volume to the faces adjacent to
part `i_remote_part`.
:arg bdry_grp_factory: Group factory to use when creating the remote-to-local
boundary connections
"""
Expand Down Expand Up @@ -290,8 +285,11 @@ def __enter__(self):

# to know when we're done
self.pending_recv_identifiers = {
(irbi.local_part_id, irbi.remote_part_id)
for irbi in self.inter_rank_bdry_info}
(irbi.local_part_id, irbi.remote_part_id): i
for i, irbi in enumerate(self.inter_rank_bdry_info)}

assert len(self.pending_recv_identifiers) \
== len(self.inter_rank_bdry_info)

self.send_reqs = [
self._internal_mpi_comm.isend(
Expand Down Expand Up @@ -327,14 +325,22 @@ def complete_some(self):

status = MPI.Status()

# Wait for any receive
data = [self._internal_mpi_comm.recv(status=status)]
source_ranks = [status.source]

# Complete any other available receives while we're at it
while self._internal_mpi_comm.iprobe():
data.append(self._internal_mpi_comm.recv(status=status))
source_ranks.append(status.source)
# Wait for all receives
# Note: This is inefficient, but ensures a deterministic order of
# boundary setup.
nrecvs = len(self.pending_recv_identifiers)
data = [None] * nrecvs
source_ranks = [None] * nrecvs

while nrecvs > 0:
r = self._internal_mpi_comm.recv(status=status)
key = (r[1], r[0])
loc = self.pending_recv_identifiers[key]
assert data[loc] is None
assert source_ranks[loc] is None
data[loc] = r
source_ranks[loc] = status.source
nrecvs -= 1

remote_to_local_bdry_conns = {}

Expand Down Expand Up @@ -372,11 +378,11 @@ def complete_some(self):
group_factory=self.bdry_grp_factory),
remote_group_infos=remote_group_infos))

self.pending_recv_identifiers.remove((local_part_id, remote_part_id))
del self.pending_recv_identifiers[(local_part_id, remote_part_id)]

if not self.pending_recv_identifiers:
MPI.Request.waitall(self.send_reqs)
logger.info("bdry comm rank %d comm end", self.i_local_rank)
assert not self.pending_recv_identifiers
MPI.Request.waitall(self.send_reqs)
logger.info("bdry comm rank %d comm end", self.i_local_rank)

return remote_to_local_bdry_conns

Expand Down Expand Up @@ -432,30 +438,37 @@ def get_partition_by_pymetis(mesh, num_parts, *, connectivity="facial", **kwargs
return np.array(p)


def membership_list_to_map(membership_list):
def membership_list_to_map(
membership_list: np.ndarray[Any, Any]
) -> Mapping[Hashable, np.ndarray]:
"""
Convert a :class:`numpy.ndarray` that maps an index to a key into a
:class:`dict` that maps a key to a set of indices (with each set of indices
stored as a sorted :class:`numpy.ndarray`).
"""
from pytools import unique

# FIXME: not clear why the sorted() call is necessary here
return {
entry: np.where(membership_list == entry)[0]
for entry in set(membership_list)}
for entry in sorted(unique(membership_list))}


# FIXME: Move somewhere else, since it's not strictly limited to distributed?
def get_connected_parts(mesh: Mesh) -> "Set[PartID]":
def get_connected_parts(mesh: Mesh) -> "Sequence[PartID]":
"""For a local mesh part in *mesh*, determine the set of connected parts."""
assert mesh.facial_adjacency_groups is not None

return {
from pytools import unique

return tuple(unique(
grp.part_id
for fagrp_list in mesh.facial_adjacency_groups
for grp in fagrp_list
if isinstance(grp, InterPartAdjacencyGroup)}
if isinstance(grp, InterPartAdjacencyGroup)))


def get_connected_partitions(mesh: Mesh) -> "Set[PartID]":
def get_connected_partitions(mesh: Mesh) -> "Sequence[PartID]":
warn(
"get_connected_partitions is deprecated and will stop working in June 2023. "
"Use get_connected_parts instead.", DeprecationWarning, stacklevel=2)
Expand Down
15 changes: 9 additions & 6 deletions meshmode/mesh/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dataclasses import dataclass, replace
from functools import reduce
from typing import (
Callable, Dict, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Union)
Callable, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union)

import numpy as np
import numpy.linalg as la
Expand Down Expand Up @@ -184,7 +184,7 @@ def _get_connected_parts(
mesh: Mesh,
part_id_to_part_index: Mapping[PartID, int],
global_elem_to_part_elem: np.ndarray,
self_part_id: PartID) -> Set[PartID]:
self_part_id: PartID) -> Sequence[PartID]:
"""
Find the parts that are connected to the current part.
Expand All @@ -196,10 +196,11 @@ def _get_connected_parts(
:func:`_compute_global_elem_to_part_elem`` for details.
:arg self_part_id: The identifier of the part currently being created.
:returns: A :class:`set` of identifiers of the neighboring parts.
:returns: A sequence of identifiers of the neighboring parts.
"""
self_part_index = part_id_to_part_index[self_part_id]

# This set is not used in a way that will cause nondeterminism.
connected_part_indices = set()

for igrp, facial_adj_list in enumerate(mesh.facial_adjacency_groups):
Expand All @@ -223,10 +224,12 @@ def _get_connected_parts(
elements_are_self & neighbors_are_other]
+ elem_base_j, 0])

return {
result = tuple(
part_id
for part_id, part_index in part_id_to_part_index.items()
if part_index in connected_part_indices}
if part_index in connected_part_indices)
assert len(set(result)) == len(result)
return result


def _create_self_to_self_adjacency_groups(
Expand Down Expand Up @@ -305,7 +308,7 @@ def _create_self_to_other_adjacency_groups(
self_part_id: PartID,
self_mesh_groups: List[MeshElementGroup],
self_mesh_group_elem_base: List[int],
connected_parts: Set[PartID]) -> List[List[InterPartAdjacencyGroup]]:
connected_parts: Sequence[PartID]) -> List[List[InterPartAdjacencyGroup]]:
"""
Create self-to-other adjacency groups for the partitioned mesh.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def main():
"numpy",
"modepy>=2020.2",
"gmsh_interop",
"pytools>=2020.4.1",
"pytools>=2024.1.1",

# 2019.1 is required for the Firedrake CIs, which use an very specific
# version of Loopy.
Expand Down
9 changes: 8 additions & 1 deletion test/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ def _test_mpi_boundary_swap(dim, order, num_groups):
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
np.random.randint(mpi_comm.size, size=mesh.nelements)))

assert list(part_id_to_part.keys()) == list(range(mpi_comm.size))
parts = [part_id_to_part[i] for i in range(mpi_comm.size)]

local_mesh = mpi_comm.scatter(parts)
Expand Down Expand Up @@ -424,6 +426,11 @@ def _test_mpi_boundary_swap(dim, order, num_groups):
conns = bdry_setup_helper.complete_some()
if not conns:
break

expected_keys = list(range(mpi_comm.size))
expected_keys.remove(mpi_comm.rank)
assert list(conns.keys()) == expected_keys

for i_remote_part, conn in conns.items():
check_connection(actx, conn)
remote_to_local_bdry_conns[i_remote_part] = conn
Expand Down Expand Up @@ -455,7 +462,7 @@ def _test_connected_parts(mpi_comm, connected_parts):
for i_remote_part in range(num_parts):
if all_connected_masks[i_remote_part][mpi_comm.rank]:
parts_connected_to_me.add(i_remote_part)
assert parts_connected_to_me == connected_parts
assert parts_connected_to_me == set(connected_parts)


# TODO
Expand Down

0 comments on commit c316e23

Please sign in to comment.