Skip to content

Commit

Permalink
get mpi_communicator from actx instead of dcoll
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Aug 19, 2024
1 parent cdd16aa commit b6625bd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
30 changes: 19 additions & 11 deletions grudge/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from pytools import memoize_in

import grudge.dof_desc as dof_desc
from grudge.array_context import MPIBasedArrayContext
from grudge.discretization import DiscretizationCollection


Expand Down Expand Up @@ -128,16 +129,17 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Scalar:
:class:`~arraycontext.ArrayContainer`.
:returns: a device scalar denoting the nodal sum.
"""
comm = dcoll.mpi_communicator
if comm is None:
from arraycontext import get_container_context_recursively
actx = get_container_context_recursively(vec)

if not isinstance(actx, MPIBasedArrayContext):
return nodal_sum_loc(dcoll, dd, vec)

comm = actx.mpi_communicator

# NOTE: Do not move, we do not want to import mpi4py in single-rank computations
from mpi4py import MPI

from arraycontext import get_container_context_recursively
actx = get_container_context_recursively(vec)

return actx.from_numpy(
comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM))

Expand Down Expand Up @@ -174,13 +176,16 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal
:arg initial: an optional initial value. Defaults to `numpy.inf`.
:returns: a device scalar denoting the nodal minimum.
"""
comm = dcoll.mpi_communicator
if comm is None:
from arraycontext import get_container_context_recursively
actx = get_container_context_recursively(vec)

if not isinstance(actx, MPIBasedArrayContext):
return nodal_min_loc(dcoll, dd, vec, initial=initial)

comm = actx.mpi_communicator

# NOTE: Do not move, we do not want to import mpi4py in single-rank computations
from mpi4py import MPI
actx = vec.array_context

return actx.from_numpy(
comm.allreduce(
Expand Down Expand Up @@ -231,13 +236,16 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal
:arg initial: an optional initial value. Defaults to `-numpy.inf`.
:returns: a device scalar denoting the nodal maximum.
"""
comm = dcoll.mpi_communicator
if comm is None:
from arraycontext import get_container_context_recursively
actx = get_container_context_recursively(vec)

if not isinstance(actx, MPIBasedArrayContext):
return nodal_max_loc(dcoll, dd, vec, initial=initial)

comm = actx.mpi_communicator

# NOTE: Do not move, we do not want to import mpi4py in single-rank computations
from mpi4py import MPI
actx = vec.array_context

return actx.from_numpy(
comm.allreduce(
Expand Down
2 changes: 1 addition & 1 deletion grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def __init__(self,
bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_rank))

local_bdry_data = project(dcoll, volume_dd, bdry_dd, array_container)
comm = dcoll.mpi_communicator
comm = actx.mpi_communicator
assert comm is not None

self.dcoll = dcoll
Expand Down

0 comments on commit b6625bd

Please sign in to comment.