Skip to content

Commit

Permalink
don't require comm_tag to be of class CommTag
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Sep 17, 2023
1 parent 35183b2 commit 9b2ced4
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def __eq__(self, other):


def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
comm_tag: CommTag = None, tag: Hashable = None,
comm_tag: Optional[Hashable] = None, tag: Hashable = None,
volume_dd: Optional[DOFDesc] = None) -> List[TracePair]:
r"""Return a :class:`list` of :class:`TracePair` objects
defined on the interior faces of *dcoll* and any faces connected to a
Expand All @@ -344,11 +344,10 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
:arg vec: a :class:`~meshmode.dof_array.DOFArray` or an
:class:`~arraycontext.ArrayContainer` of them.
:arg comm_tag: a :class:`~grudge.trace_pair.CommTag` used to match
sent and received data across ranks. Communication will only match
if both endpoints specify objects that compare equal. A
generalization of MPI communication tags to arbitrary, potentially
composite objects.
:arg comm_tag: a hashable object used to match sent and received data
across ranks. Communication will only match if both endpoints specify
objects that compare equal. A generalization of MPI communication
tags to arbitrary, potentially composite objects.
:returns: a :class:`list` of :class:`TracePair` objects.
"""

Expand Down Expand Up @@ -393,7 +392,7 @@ def connected_ranks(
dcoll._volume_discrs[volume_dd.domain_tag.tag].mesh)


def _sym_tag_to_num_tag(comm_tag: Optional[CommTag]) -> Optional[int]:
def _sym_tag_to_num_tag(comm_tag: Optional[Hashable]) -> Optional[int]:
if comm_tag is None:
return comm_tag

Expand Down Expand Up @@ -512,7 +511,7 @@ class _RankBoundaryCommunicationLazy:
def __init__(self,
dcoll: DiscretizationCollection,
array_container: ArrayOrContainer,
remote_rank: int, comm_tag: Optional[CommTag],
remote_rank: int, comm_tag: Hashable,
volume_dd=DD_VOLUME_ALL):
if comm_tag is None:
raise ValueError("lazy communication requires 'comm_tag' to be supplied")
Expand Down Expand Up @@ -562,7 +561,7 @@ def finish(self):
def cross_rank_trace_pairs(
dcoll: DiscretizationCollection, ary: ArrayOrContainer,
tag: Hashable = None,
*, comm_tag: CommTag = None,
*, comm_tag: Hashable = None,
volume_dd: Optional[DOFDesc] = None) -> List[TracePair]:
r"""Get a :class:`list` of *ary* trace pairs for each partition boundary.
Expand Down

0 comments on commit 9b2ced4

Please sign in to comment.