Skip to content

Commit

Permalink
fix(rust): fix issue #16436
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed May 24, 2024
1 parent ee25cb7 commit 76f02bc
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 15 deletions.
20 changes: 20 additions & 0 deletions crates/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ impl Schema {
Self { inner: map }
}

/// Reserve `additional` memory spaces in the schema.
pub fn reserve(&mut self, additional: usize) {
self.inner.reserve(additional);
}

/// The number of fields in the schema
#[inline]
pub fn len(&self) -> usize {
Expand Down Expand Up @@ -349,6 +354,21 @@ impl Schema {
self.inner.extend(other.inner)
}

/// Merge borrowed `other` into `self`
///
/// Merging logic:
/// - Fields that occur in `self` but not `other` are unmodified
/// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self`
/// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original
/// index
pub fn merge_from_ref(&mut self, other: &Self) {
self.inner.extend(
other
.iter()
.map(|(column, datatype)| (column.clone(), datatype.clone())),
)
}

/// Convert self to `ArrowSchema` by cloning the fields
pub fn to_arrow(&self, pl_flavor: bool) -> ArrowSchema {
let fields: Vec<_> = self
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use arrow::bitmap::MutableBitmap;
use polars_core::schema::Schema;
use polars_utils::aliases::{InitHashMaps, PlHashMap};
use polars_utils::arena::{Arena, Node};

Expand Down Expand Up @@ -148,30 +149,28 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &Arena<AExpr>)
continue;
}

let mut new_current_schema = current_schema.as_ref().clone();
let mut new_input_schema = input_schema.as_ref().clone();
let input_schema_inner = Arc::make_mut(input_schema);

// @NOTE: We don't have to insert a SimpleProjection or redo the `current_schema` if
// `pushable` contains only 0..N for some N. We use these two variables to keep track
// of this.
let mut has_seen_unpushable = false;
let mut needs_simple_projection = false;

let mut already_removed = 0;
input_schema_inner.reserve(pushable.set_bits());
*current_exprs.exprs_mut() = std::mem::take(current_exprs.exprs_mut())
.into_iter()
.zip(pushable.iter())
.enumerate()
.filter_map(|(i, (expr, do_pushdown))| {
.filter_map(|(expr, do_pushdown)| {
if do_pushdown {
needs_simple_projection = has_seen_unpushable;

let column = expr.output_name_arc().as_ref();
// @NOTE: we cannot just use the index here, as there might be renames that sit
// earlier in the schema
let datatype = current_schema.get(column).unwrap();
input_schema_inner.with_column(column.into(), datatype.clone());
input_exprs.exprs_mut().push(expr);
let (column, datatype) = new_current_schema
.shift_remove_index(i - already_removed)
.unwrap();
new_input_schema.with_column(column, datatype);
already_removed += 1;

None
} else {
Expand All @@ -188,8 +187,14 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &Arena<AExpr>)
// @NOTE: Here we add a simple projection to make sure that the output still
// has the right schema.
if needs_simple_projection {
new_current_schema.merge(new_input_schema.clone());
*input_schema = Arc::new(new_input_schema);
// @NOTE: This may seem stupid, but this way we prioritize the input columns and then
// the existing columns which is exactly what we want.
let mut new_current_schema = Schema::with_capacity(current_schema.len());
new_current_schema.merge_from_ref(input_schema.as_ref());
new_current_schema.merge_from_ref(current_schema.as_ref());

debug_assert_eq!(new_current_schema.len(), current_schema.len());

let proj_schema = std::mem::replace(current_schema, Arc::new(new_current_schema));

let moved_current = lp_arena.add(IR::Invalid);
Expand All @@ -199,8 +204,6 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &Arena<AExpr>)
};
let current = lp_arena.replace(current, projection);
lp_arena.replace(moved_current, current);
} else {
*input_schema = Arc::new(new_input_schema);
}
}
}
27 changes: 26 additions & 1 deletion py-polars/tests/unit/test_cwc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Tests for the optimization pass cluster WITH_COLUMNS


import polars as pl


Expand Down Expand Up @@ -169,3 +168,29 @@ def test_issue_16436() -> None:
.fill_nan(0)
.collect()
)


def test_issue_16435() -> None:
df = pl.DataFrame({"a": [1]}).lazy()

df = (
df.with_columns(b=pl.col("a"))
.with_columns(c=pl.col("b"))
.with_columns(l2a=pl.lit(2))
.with_columns(l2b=pl.col("l2a"))
.with_columns(m=pl.lit(3))
)

df.collect()


def test_reverse_order() -> None:
df = pl.LazyFrame({"a": [1], "b": [2]})

df = (
df.with_columns(a=pl.col("a"), b=pl.col("b"), c=pl.col("a") * pl.col("b"))
.with_columns(x=pl.col("a"), y=pl.col("b"))
.with_columns(b=pl.col("a"), a=pl.col("b"))
)

df.collect()

0 comments on commit 76f02bc

Please sign in to comment.