Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summarize changes to support prediction #1

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,554 changes: 1,554 additions & 0 deletions meshmode/array_context.py

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions meshmode/discretization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@ def empty(self, actx: ArrayContext,
f"in 2025. Use '{type(self).__name__}.zeros' instead.",
DeprecationWarning, stacklevel=2)

return self._new_array(actx, actx.np.zeros, dtype=dtype)
# return self._new_array(actx, actx.np.zeros, dtype=dtype)
return self._new_array(actx, actx.zeros, dtype=dtype)

def zeros(self, actx: ArrayContext,
dtype: Optional[np.dtype] = None) -> _DOFArray:
Expand All @@ -490,15 +491,17 @@ def zeros(self, actx: ArrayContext,
raise TypeError(
f"'actx' must be an ArrayContext, not '{type(actx).__name__}'")

return self._new_array(actx, actx.np.zeros, dtype=dtype)
# return self._new_array(actx, actx.np.zeros, dtype=dtype)
return self._new_array(actx, actx.zeros, dtype=dtype)

def empty_like(self, array: _DOFArray) -> _DOFArray:
warn(f"'{type(self).__name__}.empty_like' is deprecated and will be removed "
f"in 2025. Use '{type(self).__name__}.zeros_like' instead.",
DeprecationWarning, stacklevel=2)

actx = array.array_context
return self._new_array(actx, actx.np.zeros, dtype=array.entry_dtype)
# return self._new_array(actx, actx.np.zeros, dtype=array.entry_dtype)
return self._new_array(actx, actx.zeros, dtype=array.entry_dtype)

def zeros_like(self, array: _DOFArray) -> _DOFArray:
return self.zeros(array.array_context, dtype=array.entry_dtype)
Expand Down
53 changes: 36 additions & 17 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ConcurrentElementInameTag,
DiscretizationDOFAxisTag,
DiscretizationElementAxisTag,
DiscretizationDOFPickListAxisTag,
)


Expand Down Expand Up @@ -534,17 +535,22 @@ def _per_target_group_pick_info(
_FromGroupPickData(
from_group_index=source_group_index,
dof_pick_lists=actx.freeze(
actx.tag(NameHint("dof_pick_lists"),
actx.from_numpy(dof_pick_lists))),
actx.tag_axis(0, DiscretizationDOFPickListAxisTag(),
actx.tag(NameHint("dof_pick_lists"),
actx.from_numpy(dof_pick_lists)))),
dof_pick_list_indices=actx.freeze(
actx.tag(NameHint("dof_pick_list_indices"),
actx.from_numpy(dof_pick_list_indices))),
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("dof_pick_list_indices"),
actx.from_numpy(dof_pick_list_indices)))),
from_el_present=actx.freeze(
actx.tag(NameHint("from_el_present"),
actx.from_numpy(from_el_present.astype(np.int8)))),
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("from_el_present"),
actx.from_numpy(
from_el_present.astype(np.int8))))),
from_element_indices=actx.freeze(
actx.tag(NameHint("from_el_indices"),
actx.from_numpy(from_el_indices))),
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("from_el_indices"),
actx.from_numpy(from_el_indices)))),
is_surjective=from_el_present.all()
))

Expand Down Expand Up @@ -716,25 +722,29 @@ def group_pick_knl(is_surjective: bool):
group_pick_info = None

if group_pick_info is not None:
group_array_contributions = []
# group_array_contributions = []

if actx.permits_advanced_indexing and not _force_use_loopy:
for fgpd in group_pick_info:
from_element_indices = actx.thaw(fgpd.from_element_indices)

if ary[fgpd.from_group_index].size:
grp_ary_contrib = ary[fgpd.from_group_index][
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_element_indices, (-1, 1)),
actx, from_element_indices, (-1, 1))),
actx.thaw(fgpd.dof_pick_lists)[
actx.thaw(fgpd.dof_pick_list_indices)]
]

if not fgpd.is_surjective:
from_el_present = actx.thaw(fgpd.from_el_present)
grp_ary_contrib = actx.np.where(
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1))),
grp_ary_contrib,
0)

