Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a CommTag class with a stable hash #319

Closed
wants to merge 15 commits into from
6 changes: 3 additions & 3 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa

from grudge.dof_desc import as_dofdesc, DOFDesc, DISCR_TAG_BASE, DISCR_TAG_QUAD
from grudge.trace_pair import TracePair
from grudge.trace_pair import TracePair, CommTag
from grudge.discretization import DiscretizationCollection
from grudge.shortcuts import make_visualizer, compiled_lsrk45_step

Expand Down Expand Up @@ -95,7 +95,7 @@ def wave_flux(actx, dcoll, c, w_tpair):
return op.project(dcoll, dd, dd.with_dtag("all_faces"), c*flux_weak)


class _WaveStateTag:
class _WaveStateTag(CommTag):
pass


Expand Down Expand Up @@ -144,7 +144,7 @@ def interp_to_surf_quad(utpair):
) + sum(
wave_flux(actx, dcoll, c=c, w_tpair=interp_to_surf_quad(tpair))
for tpair in op.interior_trace_pairs(dcoll, w,
comm_tag=_WaveStateTag)
comm_tag=_WaveStateTag())
)
)
)
Expand Down
40 changes: 29 additions & 11 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
-------------------------------------------

.. autoclass:: TracePair
.. autoclass:: CommTag

.. currentmodule:: grudge.op

Expand Down Expand Up @@ -70,7 +71,7 @@

from numbers import Number

from pytools import memoize_on_first_arg
from pytools import memoize_on_first_arg, memoize_method

from grudge.discretization import DiscretizationCollection
from grudge.projection import project
Expand Down Expand Up @@ -318,8 +319,20 @@ def interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair:
return local_interior_trace_pair(dcoll, vec)


@dataclass(frozen=True) # for KeyBuilder support
class CommTag:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't really make sense for all CommTags to be subclasses of this. Some may have data contained in them.

Copy link
Collaborator Author

@matthiasdiener matthiasdiener Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a small test in 9a63df4 with subclassed dataclasses, is that what you had in mind?

Edit: I may have misunderstood your comment; I removed dataclasses in 6918c63. Currently, all (hashable) classes are accepted, not just CommTag subclasses.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should stick to arbitrary-data-with-certain-requirements as tags. I mostly don't want to restrict to just subclasses of something specific.

Copy link
Collaborator Author

@matthiasdiener matthiasdiener Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, would you like to keep the CommTag class as an optional way to construct the comm tags? I removed the warning in 9454067.

"""A communication tag with a hash value that is stable across
runs, even without setting ``PYTHONHASHSEED``."""
@memoize_method
def __hash__(self):
return hash(tuple(str(type(self)).encode("ascii")))

def __eq__(self, other):
return isinstance(other, type(self))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect: It accepts arbitrary subclasses for other, and violates the property of "__eq__ implies hash equality."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is (hopefully) fixed in 6918c63



def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
comm_tag: Hashable = None, tag: Hashable = None,
comm_tag: CommTag = None, tag: Hashable = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree with this change. There's no good reason why other Hashables should not be accepted.

Copy link
Collaborator Author

@matthiasdiener matthiasdiener Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I undid this change in 9b2ced4. I left the warning in https://github.com/inducer/grudge/pull/319/files#diff-5fa6e6b713d37028c16f1bbc5a6d0b4547ecba0f86c8a52eb30e1ba1dc2614e4R518 in place for now, would you like me to remove it?

Edit:
Removed in 9454067

volume_dd: Optional[DOFDesc] = None) -> List[TracePair]:
r"""Return a :class:`list` of :class:`TracePair` objects
defined on the interior faces of *dcoll* and any faces connected to a
Expand All @@ -331,10 +344,11 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,

:arg vec: a :class:`~meshmode.dof_array.DOFArray` or an
:class:`~arraycontext.ArrayContainer` of them.
:arg comm_tag: a hashable object used to match sent and received data
across ranks. Communication will only match if both endpoints specify
objects that compare equal. A generalization of MPI communication
tags to arbitary, potentially composite objects.
:arg comm_tag: a :class:`~grudge.trace_pair.CommTag` used to match
sent and received data across ranks. Communication will only match
if both endpoints specify objects that compare equal. A
generalization of MPI communication tags to arbitrary, potentially
composite objects.
:returns: a :class:`list` of :class:`TracePair` objects.
"""

Expand Down Expand Up @@ -379,7 +393,7 @@ def connected_ranks(
dcoll._volume_discrs[volume_dd.domain_tag.tag].mesh)


def _sym_tag_to_num_tag(comm_tag: Optional[Hashable]) -> Optional[int]:
def _sym_tag_to_num_tag(comm_tag: Optional[CommTag]) -> Optional[int]:
if comm_tag is None:
return comm_tag

Expand Down Expand Up @@ -498,10 +512,14 @@ class _RankBoundaryCommunicationLazy:
def __init__(self,
dcoll: DiscretizationCollection,
array_container: ArrayOrContainer,
remote_rank: int, comm_tag: Hashable,
remote_rank: int, comm_tag: Optional[CommTag],
volume_dd=DD_VOLUME_ALL):
if comm_tag is None:
raise ValueError("lazy communication requires 'tag' to be supplied")
raise ValueError("lazy communication requires 'comm_tag' to be supplied")

if not isinstance(comm_tag, CommTag):
from warnings import warn
warn(f"comm_tag {comm_tag} should be an instance of CommTag")

bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_rank))

Expand Down Expand Up @@ -544,7 +562,7 @@ def finish(self):
def cross_rank_trace_pairs(
dcoll: DiscretizationCollection, ary: ArrayOrContainer,
tag: Hashable = None,
*, comm_tag: Hashable = None,
*, comm_tag: CommTag = None,
volume_dd: Optional[DOFDesc] = None) -> List[TracePair]:
r"""Get a :class:`list` of *ary* trace pairs for each partition boundary.

Expand All @@ -570,7 +588,7 @@ def cross_rank_trace_pairs(
:arg comm_tag: a hashable object used to match sent and received data
across ranks. Communication will only match if both endpoints specify
objects that compare equal. A generalization of MPI communication
tags to arbitary, potentially composite objects.
tags to arbitrary, potentially composite objects.

:returns: a :class:`list` of :class:`TracePair` objects.
"""
Expand Down
21 changes: 20 additions & 1 deletion test/test_trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


import numpy as np
from grudge.trace_pair import TracePair
from grudge.trace_pair import TracePair, CommTag
import meshmode.mesh.generation as mgen
from meshmode.dof_array import DOFArray

Expand Down Expand Up @@ -67,3 +67,22 @@ def rand():
assert op.norm(dcoll, tpair.diff - (exterior - interior), np.inf) == 0
assert op.norm(dcoll, tpair.int - interior, np.inf) == 0
assert op.norm(dcoll, tpair.ext - exterior, np.inf) == 0


def test_commtag(actx_factory):

class DerivedCommTag(CommTag):
pass

x = CommTag()
x2 = CommTag()
y = DerivedCommTag()

assert hash(x) == hash(x2)
assert hash(x) != hash(y)
assert hash(x) == 4644528671524962420
assert hash(y) == -1013583671995716582

assert hash((x, 123)) == -578844573019921397
assert hash((y, 123)) == -8009406276367324841
assert hash((y, x)) == 6599529611285265043
Loading