Skip to content

Commit

Permalink
Extend test and resolve types in check
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jun 27, 2024
1 parent 3292034 commit 6c695e9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
5 changes: 3 additions & 2 deletions plum/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def add_promotion_rule(type1, type2, type_to):
def rule(t1: type1, t2: type2):
return type_to

# If the types are the same, we don't need to add the reverse rule.
if type1 is type2:
# If the types are the same, we don't need to add the reverse rule. Resolve the
# types to handle the case where types are equal, but not identical.
if TypeHint(resolve_type_hint(type1)) == TypeHint(resolve_type_hint(type2)):
return # Escape early.

@_promotion_rule.dispatch
Expand Down
27 changes: 23 additions & 4 deletions tests/test_promotion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing
import warnings
from numbers import Number
from typing import Union

Expand Down Expand Up @@ -159,7 +161,24 @@ def test_inheritance(convert, promote):
assert promote(Re(), n) == ("Num from Re", n)
assert promote(Re(), Rat()) == ("Num from Re", "Num from Rat")

# Test that explicit self-promotion works.
# This should also trigger the "escape hatch" in `add_promotion_rule`.
add_promotion_rule(Num, Num, Num)
assert promote(n, n) == (n, n)

def test_self_promotion(convert, promote):
# This should trigger the "escape hatch" in `add_promotion_rule`. It also should not
# trigger a redefinition warning. Explicitly test for that.
with warnings.catch_warnings():
warnings.simplefilter("error")

# Simple case where types are identical:
add_promotion_rule(Num, Num, Num)
n = Num()
assert promote(n, n) == (n, n)

# Also test a more complicated scenario where the types are equal, but not
# identical.
t1 = typing.Union[int, float]
t2 = typing.Union[float, int]
assert t1 is not t2
add_promotion_rule(t1, t2, str)
add_conversion_method(int, str, str)
add_conversion_method(float, str, str)
assert promote(1, 1.0) == ("1", "1.0")

0 comments on commit 6c695e9

Please sign in to comment.