Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Oct 9, 2024
1 parent ee9bafb commit 9b3abc9
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 9 deletions.
40 changes: 34 additions & 6 deletions crates/polars-core/src/chunked_array/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,11 @@ where
right.offsets().range().try_into().unwrap(),
);

arity::unary_mut_values(lhs, |a| broadcast_op(a, &values).into())
if missing {
arity::unary_mut_with_options(lhs, |a| broadcast_op(a, &values).into())
} else {
arity::unary_mut_values(lhs, |a| broadcast_op(a, &values).into())
}
},
(1, _) => {
let left = lhs.chunks()[0]
Expand All @@ -699,9 +703,19 @@ where
left.offsets().range().try_into().unwrap(),
);

arity::unary_mut_values(rhs, |a| broadcast_op(a, &values).into())
if missing {
arity::unary_mut_with_options(rhs, |a| broadcast_op(a, &values).into())
} else {
arity::unary_mut_values(rhs, |a| broadcast_op(a, &values).into())
}
},
_ => {
if missing {
arity::binary_mut_with_options(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY)
} else {
arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY)
}
},
_ => arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY),
}
}

Expand Down Expand Up @@ -874,7 +888,11 @@ where
}
}

arity::unary_mut_values(lhs, |a| broadcast_op(a, right.values()).into())
if missing {
arity::unary_mut_with_options(lhs, |a| broadcast_op(a, right.values()).into())
} else {
arity::unary_mut_values(lhs, |a| broadcast_op(a, right.values()).into())
}
},
(1, _) => {
let left = lhs.chunks()[0]
Expand All @@ -894,9 +912,19 @@ where
}
}

arity::unary_mut_values(rhs, |a| broadcast_op(a, left.values()).into())
if missing {
arity::unary_mut_with_options(rhs, |a| broadcast_op(a, left.values()).into())
} else {
arity::unary_mut_values(rhs, |a| broadcast_op(a, left.values()).into())
}
},
_ => {
if missing {
arity::binary_mut_with_options(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY)
} else {
arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY)
}
},
_ => arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY),
}
}

Expand Down
3 changes: 0 additions & 3 deletions py-polars/tests/unit/operations/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,6 @@ def test_series_str_explode_deprecated(
) -> None:
with pytest.deprecated_call():
result = pl.Series(values).str.explode()
if result.to_list() != exploded:
print(result.to_list())
print(exploded)
assert result.to_list() == exploded


Expand Down
192 changes: 192 additions & 0 deletions py-polars/tests/unit/series/test_equals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Callable

import pytest

Expand Down Expand Up @@ -105,3 +106,194 @@ def test_series_equals_strict_deprecated() -> None:
s2 = pl.Series("a", [1, 2, None], pl.Int64)
with pytest.deprecated_call():
assert not s1.equals(s2, strict=True) # type: ignore[call-arg]


@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])
@pytest.mark.parametrize(
("cmp_eq", "cmp_ne"),
[
# We parametrize the comparison sides as the impl looks like this:
# match (left.len(), right.len()) {
# (1, _) => ...,
# (_, 1) => ...,
# (_, _) => ...,
# }
(pl.Series.eq, pl.Series.ne),
(
lambda a, b: pl.Series.eq(b, a),
lambda a, b: pl.Series.ne(b, a),
),
],
)
def test_eq_lists_arrays(
dtype: pl.DataType,
cmp_eq: Callable[[pl.Series, pl.Series], pl.Series],
cmp_ne: Callable[[pl.Series, pl.Series], pl.Series],
) -> None:
# Broadcast NULL
assert_series_equal(
cmp_eq(
pl.Series([None], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([None, None, None], dtype=pl.Boolean),
)

assert_series_equal(
cmp_ne(
pl.Series([None], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([None, None, None], dtype=pl.Boolean),
)

# Non-broadcast full-NULL
assert_series_equal(
cmp_eq(
pl.Series(3 * [None], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([None, None, None], dtype=pl.Boolean),
)

assert_series_equal(
cmp_ne(
pl.Series(3 * [None], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([None, None, None], dtype=pl.Boolean),
)

# Broadcast valid
assert_series_equal(
cmp_eq(
pl.Series([[1, None]], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([None, True, False], dtype=pl.Boolean),
)

assert_series_equal(
cmp_ne(
pl.Series([[1, None]], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([None, False, True], dtype=pl.Boolean),
)

# Non-broadcast mixed
assert_series_equal(
cmp_eq(
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([None, False, True], dtype=pl.Boolean),
)

assert_series_equal(
cmp_ne(
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([None, True, False], dtype=pl.Boolean),
)


@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])
@pytest.mark.parametrize(
("cmp_eq_missing", "cmp_ne_missing"),
[
(pl.Series.eq_missing, pl.Series.ne_missing),
(
lambda a, b: pl.Series.eq_missing(b, a),
lambda a, b: pl.Series.ne_missing(b, a),
),
],
)
def test_eq_missing_lists_arrays_19153(
dtype: pl.DataType,
cmp_eq_missing: Callable[[pl.Series, pl.Series], pl.Series],
cmp_ne_missing: Callable[[pl.Series, pl.Series], pl.Series],
) -> None:
def assert_series_equal(
left: pl.Series,
right: pl.Series,
*,
assert_series_equal_impl: Callable[[pl.Series, pl.Series], None] = globals()[
"assert_series_equal"
],
) -> None:
# `assert_series_equal` also uses `ne_missing` underneath so we have
# some extra checks here to be sure.
assert_series_equal_impl(left, right)
assert left.to_list() == right.to_list()
assert left.null_count() == 0
assert right.null_count() == 0

# Broadcast NULL
assert_series_equal(
cmp_eq_missing(
pl.Series([None], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([True, False, False]),
)

assert_series_equal(
cmp_ne_missing(
pl.Series([None], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([False, True, True]),
)

# Non-broadcast full-NULL
assert_series_equal(
cmp_eq_missing(
pl.Series(3 * [None], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([True, False, False]),
)

assert_series_equal(
cmp_ne_missing(
pl.Series(3 * [None], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([False, True, True]),
)

# Broadcast valid
assert_series_equal(
cmp_eq_missing(
pl.Series([[1, None]], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([False, True, False]),
)

assert_series_equal(
cmp_ne_missing(
pl.Series([[1, None]], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([True, False, True]),
)

# Non-broadcast mixed
assert_series_equal(
cmp_eq_missing(
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([True, False, True]),
)

assert_series_equal(
cmp_ne_missing(
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
),
pl.Series([False, True, False]),
)

0 comments on commit 9b3abc9

Please sign in to comment.