diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs index 569517bd37ad..9308cac5255a 100644 --- a/crates/polars-core/src/schema.rs +++ b/crates/polars-core/src/schema.rs @@ -81,6 +81,11 @@ impl Schema { Self { inner: map } } + /// Reserve `additional` memory spaces in the schema. + pub fn reserve(&mut self, additional: usize) { + self.inner.reserve(additional); + } + /// The number of fields in the schema #[inline] pub fn len(&self) -> usize { @@ -349,6 +354,21 @@ impl Schema { self.inner.extend(other.inner) } + /// Merge borrowed `other` into `self` + /// + /// Merging logic: + /// - Fields that occur in `self` but not `other` are unmodified + /// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self` + /// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original + /// index + pub fn merge_from_ref(&mut self, other: &Self) { + self.inner.extend( + other + .iter() + .map(|(column, datatype)| (column.clone(), datatype.clone())), + ) + } + /// Convert self to `ArrowSchema` by cloning the fields pub fn to_arrow(&self, pl_flavor: bool) -> ArrowSchema { let fields: Vec<_> = self diff --git a/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs b/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs index aa7008699fbf..84a7a8ed501e 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use arrow::bitmap::MutableBitmap; +use polars_core::schema::Schema; use polars_utils::aliases::{InitHashMaps, PlHashMap}; use polars_utils::arena::{Arena, Node}; @@ -148,8 +149,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) continue; } - let mut new_current_schema = current_schema.as_ref().clone(); - let mut new_input_schema = input_schema.as_ref().clone(); + let input_schema_inner = Arc::make_mut(input_schema); // @NOTE: We don't have to insert a SimpleProjection or redo the `current_schema` if // `pushable` contains only 0..N for some N. We use these two variables to keep track @@ -157,21 +157,20 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) let mut has_seen_unpushable = false; let mut needs_simple_projection = false; - let mut already_removed = 0; + input_schema_inner.reserve(pushable.set_bits()); *current_exprs.exprs_mut() = std::mem::take(current_exprs.exprs_mut()) .into_iter() .zip(pushable.iter()) - .enumerate() - .filter_map(|(i, (expr, do_pushdown))| { + .filter_map(|(expr, do_pushdown)| { if do_pushdown { needs_simple_projection = has_seen_unpushable; + let column = expr.output_name_arc().as_ref(); + // @NOTE: we cannot just use the index here, as there might be renames that sit + // earlier in the schema + let datatype = current_schema.get(column).unwrap(); + input_schema_inner.with_column(column.into(), datatype.clone()); input_exprs.exprs_mut().push(expr); - let (column, datatype) = new_current_schema - .shift_remove_index(i - already_removed) - .unwrap(); - new_input_schema.with_column(column, datatype); - already_removed += 1; None } else { @@ -188,8 +187,14 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) // @NOTE: Here we add a simple projection to make sure that the output still // has the right schema. if needs_simple_projection { - new_current_schema.merge(new_input_schema.clone()); - *input_schema = Arc::new(new_input_schema); + // @NOTE: This may seem stupid, but this way we prioritize the input columns and then + // the existing columns which is exactly what we want. + let mut new_current_schema = Schema::with_capacity(current_schema.len()); + new_current_schema.merge_from_ref(input_schema.as_ref()); + new_current_schema.merge_from_ref(current_schema.as_ref()); + + debug_assert_eq!(new_current_schema.len(), current_schema.len()); + let proj_schema = std::mem::replace(current_schema, Arc::new(new_current_schema)); let moved_current = lp_arena.add(IR::Invalid); @@ -199,8 +204,6 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) }; let current = lp_arena.replace(current, projection); lp_arena.replace(moved_current, current); - } else { - *input_schema = Arc::new(new_input_schema); } } } diff --git a/py-polars/tests/unit/test_cwc.py b/py-polars/tests/unit/test_cwc.py index 19d59292fca7..a08fb5b6e1d8 100644 --- a/py-polars/tests/unit/test_cwc.py +++ b/py-polars/tests/unit/test_cwc.py @@ -1,6 +1,5 @@ # Tests for the optimization pass cluster WITH_COLUMNS - import polars as pl @@ -152,7 +151,7 @@ def test_cwc_with_internal_aliases() -> None: ) -def test_issue_16436() -> None: +def test_read_of_pushed_column_16436() -> None: df = pl.DataFrame( { "x": [1.12, 2.21, 4.2, 3.21], @@ -169,3 +168,29 @@ def test_issue_16436() -> None: .fill_nan(0) .collect() ) + + +def test_multiple_simple_projections_16435() -> None: + df = pl.DataFrame({"a": [1]}).lazy() + + df = ( + df.with_columns(b=pl.col("a")) + .with_columns(c=pl.col("b")) + .with_columns(l2a=pl.lit(2)) + .with_columns(l2b=pl.col("l2a")) + .with_columns(m=pl.lit(3)) + ) + + df.collect() + + +def test_reverse_order() -> None: + df = pl.LazyFrame({"a": [1], "b": [2]}) + + df = ( + df.with_columns(a=pl.col("a"), b=pl.col("b"), c=pl.col("a") * pl.col("b")) + .with_columns(x=pl.col("a"), y=pl.col("b")) + .with_columns(b=pl.col("a"), a=pl.col("b")) + ) + + df.collect()