diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index d2f8aeb0..8f4583b1 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -319,14 +319,22 @@ def interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair: return local_interior_trace_pair(dcoll, vec) -@dataclass(frozen=True) # for KeyBuilder support class CommTag: """A communication tag with a hash value that is stable across runs, even without setting ``PYTHONHASHSEED``.""" + @memoize_method - def __hash__(self): + def __hash__(self) -> int: return hash(tuple(str(type(self)).encode("ascii"))) + def __eq__(self, other: object) -> bool: + return type(self) is type(other) + + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, (self.__class__.__module__, + self.__class__.__qualname__)) + + def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *, comm_tag: Optional[Hashable] = None, tag: Hashable = None, diff --git a/test/test_trace_pair.py b/test/test_trace_pair.py index 9e26032c..2d2bd782 100644 --- a/test/test_trace_pair.py +++ b/test/test_trace_pair.py @@ -22,7 +22,7 @@ import numpy as np -from grudge.trace_pair import TracePair, CommTag +from grudge.trace_pair import TracePair import meshmode.mesh.generation as mgen from meshmode.dof_array import DOFArray from dataclasses import dataclass @@ -72,6 +72,8 @@ def rand(): def test_commtag(actx_factory): + from grudge.trace_pair import CommTag, _sym_tag_to_num_tag + class DerivedCommTag(CommTag): pass @@ -86,6 +88,8 @@ class DerivedDerivedCommTag(DerivedCommTag): dct2 = DerivedCommTag() ddct = DerivedDerivedCommTag() + assert _sym_tag_to_num_tag(ct) == 441551355 + assert ct == ct2 assert ct != dct assert dct == dct2 @@ -110,27 +114,3 @@ class DerivedDerivedCommTag(DerivedCommTag): assert hash((dct, ct)) == 6599529611285265043 # }}} - - # {{{ test using derived dataclasses - - @dataclass(frozen=True) - class DataCommTag(CommTag): - data: int - - @dataclass(frozen=True) - class DataCommTag2(CommTag): - data: int - - d1 = DataCommTag(1) - d2 = DataCommTag(2) - d3 = DataCommTag(1) - - assert d1 != d2 - assert hash(d1) != hash(d2) - assert d1 == d3 - assert hash(d1) == hash(d3) - - d4 = DataCommTag2(1) - assert d1 != d4 - - # }}}