Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Jun 19, 2024
2 parents 223665e + 369da34 commit 0b38f68
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions skbase/validate/tests/test_type_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def test_is_sequence_output():
)

# Test with 3rd party types works in default way via exact type
assert is_sequence([1.2, 4.7], element_type=np.float_) is False
assert is_sequence([np.float_(1.2), np.float_(4.7)], element_type=np.float_) is True
assert is_sequence([1.2, 4.7], element_type=np.float64) is False
assert is_sequence([np.float64(1.2), np.float64(4.7)], element_type=np.float64)

# np.nan is float, not int or np.float_
# np.nan is float, not int or np.float64
assert is_sequence([np.nan, 4.8], element_type=float) is True
assert is_sequence([np.nan, 4], element_type=int) is False

Expand Down Expand Up @@ -243,11 +243,11 @@ def test_check_sequence_output():
TypeError,
match="Invalid sequence: .*",
):
check_sequence([1.2, 4.7], element_type=np.float_)
input_seq = [np.float_(1.2), np.float_(4.7)]
assert check_sequence(input_seq, element_type=np.float_) == input_seq
check_sequence([1.2, 4.7], element_type=np.float64)
input_seq = [np.float64(1.2), np.float64(4.7)]
assert check_sequence(input_seq, element_type=np.float64) == input_seq

# np.nan is float, not int or np.float_
# np.nan is float, not int or np.float64
assert check_sequence([np.nan, 4.8], element_type=float) == [np.nan, 4.8]
assert check_sequence([np.nan, 4.8, 7], element_type=(float, int)) == [
np.nan,
Expand Down

0 comments on commit 0b38f68

Please sign in to comment.