Skip to content

Commit

Permalink
fix(rust): ensure eq for BinaryViewArray checks all elements (pol…
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Mar 25, 2024
1 parent 53f5536 commit 59c59ce
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/array/binview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ pub struct BinaryViewArrayGeneric<T: ViewType + ?Sized> {

impl<T: ViewType + ?Sized> PartialEq for BinaryViewArrayGeneric<T> {
fn eq(&self, other: &Self) -> bool {
self.into_iter().zip(other).all(|(l, r)| l == r)
self.len() == other.len() && self.into_iter().zip(other).all(|(l, r)| l == r)
}
}

Expand Down
41 changes: 35 additions & 6 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def test_append_to_an_enum() -> None:

def test_append_to_an_enum_with_new_category() -> None:
with pytest.raises(
pl.ComputeError,
match=("can not merge incompatible Enum types"),
pl.SchemaError,
match=("cannot extend/append Enum"),
):
pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])).append(
pl.Series(["d", "a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"]))
Expand Down Expand Up @@ -469,7 +469,36 @@ def test_enum_cse_eq() -> None:
dt1 = pl.Enum(["a", "b"])
dt2 = pl.Enum(["a", "c"])

df.lazy().select(
pl.when(True).then(pl.lit("a", dtype=dt1)).alias("dt1"),
pl.when(True).then(pl.lit("a", dtype=dt2)).alias("dt2"),
).collect()
out = (
df.lazy()
.select(
pl.when(True).then(pl.lit("a", dtype=dt1)).alias("dt1"),
pl.when(True).then(pl.lit("a", dtype=dt2)).alias("dt2"),
)
.collect()
)

assert out["dt1"].item() == "a"
assert out["dt2"].item() == "a"
assert out["dt1"].dtype == pl.Enum(["a", "b"])
assert out["dt2"].dtype == pl.Enum(["a", "c"])
assert out["dt1"].dtype != out["dt2"].dtype


def test_category_comparison_subset() -> None:
dt1 = pl.Enum(["a"])
dt2 = pl.Enum(["a", "b"])
out = (
pl.LazyFrame()
.select(
pl.lit("a", dtype=dt1).alias("dt1"),
pl.lit("a", dtype=dt2).alias("dt2"),
)
.collect()
)

assert out["dt1"].item() == "a"
assert out["dt2"].item() == "a"
assert out["dt1"].dtype == pl.Enum(["a"])
assert out["dt2"].dtype == pl.Enum(["a", "b"])
assert out["dt1"].dtype != out["dt2"].dtype
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_replace_enum_to_new_enum() -> None:
new_dtype = pl.Enum(["a", "b", "c", "d", "e"])
new = pl.Series(["c", "e"], dtype=new_dtype)

result = s.replace(old, new)
result = s.replace(old, new, return_dtype=new_dtype)

expected = pl.Series(["c", "e", "c"], dtype=new_dtype)
assert_series_equal(result, expected)
Expand Down

0 comments on commit 59c59ce

Please sign in to comment.