From 59c59ce9a0f478414f412f5dd6676ae666c9ef16 Mon Sep 17 00:00:00 2001 From: Marshall Date: Mon, 25 Mar 2024 07:19:41 -0400 Subject: [PATCH] fix(rust): ensure `eq` for `BinaryViewArray` checks all elements (#15268) --- crates/polars-arrow/src/array/binview/mod.rs | 2 +- py-polars/tests/unit/datatypes/test_enum.py | 41 ++++++++++++++++--- .../tests/unit/operations/test_replace.py | 2 +- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/crates/polars-arrow/src/array/binview/mod.rs b/crates/polars-arrow/src/array/binview/mod.rs index bea1c00d6709..f889e43dbc89 100644 --- a/crates/polars-arrow/src/array/binview/mod.rs +++ b/crates/polars-arrow/src/array/binview/mod.rs @@ -120,7 +120,7 @@ pub struct BinaryViewArrayGeneric { impl PartialEq for BinaryViewArrayGeneric { fn eq(&self, other: &Self) -> bool { - self.into_iter().zip(other).all(|(l, r)| l == r) + self.len() == other.len() && self.into_iter().zip(other).all(|(l, r)| l == r) } } diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index 5562f703e2e9..c2eab224e51b 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -214,8 +214,8 @@ def test_append_to_an_enum() -> None: def test_append_to_an_enum_with_new_category() -> None: with pytest.raises( - pl.ComputeError, - match=("can not merge incompatible Enum types"), + pl.SchemaError, + match=("cannot extend/append Enum"), ): pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])).append( pl.Series(["d", "a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) @@ -469,7 +469,36 @@ def test_enum_cse_eq() -> None: dt1 = pl.Enum(["a", "b"]) dt2 = pl.Enum(["a", "c"]) - df.lazy().select( - pl.when(True).then(pl.lit("a", dtype=dt1)).alias("dt1"), - pl.when(True).then(pl.lit("a", dtype=dt2)).alias("dt2"), - ).collect() + out = ( + df.lazy() + .select( + pl.when(True).then(pl.lit("a", dtype=dt1)).alias("dt1"), + pl.when(True).then(pl.lit("a", dtype=dt2)).alias("dt2"), + ) + .collect() + ) + + assert out["dt1"].item() == "a" + assert out["dt2"].item() == "a" + assert out["dt1"].dtype == pl.Enum(["a", "b"]) + assert out["dt2"].dtype == pl.Enum(["a", "c"]) + assert out["dt1"].dtype != out["dt2"].dtype + + +def test_category_comparison_subset() -> None: + dt1 = pl.Enum(["a"]) + dt2 = pl.Enum(["a", "b"]) + out = ( + pl.LazyFrame() + .select( + pl.lit("a", dtype=dt1).alias("dt1"), + pl.lit("a", dtype=dt2).alias("dt2"), + ) + .collect() + ) + + assert out["dt1"].item() == "a" + assert out["dt2"].item() == "a" + assert out["dt1"].dtype == pl.Enum(["a"]) + assert out["dt2"].dtype == pl.Enum(["a", "b"]) + assert out["dt1"].dtype != out["dt2"].dtype diff --git a/py-polars/tests/unit/operations/test_replace.py b/py-polars/tests/unit/operations/test_replace.py index c077e26338b1..56598c7d3be2 100644 --- a/py-polars/tests/unit/operations/test_replace.py +++ b/py-polars/tests/unit/operations/test_replace.py @@ -94,7 +94,7 @@ def test_replace_enum_to_new_enum() -> None: new_dtype = pl.Enum(["a", "b", "c", "d", "e"]) new = pl.Series(["c", "e"], dtype=new_dtype) - result = s.replace(old, new) + result = s.replace(old, new, return_dtype=new_dtype) expected = pl.Series(["c", "e", "c"], dtype=new_dtype) assert_series_equal(result, expected)