Expand Down Expand Up @@ -784,8 +794,10 @@ def group_pick_knl(is_surjective: bool):
mat = self._resample_matrix(actx, i_tgrp, i_batch)
if actx.permits_advanced_indexing and not _force_use_loopy:
batch_result = actx.np.where(
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
actx, from_el_present, (-1, 1))),
actx.einsum("ij,ej->ei",
mat, grp_ary[from_element_indices]),
0)
Expand All @@ -806,11 +818,15 @@ def group_pick_knl(is_surjective: bool):

if actx.permits_advanced_indexing and not _force_use_loopy:
batch_result = actx.np.where(
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
actx, from_el_present, (-1, 1))),
from_vec[
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_element_indices, (-1, 1)),
actx, from_element_indices, (-1, 1))),
pick_list],
0)
else:
Expand All @@ -837,10 +853,13 @@ def group_pick_knl(is_surjective: bool):
else:
# If no batched data at all, return zeros for this
# particular group array
group_array = actx.np.zeros(
group_array = tag_axes(actx, {
0: DiscretizationElementAxisTag(),
1: DiscretizationDOFAxisTag()},
actx.np.zeros(
shape=(self.to_discr.groups[i_tgrp].nelements,
self.to_discr.groups[i_tgrp].nunit_dofs),
dtype=ary.entry_dtype)
dtype=ary.entry_dtype))

group_arrays.append(group_array)

Expand Down
6 changes: 6 additions & 0 deletions meshmode/discretization/poly_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,14 @@ def __init__(self, mesh_el_group: _MeshTensorProductElementGroup,
"`unit_nodes` dim = {unit_nodes.shape[0]}.")

self._basis = basis
self._bases_1d = basis.bases[0]
self._nodes = unit_nodes

def bases_1d(self):
"""Return 1D component bases used to construct the tensor product basis.
"""
return self._bases_1d

def basis_obj(self):
return self._basis

Expand Down
114 changes: 83 additions & 31 deletions meshmode/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.. autoclass:: InterRankBoundaryInfo
.. autoclass:: MPIBoundaryCommSetupHelper

.. autofunction:: mpi_distribute
.. autofunction:: get_partition_by_pymetis
.. autofunction:: membership_list_to_map
.. autofunction:: get_connected_parts
Expand Down Expand Up @@ -36,11 +37,22 @@
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Hashable,
List,
Optional,
Mapping,
Sequence,
Set,
Union,
cast
)
from warnings import warn

import numpy as np

from contextlib import contextmanager
from arraycontext import ArrayContext

from meshmode.discretization import ElementGroupFactory
Expand All @@ -66,6 +78,70 @@

# {{{ mesh distributor

@contextmanager
def _duplicate_mpi_comm(mpi_comm):
dup_comm = mpi_comm.Dup()
try:
yield dup_comm
finally:
dup_comm.Free()


def mpi_distribute(
mpi_comm: "mpi4py.MPI.Intracomm",
source_data: Optional[Mapping[int, Any]] = None,
source_rank: int = 0) -> Optional[Any]:
"""
Distribute data to a set of processes.

:arg mpi_comm: An ``MPI.Intracomm``
:arg source_data: A :class:`dict` mapping destination ranks to data to be sent.
Only present on the source rank.
:arg source_rank: The rank from which the data is being sent.

:returns: The data local to the current process if there is any, otherwise
*None*.
"""
with _duplicate_mpi_comm(mpi_comm) as mpi_comm:
num_proc = mpi_comm.Get_size()
rank = mpi_comm.Get_rank()

local_data = None

if rank == source_rank:
if source_data is None:
raise TypeError("source rank has no data.")

sending_to = [False] * num_proc
for dest_rank in source_data.keys():
sending_to[dest_rank] = True

mpi_comm.scatter(sending_to, root=source_rank)

reqs = []
for dest_rank, data in source_data.items():
if dest_rank == rank:
local_data = data
logger.info("rank %d: received data", rank)
else:
reqs.append(mpi_comm.isend(data, dest=dest_rank))

logger.info("rank %d: sent all data", rank)

from mpi4py import MPI
MPI.Request.waitall(reqs)

else:
receiving = mpi_comm.scatter([], root=source_rank)

if receiving:
local_data = mpi_comm.recv(source=source_rank)
logger.info("rank %d: received data", rank)

