diff --git a/crates/polars-core/src/chunked_array/bitwise.rs b/crates/polars-core/src/chunked_array/bitwise.rs index a47cd9c82aa3..9e8fc482498c 100644 --- a/crates/polars-core/src/chunked_array/bitwise.rs +++ b/crates/polars-core/src/chunked_array/bitwise.rs @@ -161,7 +161,7 @@ impl BitAnd for &BooleanChunked { (1, 1) => {}, (1, _) => { return match self.get(0) { - Some(true) => rhs.clone(), + Some(true) => rhs.clone().with_name(self.name()), Some(false) => BooleanChunked::full(self.name(), false, rhs.len()), None => &self.new_from_index(0, rhs.len()) & rhs, }; diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index 4520afd845ad..ae2245ce3721 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -167,7 +167,10 @@ pub enum AExpr { options: FunctionOptions, }, Function { - /// function arguments + /// Function arguments + /// Some functions rely on aliases, + /// for instance assignment of struct fields. + /// Therefore we need `[ExprIr]`. input: Vec, /// function to apply function: FunctionExpr, diff --git a/crates/polars-plan/src/logical_plan/aexpr/schema.rs b/crates/polars-plan/src/logical_plan/aexpr/schema.rs index 9141bb60f37c..2dfb0eae0a0f 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/schema.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/schema.rs @@ -228,16 +228,24 @@ impl AExpr { } } -fn func_args_to_fields(input: &[ExprIR], schema: &Schema, arena: &Arena) -> PolarsResult> { +fn func_args_to_fields( + input: &[ExprIR], + schema: &Schema, + arena: &Arena, +) -> PolarsResult> { input .iter() // Default context because `col()` would return a list in aggregation context .map(|e| { - arena.get(e.node()).to_field(schema, Context::Default, arena).map(|mut field| { - field.name = e.output_name().into(); - field - }) - }).collect() + arena + .get(e.node()) + .to_field(schema, Context::Default, arena) + .map(|mut field| { + field.name = e.output_name().into(); + field + }) + }) + .collect() } fn get_arithmetic_field( diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index 943d450aab5c..a936ecb5a00e 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; + use polars_core::prelude::*; use polars_utils::vec::ConvertVec; use recursive::recursive; @@ -64,12 +65,9 @@ fn to_aexprs(input: Vec, arena: &mut Arena, state: &mut ConversionS .collect() } -fn set_function_output_name( - e: &[ExprIR], - state: &mut ConversionState, - function_fmt: F -) -where F: FnOnce() -> Cow<'static, str> +fn set_function_output_name(e: &[ExprIR], state: &mut ConversionState, function_fmt: F) +where + F: FnOnce() -> Cow<'static, str>, { if state.output_name.is_none() { if e.is_empty() { @@ -243,7 +241,15 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta options, } => { let e = to_expr_irs(input, arena); - set_function_output_name(&e, state, || Cow::Owned(format!("{}", &function))); + + if state.output_name.is_none() { + // Handles special case functions like `struct.field`. + if let Some(name) = function.output_name() { + state.output_name = OutputName::ColumnLhs(name.clone()) + } else { + set_function_output_name(&e, state, || Cow::Owned(format!("{}", &function))); + } + } AExpr::Function { input: e, function, diff --git a/crates/polars-plan/src/logical_plan/expr_ir.rs b/crates/polars-plan/src/logical_plan/expr_ir.rs index 3fd0fea2565a..6208934bb161 100644 --- a/crates/polars-plan/src/logical_plan/expr_ir.rs +++ b/crates/polars-plan/src/logical_plan/expr_ir.rs @@ -3,7 +3,7 @@ use std::hash::Hash; use std::hash::Hasher; use super::*; -use crate::constants::LITERAL_NAME; +use crate::constants::{get_len_name, LITERAL_NAME}; #[derive(Default, Debug, Clone, Hash, PartialEq, Eq)] pub enum OutputName { @@ -64,9 +64,32 @@ impl ExprIR { } break; }, + AExpr::Function { + input, function, .. + } => { + if input.is_empty() { + out.output_name = + OutputName::LiteralLhs(ColumnName::from(format!("{}", function))); + } else { + out.output_name = input[0].output_name.clone(); + } + break; + }, + AExpr::AnonymousFunction { input, options, .. } => { + if input.is_empty() { + out.output_name = OutputName::LiteralLhs(ColumnName::from(options.fmt_str)); + } else { + out.output_name = input[0].output_name.clone(); + } + break; + }, + AExpr::Len => out.output_name = OutputName::LiteralLhs(get_len_name()), AExpr::Alias(_, _) => { // Should be removed during conversion. - unreachable!() + #[cfg(debug_assertions)] + { + unreachable!() + } }, _ => {}, } diff --git a/crates/polars-plan/src/logical_plan/optimizer/fused.rs b/crates/polars-plan/src/logical_plan/optimizer/fused.rs index 0221e7498d9b..01692b30587b 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/fused.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/fused.rs @@ -3,7 +3,11 @@ use super::*; pub struct FusedArithmetic {} fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena) -> AExpr { - let input = input.iter().copied().map(|n| ExprIR::from_node(n, expr_arena)).collect(); + let input = input + .iter() + .copied() + .map(|n| ExprIR::from_node(n, expr_arena)) + .collect(); let mut options = FunctionOptions { collect_groups: ApplyOptions::ElementWise, cast_to_supertypes: true, @@ -111,7 +115,11 @@ impl OptimizationRule for FusedArithmetic { (None, _) | (Some(false), _) => Ok(None), (Some(true), _) => { let input = &[*left, *a, *b]; - Ok(Some(get_expr(input, FusedOperator::MultiplyAdd, expr_arena))) + Ok(Some(get_expr( + input, + FusedOperator::MultiplyAdd, + expr_arena, + ))) }, }, _ => Ok(None), @@ -137,7 +145,11 @@ impl OptimizationRule for FusedArithmetic { (None, _) | (Some(false), _) => Ok(None), (Some(true), _) => { let input = &[*left, *a, *b]; - Ok(Some(get_expr(input, FusedOperator::SubMultiply, expr_arena))) + Ok(Some(get_expr( + input, + FusedOperator::SubMultiply, + expr_arena, + ))) }, }, _ => { @@ -155,7 +167,11 @@ impl OptimizationRule for FusedArithmetic { (None, _) | (Some(false), _) => Ok(None), (Some(true), _) => { let input = &[*a, *b, *right]; - Ok(Some(get_expr(input, FusedOperator::MultiplySub, expr_arena))) + Ok(Some(get_expr( + input, + FusedOperator::MultiplySub, + expr_arena, + ))) }, } }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs index 5caa1abfcb3d..e6c4c08a4862 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs @@ -1,7 +1,6 @@ use polars_utils::floor_divmod::FloorDivMod; use polars_utils::total_ord::ToTotalOrd; -use crate::constants::get_literal_name; use crate::logical_plan::*; use crate::prelude::optimizer::simplify_functions::optimize_functions; @@ -108,10 +107,12 @@ impl OptimizationRule for SimplifyBooleanRule { &mut self, expr_arena: &mut Arena, expr_node: Node, - _: &Arena, - _: Node, + lp_arena: &Arena, + lp_node: Node, ) -> PolarsResult> { let expr = expr_arena.get(expr_node); + let in_filter = matches!(lp_arena.get(lp_node), ALogicalPlan::Selection { .. }); + let out = match expr { // true AND x => x AExpr::BinaryExpr { @@ -121,10 +122,11 @@ impl OptimizationRule for SimplifyBooleanRule { } if matches!( expr_arena.get(*left), AExpr::Literal(LiteralValue::Boolean(true)) - ) => + ) && in_filter => { - // We alias because of the left-hand naming rule. - Some(AExpr::Alias(*right, get_literal_name())) + // Only in filter as we we might change the name from "literal" + // to whatever lhs columns is. + return Ok(Some(expr_arena.get(*right).clone())); }, // x AND true => x AExpr::BinaryExpr { @@ -178,10 +180,11 @@ impl OptimizationRule for SimplifyBooleanRule { } if matches!( expr_arena.get(*left), AExpr::Literal(LiteralValue::Boolean(false)) - ) => + ) && in_filter => { - // We alias because of the left-hand naming rule. - Some(AExpr::Alias(*right, get_literal_name())) + // Only in filter as we we might change the name from "literal" + // to whatever lhs columns is. + return Ok(Some(expr_arena.get(*right).clone())); }, // x or false => x AExpr::BinaryExpr { @@ -236,16 +239,6 @@ impl OptimizationRule for SimplifyBooleanRule { let ae = expr_arena.get(input.node()); eval_negate(ae) }, - // Flatten Aliases. - AExpr::Alias(inner, name) => { - let input = expr_arena.get(*inner); - - if let AExpr::Alias(input, _) = input { - Some(AExpr::Alias(*input, name.clone())) - } else { - None - } - }, _ => None, }; Ok(out) diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs index d19aa32c013d..2dcc9847b9ac 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs @@ -137,7 +137,7 @@ pub(super) fn optimize_functions( } => Some(AExpr::Function { input: input.clone(), function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), - options: options.clone(), + options: *options, }), // not(x.is_not_null) => x.is_null AExpr::Function { @@ -147,7 +147,7 @@ pub(super) fn optimize_functions( } => Some(AExpr::Function { input: input.clone(), function: FunctionExpr::Boolean(BooleanFunction::IsNull), - options: options.clone(), + options: *options, }), // not(a == b) => a != b AExpr::BinaryExpr { @@ -239,7 +239,7 @@ pub(super) fn optimize_functions( left: expr_arena.add(AExpr::BinaryExpr { left: left_left, op: left_cmp_op, - right: right_left + right: right_left, }), // OR op: Operator::Or, @@ -247,7 +247,7 @@ pub(super) fn optimize_functions( right: expr_arena.add(AExpr::BinaryExpr { left: left_right, op: right_cmp_op, - right: right_right + right: right_right, }), }) } diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_expr.rs index e32fdfde4090..4e13f07a8009 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_expr.rs @@ -80,12 +80,10 @@ impl OptimizationRule for SlicePushDown { options, } = m.clone() { - input - .iter_mut() - .for_each(|e| { - let n = pushdown(e.node(), offset, length, expr_arena); - e.set_node(n); - }); + input.iter_mut().for_each(|e| { + let n = pushdown(e.node(), offset, length, expr_arena); + e.set_node(n); + }); Some(AnonymousFunction { input, @@ -98,30 +96,28 @@ impl OptimizationRule for SlicePushDown { } }, m @ Function { options, .. } - if matches!(options.collect_groups, ApplyOptions::ElementWise) => + if matches!(options.collect_groups, ApplyOptions::ElementWise) => + { + if let Function { + mut input, + function, + options, + } = m.clone() { - if let Function { - mut input, + input.iter_mut().for_each(|e| { + let n = pushdown(e.node(), offset, length, expr_arena); + e.set_node(n); + }); + + Some(Function { + input, function, options, - } = m.clone() - { - input - .iter_mut() - .for_each(|e| { - let n = pushdown(e.node(), offset, length, expr_arena); - e.set_node(n); - }); - - Some(Function { - input, - function, - options, - }) - } else { - unreachable!() - } - }, + }) + } else { + unreachable!() + } + }, _ => None, }; Ok(out) diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index 14c9caa8128f..8d6700780dab 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -349,10 +349,16 @@ impl OptimizationRule for TypeCoercionRule { } => { let input_schema = get_schema(lp_arena, lp_node); let other_e = &input[1]; - let (_, type_left) = - unpack!(get_aexpr_and_type(expr_arena, input[0].node(), &input_schema)); - let (_, type_other) = - unpack!(get_aexpr_and_type(expr_arena, other_e.node(), &input_schema)); + let (_, type_left) = unpack!(get_aexpr_and_type( + expr_arena, + input[0].node(), + &input_schema + )); + let (_, type_other) = unpack!(get_aexpr_and_type( + expr_arena, + other_e.node(), + &input_schema + )); unpack!(early_escape(&type_left, &type_other)); @@ -499,8 +505,11 @@ impl OptimizationRule for TypeCoercionRule { } => { let input_schema = get_schema(lp_arena, lp_node); let other_node = input[1].node(); - let (left, type_left) = - unpack!(get_aexpr_and_type(expr_arena, input[0].node(), &input_schema)); + let (left, type_left) = unpack!(get_aexpr_and_type( + expr_arena, + input[0].node(), + &input_schema + )); let (fill_value, type_fill_value) = unpack!(get_aexpr_and_type(expr_arena, other_node, &input_schema)); @@ -579,7 +588,7 @@ impl OptimizationRule for TypeCoercionRule { }; let mut other_node = other_node.clone(); if type_other != super_type { - let n= expr_arena.add(AExpr::Cast { + let n = expr_arena.add(AExpr::Cast { expr: other_node.node(), data_type: super_type.clone(), strict: false, diff --git a/crates/polars-plan/src/logical_plan/visitor/expr.rs b/crates/polars-plan/src/logical_plan/visitor/expr.rs index ab0deb88c10b..f4f9cef19474 100644 --- a/crates/polars-plan/src/logical_plan/visitor/expr.rs +++ b/crates/polars-plan/src/logical_plan/visitor/expr.rs @@ -225,13 +225,15 @@ impl AexprNode { function: fr, options: or, }, - ) => fl == fr && ol == or && { - let mut all_same_name = true; - for (l,r) in il.iter().zip(ir) { - all_same_name &= l.output_name() == r.output_name() + ) => { + fl == fr && ol == or && { + let mut all_same_name = true; + for (l, r) in il.iter().zip(ir) { + all_same_name &= l.output_name() == r.output_name() + } + + all_same_name } - - all_same_name }, (AnonymousFunction { .. }, AnonymousFunction { .. }) => false, (BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r, diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index 6c662359ed19..4d01f5ec674e 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -645,3 +645,16 @@ def s_per_count(count_diff, span) -> pl.Expr: assert_frame_equal( ldf.collect(comm_subexpr_elim=True), ldf.collect(comm_subexpr_elim=False) ) + + +def test_cse_15536() -> None: + source = pl.DataFrame({"a": range(10)}) + + data = source.lazy().filter(pl.col("a") >= 5) + + assert pl.concat( + [ + data.filter(pl.lit(True) & (pl.col("a") == 6) | (pl.col("a") == 9)), + data.filter(pl.lit(True) & (pl.col("a") == 7) | (pl.col("a") == 8)), + ] + ).collect()["a"].to_list() == [6, 9, 7, 8] diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index e8d615d65857..b0785dad8987 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -662,3 +662,21 @@ def test_alias_prune_in_fold_15438() -> None: } ) assert_frame_equal(df, expected) + + +def test_resolved_names_15442() -> None: + df = pl.DataFrame( + { + "x": [206.0], + "y": [225.0], + } + ) + center = pl.struct( + x=pl.col("x"), + y=pl.col("y"), + ) + + left = 0 + right = 1000 + in_x = (left < center.struct.field("x")) & (center.struct.field("x") <= right) + assert df.lazy().filter(in_x).collect().shape == (1, 2)