Skip to content

Commit

Permalink
remove __eq__ function, more eq testing
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Sep 17, 2023
1 parent 9b2ced4 commit 5e21cf2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
3 changes: 0 additions & 3 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,6 @@ class CommTag:
def __hash__(self):
return hash(tuple(str(type(self)).encode("ascii")))

def __eq__(self, other):
return isinstance(other, type(self))


def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
comm_tag: Optional[Hashable] = None, tag: Hashable = None,
Expand Down
58 changes: 47 additions & 11 deletions test/test_trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from grudge.trace_pair import TracePair, CommTag
import meshmode.mesh.generation as mgen
from meshmode.dof_array import DOFArray
from dataclasses import dataclass

from grudge import DiscretizationCollection

Expand Down Expand Up @@ -74,15 +75,50 @@ def test_commtag(actx_factory):
class DerivedCommTag(CommTag):
pass

x = CommTag()
x2 = CommTag()
y = DerivedCommTag()

assert hash(x) == hash(x2)
assert hash(x) != hash(y)
assert hash(x) == 4644528671524962420
assert hash(y) == -1013583671995716582
class DerivedDerivedCommTag(DerivedCommTag):
pass

assert hash((x, 123)) == -578844573019921397
assert hash((y, 123)) == -8009406276367324841
assert hash((y, x)) == 6599529611285265043
ct = CommTag()
ct2 = CommTag()
dct = DerivedCommTag()
dct2 = DerivedCommTag()
ddct = DerivedDerivedCommTag()

assert ct == ct2
assert ct != dct
assert dct == dct2
assert dct != ddct
assert ddct != dct
assert (ct, dct) != (dct, ct)

assert hash(ct) == hash(ct2)
assert hash(ct) != hash(dct)
assert hash(dct) != hash(ddct)

assert hash(ct) == 4644528671524962420
assert hash(dct) == -1013583671995716582
assert hash(ddct) == 626392264874077479

assert hash((ct, 123)) == -578844573019921397
assert hash((dct, 123)) == -8009406276367324841
assert hash((dct, ct)) == 6599529611285265043

@dataclass(frozen=True)
class DataCommTag(CommTag):
data: int

@dataclass(frozen=True)
class DataCommTag2(CommTag):
data: int

d1 = DataCommTag(1)
d2 = DataCommTag(2)
d3 = DataCommTag(1)

assert d1 != d2
assert hash(d1) != hash(d2)
assert d1 == d3
assert hash(d1) == hash(d3)

d4 = DataCommTag2(1)
assert d1 != d4

0 comments on commit 5e21cf2

Please sign in to comment.