From d39bc8485ad6522652abde266101e82193bf3fd4 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Thu, 23 May 2024 16:29:05 +0200 Subject: [PATCH] fix(rust): fixes #16436 and #16437 --- .../optimizer/cluster_with_columns.rs | 6 +++++- py-polars/tests/unit/test_cwc.py | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs b/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs index 98d18597e784..aa7008699fbf 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs @@ -157,6 +157,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) let mut has_seen_unpushable = false; let mut needs_simple_projection = false; + let mut already_removed = 0; *current_exprs.exprs_mut() = std::mem::take(current_exprs.exprs_mut()) .into_iter() .zip(pushable.iter()) @@ -166,8 +167,11 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) needs_simple_projection = has_seen_unpushable; input_exprs.exprs_mut().push(expr); - let (column, datatype) = new_current_schema.shift_remove_index(i).unwrap(); + let (column, datatype) = new_current_schema + .shift_remove_index(i - already_removed) + .unwrap(); new_input_schema.with_column(column, datatype); + already_removed += 1; None } else { diff --git a/py-polars/tests/unit/test_cwc.py b/py-polars/tests/unit/test_cwc.py index cfcd924e7a90..fe27f0652e99 100644 --- a/py-polars/tests/unit/test_cwc.py +++ b/py-polars/tests/unit/test_cwc.py @@ -1,5 +1,6 @@ # Tests for the optimization pass cluster WITH_COLUMNS + import polars as pl @@ -134,3 +135,22 @@ def test_cwc_with_internal_aliases() -> None: """[[(col("a")) == (2)].cast(Boolean).alias("c"), [(col("b")) * (3)].alias("d")]""" in explain ) + + +def test_issue_16436() -> None: + df = pl.DataFrame( + { + "x": [1.12, 2.21, 4.2, 3.21], + "y": [2.11, 3.32, 2.1, 6.12], + } + ) + + df = ( + df.lazy() + .with_columns((pl.col("y") / pl.col("x")).alias("z")) + .with_columns( + pl.when(pl.col("z").is_infinite()).then(0).otherwise(pl.col("z")).alias("z") + ) + .fill_nan(0) + .collect() + )