From 644e631b57fb976f8116673626d0eac4e6fa6b53 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 24 Jun 2024 16:45:37 -0500 Subject: [PATCH] Mesh Distribution: determinism fixes --- examples/parallel-vtkhdf.py | 3 +- meshmode/discretization/connection/direct.py | 8 +- meshmode/distributed.py | 97 ++++++++------------ meshmode/mesh/processing.py | 15 +-- setup.py | 2 +- 5 files changed, 55 insertions(+), 70 deletions(-) diff --git a/examples/parallel-vtkhdf.py b/examples/parallel-vtkhdf.py index 0d93ac7f..a928176f 100644 --- a/examples/parallel-vtkhdf.py +++ b/examples/parallel-vtkhdf.py @@ -56,8 +56,7 @@ def main(*, ambient_dim: int) -> None: parts = [part_id_to_part[i] for i in range(comm.size)] local_mesh = comm.scatter(parts) else: - # Reason for type-ignore: presumed faulty type annotation in mpi4py - local_mesh = comm.scatter(None) # type: ignore[arg-type] + local_mesh = comm.scatter(None) logger.info("[%4d] distributing mesh: finished", comm.rank) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 49ff8c97..d466c0d5 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -249,6 +249,9 @@ class DiscretizationConnectionElementGroup: def __init__(self, batches): self.batches = batches + def __repr__(self): + return f"{type(self).__name__}({self.batches})" + # }}} @@ -488,9 +491,10 @@ def _per_target_group_pick_info( if batch.from_group_index == source_group_index] # {{{ find and weed out duplicate dof pick lists + from pytools import unique - dof_pick_lists = list({tuple(batch_dof_pick_lists[bi]) - for bi in batch_indices_for_this_source_group}) + dof_pick_lists = list(unique(tuple(batch_dof_pick_lists[bi]) + for bi in batch_indices_for_this_source_group)) dof_pick_list_to_index = { p_ind: i for i, p_ind in enumerate(dof_pick_lists)} # shape: (number of pick lists, nunit_dofs_tgt) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index ee0403ab..4e98bec9 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -36,7 +36,7 @@ """ from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Mapping, Sequence, Set, Union, cast +from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence from warnings import warn import numpy as np @@ -231,19 +231,9 @@ class MPIBoundaryCommSetupHelper: def __init__(self, mpi_comm: "mpi4py.MPI.Intracomm", actx: ArrayContext, - inter_rank_bdry_info: Union[ - # new-timey - Sequence[InterRankBoundaryInfo], - # old-timey, for compatibility - Mapping[int, DirectDiscretizationConnection], - ], + inter_rank_bdry_info: Sequence[InterRankBoundaryInfo], bdry_grp_factory: ElementGroupFactory): """ - :arg local_bdry_conns: A :class:`dict` mapping remote part to - `local_bdry_conn`, where `local_bdry_conn` is a - :class:`~meshmode.discretization.connection.DirectDiscretizationConnection` - that performs data exchange from the volume to the faces adjacent to - part `i_remote_part`. :arg bdry_grp_factory: Group factory to use when creating the remote-to-local boundary connections """ @@ -251,30 +241,7 @@ def __init__(self, self.array_context = actx self.i_local_rank = mpi_comm.Get_rank() - # {{{ normalize inter_rank_bdry_info - - self._using_old_timey_interface = False - - if isinstance(inter_rank_bdry_info, dict): - self._using_old_timey_interface = True - warn("Using the old-timey interface of MPIBoundaryCommSetupHelper. " - "That's deprecated and will stop working in July 2022. " - "Use the currently documented interface instead.", - DeprecationWarning, stacklevel=2) - - inter_rank_bdry_info = [ - InterRankBoundaryInfo( - local_part_id=self.i_local_rank, - remote_part_id=remote_rank, - remote_rank=remote_rank, - local_boundary_connection=conn - ) - for remote_rank, conn in inter_rank_bdry_info.items()] - - # }}} - - self.inter_rank_bdry_info = cast( - Sequence[InterRankBoundaryInfo], inter_rank_bdry_info) + self.inter_rank_bdry_info = inter_rank_bdry_info self.bdry_grp_factory = bdry_grp_factory @@ -289,9 +256,13 @@ def __enter__(self): # the pickling ourselves. # to know when we're done - self.pending_recv_identifiers = { + self.pending_recv_identifiers = [ (irbi.local_part_id, irbi.remote_part_id) - for irbi in self.inter_rank_bdry_info} + for irbi in self.inter_rank_bdry_info] + + assert len(self.pending_recv_identifiers) \ + == len(self.inter_rank_bdry_info) \ + == len(set(self.pending_recv_identifiers)) self.send_reqs = [ self._internal_mpi_comm.isend( @@ -327,14 +298,20 @@ def complete_some(self): status = MPI.Status() - # Wait for any receive - data = [self._internal_mpi_comm.recv(status=status)] - source_ranks = [status.source] - - # Complete any other available receives while we're at it - while self._internal_mpi_comm.iprobe(): - data.append(self._internal_mpi_comm.recv(status=status)) - source_ranks.append(status.source) + # Wait for all receives + nrecvs = len(self.pending_recv_identifiers) + data = [None] * nrecvs + source_ranks = [None] * nrecvs + + while nrecvs > 0: + r = self._internal_mpi_comm.recv(status=status) + key = (r[1], r[0]) + loc = self.pending_recv_identifiers.index(key) + assert data[loc] is None + assert source_ranks[loc] is None + data[loc] = r + source_ranks[loc] = status.source + nrecvs -= 1 remote_to_local_bdry_conns = {} @@ -357,10 +334,7 @@ def complete_some(self): irbi = part_ids_to_irbi[local_part_id, remote_part_id] assert i_src_rank == irbi.remote_rank - if self._using_old_timey_interface: - key = remote_part_id - else: - key = (remote_part_id, local_part_id) + key = (remote_part_id, local_part_id) remote_to_local_bdry_conns[key] = ( make_partition_connection( @@ -374,9 +348,9 @@ def complete_some(self): self.pending_recv_identifiers.remove((local_part_id, remote_part_id)) - if not self.pending_recv_identifiers: - MPI.Request.waitall(self.send_reqs) - logger.info("bdry comm rank %d comm end", self.i_local_rank) + assert not self.pending_recv_identifiers + MPI.Request.waitall(self.send_reqs) + logger.info("bdry comm rank %d comm end", self.i_local_rank) return remote_to_local_bdry_conns @@ -432,30 +406,35 @@ def get_partition_by_pymetis(mesh, num_parts, *, connectivity="facial", **kwargs return np.array(p) -def membership_list_to_map(membership_list): +def membership_list_to_map( + membership_list: np.ndarray[Any, Any] + ) -> Mapping[Hashable, np.ndarray]: """ Convert a :class:`numpy.ndarray` that maps an index to a key into a :class:`dict` that maps a key to a set of indices (with each set of indices stored as a sorted :class:`numpy.ndarray`). """ + from pytools import unique return { entry: np.where(membership_list == entry)[0] - for entry in set(membership_list)} + for entry in unique(list(membership_list))} # FIXME: Move somewhere else, since it's not strictly limited to distributed? -def get_connected_parts(mesh: Mesh) -> "Set[PartID]": +def get_connected_parts(mesh: Mesh) -> "Sequence[PartID]": """For a local mesh part in *mesh*, determine the set of connected parts.""" assert mesh.facial_adjacency_groups is not None - return { + from pytools import unique + + return tuple(unique(list( grp.part_id for fagrp_list in mesh.facial_adjacency_groups for grp in fagrp_list - if isinstance(grp, InterPartAdjacencyGroup)} + if isinstance(grp, InterPartAdjacencyGroup)))) -def get_connected_partitions(mesh: Mesh) -> "Set[PartID]": +def get_connected_partitions(mesh: Mesh) -> "Sequence[PartID]": warn( "get_connected_partitions is deprecated and will stop working in June 2023. " "Use get_connected_parts instead.", DeprecationWarning, stacklevel=2) diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py index 1b2a64dd..c28877da 100644 --- a/meshmode/mesh/processing.py +++ b/meshmode/mesh/processing.py @@ -25,7 +25,7 @@ from dataclasses import dataclass, replace from functools import reduce from typing import ( - Callable, Dict, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Union) + Callable, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union) import numpy as np import numpy.linalg as la @@ -184,7 +184,7 @@ def _get_connected_parts( mesh: Mesh, part_id_to_part_index: Mapping[PartID, int], global_elem_to_part_elem: np.ndarray, - self_part_id: PartID) -> Set[PartID]: + self_part_id: PartID) -> Sequence[PartID]: """ Find the parts that are connected to the current part. @@ -196,10 +196,11 @@ def _get_connected_parts( :func:`_compute_global_elem_to_part_elem`` for details. :arg self_part_id: The identifier of the part currently being created. - :returns: A :class:`set` of identifiers of the neighboring parts. + :returns: A sequence of identifiers of the neighboring parts. """ self_part_index = part_id_to_part_index[self_part_id] + # This set is not used in a way that will cause nondeterminism. connected_part_indices = set() for igrp, facial_adj_list in enumerate(mesh.facial_adjacency_groups): @@ -223,10 +224,12 @@ def _get_connected_parts( elements_are_self & neighbors_are_other] + elem_base_j, 0]) - return { + result = tuple( part_id for part_id, part_index in part_id_to_part_index.items() - if part_index in connected_part_indices} + if part_index in connected_part_indices) + assert len(set(result)) == len(result) + return result def _create_self_to_self_adjacency_groups( @@ -305,7 +308,7 @@ def _create_self_to_other_adjacency_groups( self_part_id: PartID, self_mesh_groups: List[MeshElementGroup], self_mesh_group_elem_base: List[int], - connected_parts: Set[PartID]) -> List[List[InterPartAdjacencyGroup]]: + connected_parts: Sequence[PartID]) -> List[List[InterPartAdjacencyGroup]]: """ Create self-to-other adjacency groups for the partitioned mesh. diff --git a/setup.py b/setup.py index c8a1d7e8..a78d0afb 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def main(): "numpy", "modepy>=2020.2", "gmsh_interop", - "pytools>=2020.4.1", + "pytools>=2024.1.1", # 2019.1 is required for the Firedrake CIs, which use an very specific # version of Loopy.