From 94c695eacc08969e4ed7f96526d36904e820a05f Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 15 Aug 2024 22:39:25 -0500 Subject: [PATCH] get mpi_communicator from actx instead of dcoll --- grudge/reductions.py | 30 +++++++++++++++++++----------- grudge/trace_pair.py | 2 +- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/grudge/reductions.py b/grudge/reductions.py index 48780078..6dcde131 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -80,6 +80,7 @@ from pytools import memoize_in import grudge.dof_desc as dof_desc +from grudge.array_context import MPIBasedArrayContext from grudge.discretization import DiscretizationCollection @@ -128,16 +129,17 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Scalar: :class:`~arraycontext.ArrayContainer`. :returns: a device scalar denoting the nodal sum. """ - comm = dcoll.mpi_communicator - if comm is None: + from arraycontext import get_container_context_recursively + actx = get_container_context_recursively(vec) + + if not isinstance(actx, MPIBasedArrayContext): return nodal_sum_loc(dcoll, dd, vec) + comm = actx.mpi_communicator + # NOTE: Do not move, we do not want to import mpi4py in single-rank computations from mpi4py import MPI - from arraycontext import get_container_context_recursively - actx = get_container_context_recursively(vec) - return actx.from_numpy( comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM)) @@ -174,13 +176,16 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal :arg initial: an optional initial value. Defaults to `numpy.inf`. :returns: a device scalar denoting the nodal minimum. """ - comm = dcoll.mpi_communicator - if comm is None: + from arraycontext import get_container_context_recursively + actx = get_container_context_recursively(vec) + + if not isinstance(actx, MPIBasedArrayContext): return nodal_min_loc(dcoll, dd, vec, initial=initial) + comm = actx.mpi_communicator + # NOTE: Do not move, we do not want to import mpi4py in single-rank computations from mpi4py import MPI - actx = vec.array_context return actx.from_numpy( comm.allreduce( @@ -231,13 +236,16 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal :arg initial: an optional initial value. Defaults to `-numpy.inf`. :returns: a device scalar denoting the nodal maximum. """ - comm = dcoll.mpi_communicator - if comm is None: + from arraycontext import get_container_context_recursively + actx = get_container_context_recursively(vec) + + if not isinstance(actx, MPIBasedArrayContext): return nodal_max_loc(dcoll, dd, vec, initial=initial) + comm = actx.mpi_communicator + # NOTE: Do not move, we do not want to import mpi4py in single-rank computations from mpi4py import MPI - actx = vec.array_context return actx.from_numpy( comm.allreduce( diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 6ac3ad63..e831b0a9 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -410,7 +410,7 @@ def __init__(self, bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_rank)) local_bdry_data = project(dcoll, volume_dd, bdry_dd, array_container) - comm = dcoll.mpi_communicator + comm = actx.mpi_communicator assert comm is not None self.dcoll = dcoll