Skip to content

Commit

Permalink
only simplify in filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 8, 2024
1 parent fdc5c49 commit 8476cf0
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 84 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/bitwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl BitAnd for &BooleanChunked {
(1, 1) => {},
(1, _) => {
return match self.get(0) {
Some(true) => rhs.clone(),
Some(true) => rhs.clone().with_name(self.name()),
Some(false) => BooleanChunked::full(self.name(), false, rhs.len()),
None => &self.new_from_index(0, rhs.len()) & rhs,
};
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-plan/src/logical_plan/aexpr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ pub enum AExpr {
options: FunctionOptions,
},
Function {
/// function arguments
/// Function arguments
/// Some functions rely on aliases,
/// for instance assignment of struct fields.
/// Therefore we need `[ExprIr]`.
input: Vec<ExprIR>,
/// function to apply
function: FunctionExpr,
Expand Down
20 changes: 14 additions & 6 deletions crates/polars-plan/src/logical_plan/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,24 @@ impl AExpr {
}
}

fn func_args_to_fields(input: &[ExprIR], schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<Vec<Field>> {
fn func_args_to_fields(
input: &[ExprIR],
schema: &Schema,
arena: &Arena<AExpr>,
) -> PolarsResult<Vec<Field>> {
input
.iter()
// Default context because `col()` would return a list in aggregation context
.map(|e| {
arena.get(e.node()).to_field(schema, Context::Default, arena).map(|mut field| {
field.name = e.output_name().into();
field
})
}).collect()
arena
.get(e.node())
.to_field(schema, Context::Default, arena)
.map(|mut field| {
field.name = e.output_name().into();
field
})
})
.collect()
}

fn get_arithmetic_field(
Expand Down
20 changes: 13 additions & 7 deletions crates/polars-plan/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;

use polars_core::prelude::*;
use polars_utils::vec::ConvertVec;
use recursive::recursive;
Expand Down Expand Up @@ -64,12 +65,9 @@ fn to_aexprs(input: Vec<Expr>, arena: &mut Arena<AExpr>, state: &mut ConversionS
.collect()
}

fn set_function_output_name<F>(
e: &[ExprIR],
state: &mut ConversionState,
function_fmt: F
)
where F: FnOnce() -> Cow<'static, str>
fn set_function_output_name<F>(e: &[ExprIR], state: &mut ConversionState, function_fmt: F)
where
F: FnOnce() -> Cow<'static, str>,
{
if state.output_name.is_none() {
if e.is_empty() {
Expand Down Expand Up @@ -243,7 +241,15 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena<AExpr>, state: &mut ConversionSta
options,
} => {
let e = to_expr_irs(input, arena);
set_function_output_name(&e, state, || Cow::Owned(format!("{}", &function)));

if state.output_name.is_none() {
// Handles special case functions like `struct.field`.
if let Some(name) = function.output_name() {
state.output_name = OutputName::ColumnLhs(name.clone())
} else {
set_function_output_name(&e, state, || Cow::Owned(format!("{}", &function)));
}
}
AExpr::Function {
input: e,
function,
Expand Down
27 changes: 25 additions & 2 deletions crates/polars-plan/src/logical_plan/expr_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::hash::Hash;
use std::hash::Hasher;

use super::*;
use crate::constants::LITERAL_NAME;
use crate::constants::{get_len_name, LITERAL_NAME};

#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
pub enum OutputName {
Expand Down Expand Up @@ -64,9 +64,32 @@ impl ExprIR {
}
break;
},
AExpr::Function {
input, function, ..
} => {
if input.is_empty() {
out.output_name =
OutputName::LiteralLhs(ColumnName::from(format!("{}", function)));
} else {
out.output_name = input[0].output_name.clone();
}
break;
},
AExpr::AnonymousFunction { input, options, .. } => {
if input.is_empty() {
out.output_name = OutputName::LiteralLhs(ColumnName::from(options.fmt_str));
} else {
out.output_name = input[0].output_name.clone();
}
break;
},
AExpr::Len => out.output_name = OutputName::LiteralLhs(get_len_name()),
AExpr::Alias(_, _) => {
// Should be removed during conversion.
unreachable!()
#[cfg(debug_assertions)]
{
unreachable!()
}
},
_ => {},
}
Expand Down
24 changes: 20 additions & 4 deletions crates/polars-plan/src/logical_plan/optimizer/fused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ use super::*;
pub struct FusedArithmetic {}

fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena<AExpr>) -> AExpr {
let input = input.iter().copied().map(|n| ExprIR::from_node(n, expr_arena)).collect();
let input = input
.iter()
.copied()
.map(|n| ExprIR::from_node(n, expr_arena))
.collect();
let mut options = FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_to_supertypes: true,
Expand Down Expand Up @@ -111,7 +115,11 @@ impl OptimizationRule for FusedArithmetic {
(None, _) | (Some(false), _) => Ok(None),
(Some(true), _) => {
let input = &[*left, *a, *b];
Ok(Some(get_expr(input, FusedOperator::MultiplyAdd, expr_arena)))
Ok(Some(get_expr(
input,
FusedOperator::MultiplyAdd,
expr_arena,
)))
},
},
_ => Ok(None),
Expand All @@ -137,7 +145,11 @@ impl OptimizationRule for FusedArithmetic {
(None, _) | (Some(false), _) => Ok(None),
(Some(true), _) => {
let input = &[*left, *a, *b];
Ok(Some(get_expr(input, FusedOperator::SubMultiply, expr_arena)))
Ok(Some(get_expr(
input,
FusedOperator::SubMultiply,
expr_arena,
)))
},
},
_ => {
Expand All @@ -155,7 +167,11 @@ impl OptimizationRule for FusedArithmetic {
(None, _) | (Some(false), _) => Ok(None),
(Some(true), _) => {
let input = &[*a, *b, *right];
Ok(Some(get_expr(input, FusedOperator::MultiplySub, expr_arena)))
Ok(Some(get_expr(
input,
FusedOperator::MultiplySub,
expr_arena,
)))
},
}
},
Expand Down
31 changes: 12 additions & 19 deletions crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use polars_utils::floor_divmod::FloorDivMod;
use polars_utils::total_ord::ToTotalOrd;

use crate::constants::get_literal_name;
use crate::logical_plan::*;
use crate::prelude::optimizer::simplify_functions::optimize_functions;

Expand Down Expand Up @@ -108,10 +107,12 @@ impl OptimizationRule for SimplifyBooleanRule {
&mut self,
expr_arena: &mut Arena<AExpr>,
expr_node: Node,
_: &Arena<ALogicalPlan>,
_: Node,
lp_arena: &Arena<ALogicalPlan>,
lp_node: Node,
) -> PolarsResult<Option<AExpr>> {
let expr = expr_arena.get(expr_node);
let in_filter = matches!(lp_arena.get(lp_node), ALogicalPlan::Selection { .. });

let out = match expr {
// true AND x => x
AExpr::BinaryExpr {
Expand All @@ -121,10 +122,11 @@ impl OptimizationRule for SimplifyBooleanRule {
} if matches!(
expr_arena.get(*left),
AExpr::Literal(LiteralValue::Boolean(true))
) =>
) && in_filter =>
{
// We alias because of the left-hand naming rule.
Some(AExpr::Alias(*right, get_literal_name()))
// Only in filter as we we might change the name from "literal"
// to whatever lhs columns is.
return Ok(Some(expr_arena.get(*right).clone()));
},
// x AND true => x
AExpr::BinaryExpr {
Expand Down Expand Up @@ -178,10 +180,11 @@ impl OptimizationRule for SimplifyBooleanRule {
} if matches!(
expr_arena.get(*left),
AExpr::Literal(LiteralValue::Boolean(false))
) =>
) && in_filter =>
{
// We alias because of the left-hand naming rule.
Some(AExpr::Alias(*right, get_literal_name()))
// Only in filter as we we might change the name from "literal"
// to whatever lhs columns is.
return Ok(Some(expr_arena.get(*right).clone()));
},
// x or false => x
AExpr::BinaryExpr {
Expand Down Expand Up @@ -236,16 +239,6 @@ impl OptimizationRule for SimplifyBooleanRule {
let ae = expr_arena.get(input.node());
eval_negate(ae)
},
// Flatten Aliases.
AExpr::Alias(inner, name) => {
let input = expr_arena.get(*inner);

if let AExpr::Alias(input, _) = input {
Some(AExpr::Alias(*input, name.clone()))
} else {
None
}
},
_ => None,
};
Ok(out)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ pub(super) fn optimize_functions(
} => Some(AExpr::Function {
input: input.clone(),
function: FunctionExpr::Boolean(BooleanFunction::IsNotNull),
options: options.clone(),
options: *options,
}),
// not(x.is_not_null) => x.is_null
AExpr::Function {
Expand All @@ -147,7 +147,7 @@ pub(super) fn optimize_functions(
} => Some(AExpr::Function {
input: input.clone(),
function: FunctionExpr::Boolean(BooleanFunction::IsNull),
options: options.clone(),
options: *options,
}),
// not(a == b) => a != b
AExpr::BinaryExpr {
Expand Down Expand Up @@ -239,15 +239,15 @@ pub(super) fn optimize_functions(
left: expr_arena.add(AExpr::BinaryExpr {
left: left_left,
op: left_cmp_op,
right: right_left
right: right_left,
}),
// OR
op: Operator::Or,
// input[0] (>,>=) input[2]
right: expr_arena.add(AExpr::BinaryExpr {
left: left_right,
op: right_cmp_op,
right: right_right
right: right_right,
}),
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,10 @@ impl OptimizationRule for SlicePushDown {
options,
} = m.clone()
{
input
.iter_mut()
.for_each(|e| {
let n = pushdown(e.node(), offset, length, expr_arena);
e.set_node(n);
});
input.iter_mut().for_each(|e| {
let n = pushdown(e.node(), offset, length, expr_arena);
e.set_node(n);
});

Some(AnonymousFunction {
input,
Expand All @@ -98,30 +96,28 @@ impl OptimizationRule for SlicePushDown {
}
},
m @ Function { options, .. }
if matches!(options.collect_groups, ApplyOptions::ElementWise) =>
if matches!(options.collect_groups, ApplyOptions::ElementWise) =>
{
if let Function {
mut input,
function,
options,
} = m.clone()
{
if let Function {
mut input,
input.iter_mut().for_each(|e| {
let n = pushdown(e.node(), offset, length, expr_arena);
e.set_node(n);
});

Some(Function {
input,
function,
options,
} = m.clone()
{
input
.iter_mut()
.for_each(|e| {
let n = pushdown(e.node(), offset, length, expr_arena);
e.set_node(n);
});

Some(Function {
input,
function,
options,
})
} else {
unreachable!()
}
},
})
} else {
unreachable!()
}
},
_ => None,
};
Ok(out)
Expand Down
Loading

0 comments on commit 8476cf0

Please sign in to comment.