Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a CommTag class with a stable hash #319

Closed
wants to merge 15 commits into from
6 changes: 3 additions & 3 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa

from grudge.dof_desc import as_dofdesc, DISCR_TAG_BASE, DISCR_TAG_QUAD
from grudge.trace_pair import TracePair
from grudge.trace_pair import TracePair, CommTag
from grudge.discretization import make_discretization_collection
from grudge.shortcuts import make_visualizer, compiled_lsrk45_step

Expand Down Expand Up @@ -95,7 +95,7 @@ def wave_flux(actx, dcoll, c, w_tpair):
return op.project(dcoll, dd, dd.with_domain_tag("all_faces"), c*flux_weak)


class _WaveStateTag:
class _WaveStateTag(CommTag):
pass


Expand Down Expand Up @@ -144,7 +144,7 @@ def interp_to_surf_quad(utpair):
) + sum(
wave_flux(actx, dcoll, c=c, w_tpair=interp_to_surf_quad(tpair))
for tpair in op.interior_trace_pairs(dcoll, w,
comm_tag=_WaveStateTag)
comm_tag=_WaveStateTag())
)
)
)
Expand Down
36 changes: 30 additions & 6 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
-------------------------------------------

.. autoclass:: TracePair
.. autoclass:: CommTag

.. currentmodule:: grudge.op

Expand Down Expand Up @@ -70,7 +71,7 @@

from numbers import Number

from pytools import memoize_on_first_arg
from pytools import memoize_on_first_arg, memoize_method

from grudge.discretization import DiscretizationCollection
from grudge.projection import project
Expand Down Expand Up @@ -318,8 +319,27 @@ def interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair:
return local_interior_trace_pair(dcoll, vec)


class CommTag:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't really make sense for all CommTags to be subclasses of this. Some may have data contained in them.

Copy link
Collaborator Author

@matthiasdiener matthiasdiener Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a small test in 9a63df4 with subclassed dataclasses, is that what you had in mind?

Edit: I may have misunderstood your comment; I removed dataclasses in 6918c63. Currently, all (hashable) classes are accepted, not just CommTag subclasses.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should stick to arbitrary-data-with-certain-requirements as tags. I mostly don't want to restrict to just subclasses of something specific.

Copy link
Collaborator Author

@matthiasdiener matthiasdiener Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, would you like to keep the CommTag class as an optional way to construct the comm tags? I removed the warning in 9454067.

"""A communication tag with a hash value that is stable across
runs, even without setting ``PYTHONHASHSEED``."""

@memoize_method
def __hash__(self) -> int:
return hash(tuple(str(type(self)).encode("ascii")))

def __eq__(self, other: object) -> bool:
return type(self) is type(other)

def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, (self.__class__.__module__,
self.__class__.__qualname__))

def __repr__(self) -> str:
return self.__class__.__module__ + "." + self.__class__.__qualname__


def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
comm_tag: Hashable = None, tag: Hashable = None,
comm_tag: Optional[Hashable] = None, tag: Hashable = None,
volume_dd: Optional[DOFDesc] = None) -> List[TracePair]:
r"""Return a :class:`list` of :class:`TracePair` objects
defined on the interior faces of *dcoll* and any faces connected to a
Expand All @@ -334,7 +354,7 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
:arg comm_tag: a hashable object used to match sent and received data
across ranks. Communication will only match if both endpoints specify
objects that compare equal. A generalization of MPI communication
tags to arbitary, potentially composite objects.
tags to arbitrary, potentially composite objects.
:returns: a :class:`list` of :class:`TracePair` objects.
"""

Expand Down Expand Up @@ -439,7 +459,7 @@ def __init__(self,
self.comm_tag += comm_tag
del comm_tag

# Here, we initialize both send and recieve operations through
# Here, we initialize both send and receive operations through
# mpi4py `Request` (MPI_Request) instances for comm.Isend (MPI_Isend)
# and comm.Irecv (MPI_Irecv) respectively. These initiate non-blocking
# point-to-point communication requests and require explicit management
Expand Down Expand Up @@ -501,7 +521,11 @@ def __init__(self,
remote_rank: int, comm_tag: Hashable,
volume_dd=DD_VOLUME_ALL):
if comm_tag is None:
raise ValueError("lazy communication requires 'tag' to be supplied")
raise ValueError("lazy communication requires 'comm_tag' to be supplied")

if not isinstance(comm_tag, CommTag):
from warnings import warn
warn(f"comm_tag {comm_tag} should be an instance of CommTag")

bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_rank))

Expand Down Expand Up @@ -570,7 +594,7 @@ def cross_rank_trace_pairs(
:arg comm_tag: a hashable object used to match sent and received data
across ranks. Communication will only match if both endpoints specify
objects that compare equal. A generalization of MPI communication
tags to arbitary, potentially composite objects.
tags to arbitrary, potentially composite objects.

:returns: a :class:`list` of :class:`TracePair` objects.
"""
Expand Down
59 changes: 59 additions & 0 deletions test/test_trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,62 @@ def rand():
assert op.norm(dcoll, tpair.diff - (exterior - interior), np.inf) == 0
assert op.norm(dcoll, tpair.int - interior, np.inf) == 0
assert op.norm(dcoll, tpair.ext - exterior, np.inf) == 0


def test_commtag(actx_factory):

from grudge.trace_pair import CommTag, _sym_tag_to_num_tag

class DerivedCommTag(CommTag):
pass

class DerivedDerivedCommTag(DerivedCommTag):
pass

# {{{ test equality and hash consistency

ct = CommTag()
ct2 = CommTag()
dct = DerivedCommTag()
dct2 = DerivedCommTag()
ddct = DerivedDerivedCommTag()

assert ct == ct2
assert ct != dct
assert dct == dct2
assert dct != ddct
assert ddct != dct
assert (ct, dct) != (dct, ct)

assert hash(ct) == hash(ct2)
assert hash(ct) != hash(dct)
assert hash(dct) == hash(dct2)
assert hash(dct) != hash(ddct)
assert hash(ddct) != hash(dct)
assert hash((ct, dct)) != hash((dct, ct))

# }}}

# {{{ test hash stability

assert hash(ct) == 4644528671524962420
assert hash(dct) == -1013583671995716582
assert hash(ddct) == 626392264874077479

assert hash((ct, 123)) == -578844573019921397
assert hash((dct, 123)) == -8009406276367324841
assert hash((dct, ct)) == 6599529611285265043

# }}}

# {{{ test _sym_tag_to_num_tag

try:
from mpi4py import MPI
except ModuleNotFoundError:
pass
else:
tag_ub = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)
assert _sym_tag_to_num_tag(ct) == (1549868734841116283675 % tag_ub)

# }}}
Loading