return local_data


# TODO: Deprecate?
class MPIMeshDistributor:
"""
.. automethod:: is_manager_rank
Expand Down Expand Up @@ -104,9 +180,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
Sends each part to a different rank.
Returns one part that was not sent to any other rank.
"""
mpi_comm = self.mpi_comm
rank = mpi_comm.Get_rank()
assert num_parts <= mpi_comm.Get_size()
assert num_parts <= self.mpi_comm.Get_size()

assert self.is_manager_rank()

Expand All @@ -115,38 +189,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
from meshmode.mesh.processing import partition_mesh
parts = partition_mesh(mesh, part_num_to_elements)

local_part = None

reqs = []
for r, part in parts.items():
if r == self.manager_rank:
local_part = part
else:
reqs.append(mpi_comm.isend(part, dest=r, tag=TAG_DISTRIBUTE_MESHES))

logger.info("rank %d: sent all mesh parts", rank)
for req in reqs:
req.wait()

return local_part
return mpi_distribute(
self.mpi_comm, source_data=parts, source_rank=self.manager_rank)

def receive_mesh_part(self):
"""
Returns the mesh sent by the manager rank.
"""
mpi_comm = self.mpi_comm
rank = mpi_comm.Get_rank()

assert not self.is_manager_rank(), "Manager rank cannot receive mesh"

from mpi4py import MPI
status = MPI.Status()
result = self.mpi_comm.recv(
source=self.manager_rank, tag=TAG_DISTRIBUTE_MESHES,
status=status)
logger.info("rank %d: received local mesh (size = %d)", rank, status.count)
assert not self.is_mananger_rank(), "Manager rank cannot receive mesh"

return result
return mpi_distribute(self.mpi_comm, source_rank=self.manager_rank)

# }}}

Expand Down
22 changes: 19 additions & 3 deletions meshmode/mesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,8 @@ def check_mesh_consistency(
"parameter force_positive_orientation=True to make_mesh().")
else:
warn("Unimplemented: Cannot check element orientation for a mesh with "
"mesh.dim != mesh.ambient_dim", stacklevel=2)
f"mesh.dim != mesh.ambient_dim ({mesh.dim=},{mesh.ambient_dim=})",
stacklevel=2)


def is_mesh_consistent(
Expand Down Expand Up @@ -957,7 +958,8 @@ def make_mesh(
node_vertex_consistency_tolerance: float | None = None,
skip_element_orientation_test: bool = False,
force_positive_orientation: bool = False,
) -> Mesh:
face_vertex_indices_to_tags=None,
) -> "Mesh":
"""Construct a new mesh from a given list of *groups*.

This constructor performs additional checks on the mesh once constructed and
Expand Down Expand Up @@ -1045,6 +1047,15 @@ def make_mesh(
nodal_adjacency = (
NodalAdjacency(neighbors_starts=nb_starts, neighbors=nbs))

face_vert_ind_to_tags_local = None
if face_vertex_indices_to_tags is not None:
face_vert_ind_to_tags_local = face_vertex_indices_to_tags.copy()

if (facial_adjacency_groups is False or facial_adjacency_groups is None):
if face_vertex_indices_to_tags is not None:
facial_adjacency_groups = _compute_facial_adjacency_from_vertices(
groups, np.int32, np.int8, face_vertex_indices_to_tags)

if (
facial_adjacency_groups is not False
and facial_adjacency_groups is not None):
Expand All @@ -1069,8 +1080,13 @@ def make_mesh(
if force_positive_orientation:
if mesh.dim == mesh.ambient_dim:
import meshmode.mesh.processing as mproc
mesh_making_kwargs = {
"face_vertex_indices_to_tags": face_vert_ind_to_tags_local
}
mesh = mproc.perform_flips(
mesh, mproc.find_volume_mesh_element_orientations(mesh) < 0)
mesh=mesh,
flip_flags=mproc.find_volume_mesh_element_orientations(mesh) < 0,
skip_tests=False, mesh_making_kwargs=mesh_making_kwargs)
else:
raise ValueError("cannot enforce positive element orientation "
"on non-volume meshes")
Expand Down
Loading
Loading