From be4169796e6974252ada6e13cf6127d85c9bdb4f Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Sat, 30 Mar 2024 16:41:48 +0100 Subject: [PATCH] refactor: make dsl immutable and cheap to clone (#15394) --- .../polars-lazy/src/physical_plan/exotic.rs | 19 +- crates/polars-plan/src/dsl/arity.rs | 10 +- crates/polars-plan/src/dsl/expr.rs | 72 ++++---- .../src/dsl/functions/syntactic_sugar.rs | 2 +- crates/polars-plan/src/dsl/meta.rs | 19 +- crates/polars-plan/src/dsl/mod.rs | 58 +++--- crates/polars-plan/src/dsl/name.rs | 4 +- crates/polars-plan/src/dsl/statistics.rs | 18 +- .../src/logical_plan/conversion.rs | 145 ++++++++------- .../polars-plan/src/logical_plan/iterator.rs | 157 ++++++++--------- .../optimizer/predicate_pushdown/mod.rs | 16 +- .../src/logical_plan/projection.rs | 166 +++++++----------- .../src/logical_plan/visitor/expr.rs | 58 +++++- crates/polars-plan/src/utils.rs | 9 +- crates/polars-sql/src/context.rs | 15 +- crates/polars-utils/src/functions.rs | 39 ++++ 16 files changed, 419 insertions(+), 388 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/exotic.rs b/crates/polars-lazy/src/physical_plan/exotic.rs index 14fb5a5c3517..664cd2bfbb2d 100644 --- a/crates/polars-lazy/src/physical_plan/exotic.rs +++ b/crates/polars-lazy/src/physical_plan/exotic.rs @@ -4,19 +4,12 @@ use crate::physical_plan::planner::create_physical_expr; use crate::prelude::*; #[cfg(feature = "pivot")] -pub(crate) fn prepare_eval_expr(mut expr: Expr) -> Expr { - expr.mutate().apply(|e| match e { - Expr::Column(name) => { - *name = Arc::from(""); - true - }, - Expr::Nth(_) => { - *e = Expr::Column(Arc::from("")); - true - }, - _ => true, - }); - expr +pub(crate) fn prepare_eval_expr(expr: Expr) -> Expr { + expr.map_expr(|e| match e { + Expr::Column(_) => Expr::Column(Arc::from("")), + Expr::Nth(_) => Expr::Column(Arc::from("")), + e => e, + }) } pub(crate) fn prepare_expression_for_context( diff --git a/crates/polars-plan/src/dsl/arity.rs b/crates/polars-plan/src/dsl/arity.rs index 05ff22df52b0..9883936f6c10 100644 --- a/crates/polars-plan/src/dsl/arity.rs +++ b/crates/polars-plan/src/dsl/arity.rs @@ -139,17 +139,17 @@ pub fn when>(condition: E) -> When { pub fn ternary_expr(predicate: Expr, truthy: Expr, falsy: Expr) -> Expr { Expr::Ternary { - predicate: Box::new(predicate), - truthy: Box::new(truthy), - falsy: Box::new(falsy), + predicate: Arc::new(predicate), + truthy: Arc::new(truthy), + falsy: Arc::new(falsy), } } /// Compute `op(l, r)` (or equivalently `l op r`). `l` and `r` must have types compatible with the Operator. pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { Expr::BinaryExpr { - left: Box::new(l), + left: Arc::new(l), op, - right: Box::new(r), + right: Arc::new(r), } } diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index cb88fc45be55..5e8a31dd65f2 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -12,30 +12,30 @@ use crate::prelude::*; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum AggExpr { Min { - input: Box, + input: Arc, propagate_nans: bool, }, Max { - input: Box, + input: Arc, propagate_nans: bool, }, - Median(Box), - NUnique(Box), - First(Box), - Last(Box), - Mean(Box), - Implode(Box), + Median(Arc), + NUnique(Arc), + First(Arc), + Last(Arc), + Mean(Arc), + Implode(Arc), // include_nulls - Count(Box, bool), + Count(Arc, bool), Quantile { - expr: Box, - quantile: Box, + expr: Arc, + quantile: Arc, interpol: QuantileInterpolOptions, }, - Sum(Box), - AggGroups(Box), - Std(Box, u8), - Var(Box, u8), + Sum(Arc), + AggGroups(Arc), + Std(Arc, u8), + Var(Arc, u8), } impl AsRef for AggExpr { @@ -67,32 +67,32 @@ impl AsRef for AggExpr { #[must_use] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum Expr { - Alias(Box, Arc), + Alias(Arc, Arc), Column(Arc), Columns(Vec), DtypeColumn(Vec), Literal(LiteralValue), BinaryExpr { - left: Box, + left: Arc, op: Operator, - right: Box, + right: Arc, }, Cast { - expr: Box, + expr: Arc, data_type: DataType, strict: bool, }, Sort { - expr: Box, + expr: Arc, options: SortOptions, }, Gather { - expr: Box, - idx: Box, + expr: Arc, + idx: Arc, returns_scalar: bool, }, SortBy { - expr: Box, + expr: Arc, by: Vec, descending: Vec, }, @@ -100,9 +100,9 @@ pub enum Expr { /// A ternary operation /// if true then "foo" else "bar" Ternary { - predicate: Box, - truthy: Box, - falsy: Box, + predicate: Arc, + truthy: Arc, + falsy: Arc, }, Function { /// function arguments @@ -111,29 +111,29 @@ pub enum Expr { function: FunctionExpr, options: FunctionOptions, }, - Explode(Box), + Explode(Arc), Filter { - input: Box, - by: Box, + input: Arc, + by: Arc, }, /// See postgres window functions Window { /// Also has the input. i.e. avg("foo") - function: Box, + function: Arc, partition_by: Vec, options: WindowType, }, Wildcard, Slice { - input: Box, + input: Arc, /// length is not yet known so we accept negative offsets - offset: Box, - length: Box, + offset: Arc, + length: Arc, }, /// Can be used in a select statement to exclude a column from selection - Exclude(Box, Vec), + Exclude(Arc, Vec), /// Set root name as Alias - KeepName(Box), + KeepName(Arc), Len, /// Take the nth column in the `DataFrame` Nth(i64), @@ -141,7 +141,7 @@ pub enum Expr { #[cfg_attr(feature = "serde", serde(skip))] RenameAlias { function: SpecialEq>, - expr: Box, + expr: Arc, }, AnonymousFunction { /// function arguments diff --git a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs index df778ee60ee6..5315709da4cf 100644 --- a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs +++ b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs @@ -57,7 +57,7 @@ pub fn is_not_null(expr: Expr) -> Expr { /// nominal type of the column. pub fn cast(expr: Expr, data_type: DataType) -> Expr { Expr::Cast { - expr: Box::new(expr), + expr: Arc::new(expr), data_type, strict: false, } diff --git a/crates/polars-plan/src/dsl/meta.rs b/crates/polars-plan/src/dsl/meta.rs index 28a554007a50..ac753024b8ce 100644 --- a/crates/polars-plan/src/dsl/meta.rs +++ b/crates/polars-plan/src/dsl/meta.rs @@ -41,22 +41,13 @@ impl MetaNameSpace { } /// Undo any renaming operation like `alias`, `keep_name`. - pub fn undo_aliases(mut self) -> Expr { - self.0.mutate().apply(|e| match e { + pub fn undo_aliases(self) -> Expr { + self.0.map_expr(|e| match e { Expr::Alias(input, _) | Expr::KeepName(input) - | Expr::RenameAlias { expr: input, .. } => { - // remove this node - *e = *input.clone(); - - // continue iteration - true - }, - // continue iteration - _ => true, - }); - - self.0 + | Expr::RenameAlias { expr: input, .. } => Arc::unwrap_or_clone(input), + e => e, + }) } /// Indicate if this expression expands to multiple expressions. diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 872153db3961..0d5c7f5025b6 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -166,7 +166,7 @@ impl Expr { /// Rename Column. pub fn alias(self, name: &str) -> Expr { - Expr::Alias(Box::new(self), ColumnName::from(name)) + Expr::Alias(Arc::new(self), ColumnName::from(name)) } /// Run is_null operation on `Expr`. @@ -193,29 +193,29 @@ impl Expr { /// Get the number of unique values in the groups. pub fn n_unique(self) -> Self { - AggExpr::NUnique(Box::new(self)).into() + AggExpr::NUnique(Arc::new(self)).into() } /// Get the first value in the group. pub fn first(self) -> Self { - AggExpr::First(Box::new(self)).into() + AggExpr::First(Arc::new(self)).into() } /// Get the last value in the group. pub fn last(self) -> Self { - AggExpr::Last(Box::new(self)).into() + AggExpr::Last(Arc::new(self)).into() } /// Aggregate the group to a Series. pub fn implode(self) -> Self { - AggExpr::Implode(Box::new(self)).into() + AggExpr::Implode(Arc::new(self)).into() } /// Compute the quantile per group. pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { AggExpr::Quantile { - expr: Box::new(self), - quantile: Box::new(quantile), + expr: Arc::new(self), + quantile: Arc::new(quantile), interpol, } .into() @@ -223,7 +223,7 @@ impl Expr { /// Get the group indexes of the group by operation. pub fn agg_groups(self) -> Self { - AggExpr::AggGroups(Box::new(self)).into() + AggExpr::AggGroups(Arc::new(self)).into() } /// Alias for `explode`. @@ -233,16 +233,16 @@ impl Expr { /// Explode the String/List column. pub fn explode(self) -> Self { - Expr::Explode(Box::new(self)) + Expr::Explode(Arc::new(self)) } /// Slice the Series. /// `offset` may be negative. pub fn slice, F: Into>(self, offset: E, length: F) -> Self { Expr::Slice { - input: Box::new(self), - offset: Box::new(offset.into()), - length: Box::new(length.into()), + input: Arc::new(self), + offset: Arc::new(offset.into()), + length: Arc::new(length.into()), } } @@ -375,7 +375,7 @@ impl Expr { /// Throws an error if conversion had overflows. pub fn strict_cast(self, data_type: DataType) -> Self { Expr::Cast { - expr: Box::new(self), + expr: Arc::new(self), data_type, strict: true, } @@ -384,7 +384,7 @@ impl Expr { /// Cast expression to another data type. pub fn cast(self, data_type: DataType) -> Self { Expr::Cast { - expr: Box::new(self), + expr: Arc::new(self), data_type, strict: false, } @@ -393,8 +393,8 @@ impl Expr { /// Take the values by idx. pub fn gather>(self, idx: E) -> Self { Expr::Gather { - expr: Box::new(self), - idx: Box::new(idx.into()), + expr: Arc::new(self), + idx: Arc::new(idx.into()), returns_scalar: false, } } @@ -402,8 +402,8 @@ impl Expr { /// Take the values by a single index. pub fn get>(self, idx: E) -> Self { Expr::Gather { - expr: Box::new(self), - idx: Box::new(idx.into()), + expr: Arc::new(self), + idx: Arc::new(idx.into()), returns_scalar: true, } } @@ -411,7 +411,7 @@ impl Expr { /// Sort in increasing order. See [the eager implementation](Series::sort). pub fn sort(self, descending: bool) -> Self { Expr::Sort { - expr: Box::new(self), + expr: Arc::new(self), options: SortOptions { descending, ..Default::default() @@ -422,7 +422,7 @@ impl Expr { /// Sort with given options. pub fn sort_with(self, options: SortOptions) -> Self { Expr::Sort { - expr: Box::new(self), + expr: Arc::new(self), options, } } @@ -903,7 +903,7 @@ impl Expr { .map(|e| e.clone().into()) .collect(); Expr::Window { - function: Box::new(self), + function: Arc::new(self), partition_by, options: options.into(), } @@ -915,7 +915,7 @@ impl Expr { // not ignore it. let index_col = col(options.index_column.as_str()); Expr::Window { - function: Box::new(self), + function: Arc::new(self), partition_by: vec![index_col], options: WindowType::Rolling(options), } @@ -961,11 +961,11 @@ impl Expr { /// or /// Get counts of the group by operation. pub fn count(self) -> Self { - AggExpr::Count(Box::new(self), false).into() + AggExpr::Count(Arc::new(self), false).into() } pub fn len(self) -> Self { - AggExpr::Count(Box::new(self), true).into() + AggExpr::Count(Arc::new(self), true).into() } /// Get a mask of duplicated values. @@ -1037,8 +1037,8 @@ impl Expr { panic!("filter '*' not allowed, use LazyFrame::filter") }; Expr::Filter { - input: Box::new(self), - by: Box::new(predicate.into()), + input: Arc::new(self), + by: Arc::new(predicate.into()), } } @@ -1081,7 +1081,7 @@ impl Expr { let by = by.as_ref().iter().map(|e| e.clone().into()).collect(); let descending = descending.as_ref().to_vec(); Expr::SortBy { - expr: Box::new(self), + expr: Arc::new(self), by, descending, } @@ -1137,7 +1137,7 @@ impl Expr { .into_iter() .map(|s| Excluded::Name(ColumnName::from(s))) .collect(); - Expr::Exclude(Box::new(self), v) + Expr::Exclude(Arc::new(self), v) } pub fn exclude_dtype>(self, dtypes: D) -> Expr { @@ -1146,7 +1146,7 @@ impl Expr { .iter() .map(|dt| Excluded::Dtype(dt.clone())) .collect(); - Expr::Exclude(Box::new(self), v) + Expr::Exclude(Arc::new(self), v) } #[cfg(feature = "interpolate")] diff --git a/crates/polars-plan/src/dsl/name.rs b/crates/polars-plan/src/dsl/name.rs index 61fe7951e741..def56c87e87e 100644 --- a/crates/polars-plan/src/dsl/name.rs +++ b/crates/polars-plan/src/dsl/name.rs @@ -21,7 +21,7 @@ impl ExprNameNameSpace { /// } /// ``` pub fn keep(self) -> Expr { - Expr::KeepName(Box::new(self.0)) + Expr::KeepName(Arc::new(self.0)) } /// Define an alias by mapping a function over the original root column name. @@ -31,7 +31,7 @@ impl ExprNameNameSpace { { let function = SpecialEq::new(Arc::new(function) as Arc); Expr::RenameAlias { - expr: Box::new(self.0), + expr: Arc::new(self.0), function, } } diff --git a/crates/polars-plan/src/dsl/statistics.rs b/crates/polars-plan/src/dsl/statistics.rs index 20a63d1e2bf1..6220f6b88b58 100644 --- a/crates/polars-plan/src/dsl/statistics.rs +++ b/crates/polars-plan/src/dsl/statistics.rs @@ -3,18 +3,18 @@ use super::*; impl Expr { /// Standard deviation of the values of the Series. pub fn std(self, ddof: u8) -> Self { - AggExpr::Std(Box::new(self), ddof).into() + AggExpr::Std(Arc::new(self), ddof).into() } /// Variance of the values of the Series. pub fn var(self, ddof: u8) -> Self { - AggExpr::Var(Box::new(self), ddof).into() + AggExpr::Var(Arc::new(self), ddof).into() } /// Reduce groups to minimal value. pub fn min(self) -> Self { AggExpr::Min { - input: Box::new(self), + input: Arc::new(self), propagate_nans: false, } .into() @@ -23,7 +23,7 @@ impl Expr { /// Reduce groups to maximum value. pub fn max(self) -> Self { AggExpr::Max { - input: Box::new(self), + input: Arc::new(self), propagate_nans: false, } .into() @@ -32,7 +32,7 @@ impl Expr { /// Reduce groups to minimal value. pub fn nan_min(self) -> Self { AggExpr::Min { - input: Box::new(self), + input: Arc::new(self), propagate_nans: true, } .into() @@ -41,7 +41,7 @@ impl Expr { /// Reduce groups to maximum value. pub fn nan_max(self) -> Self { AggExpr::Max { - input: Box::new(self), + input: Arc::new(self), propagate_nans: true, } .into() @@ -49,17 +49,17 @@ impl Expr { /// Reduce groups to the mean value. pub fn mean(self) -> Self { - AggExpr::Mean(Box::new(self)).into() + AggExpr::Mean(Arc::new(self)).into() } /// Reduce groups to the median value. pub fn median(self) -> Self { - AggExpr::Median(Box::new(self)).into() + AggExpr::Median(Arc::new(self)).into() } /// Reduce groups to the sum of all the values. pub fn sum(self) -> Self { - AggExpr::Sum(Box::new(self)).into() + AggExpr::Sum(Arc::new(self)).into() } /// Compute the histogram of a dataset. diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index 1cfb23b79bc4..3c1f6f65ab3a 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -66,17 +66,18 @@ fn to_aexprs(input: Vec, arena: &mut Arena, state: &mut ConversionS /// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation. #[recursive] fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionState) -> Node { + let owned = Arc::unwrap_or_clone; let v = match expr { - Expr::Explode(expr) => AExpr::Explode(to_aexpr_impl(*expr, arena, state)), + Expr::Explode(expr) => AExpr::Explode(to_aexpr_impl(owned(expr), arena, state)), Expr::Alias(e, name) => { if state.prune_alias { if state.output_name.is_none() && !state.ignore_alias { state.output_name = OutputName::Alias(name); } - to_aexpr_impl(*e, arena, state); + to_aexpr_impl(owned(e), arena, state); arena.pop().unwrap() } else { - AExpr::Alias(to_aexpr_impl(*e, arena, state), name) + AExpr::Alias(to_aexpr_impl(owned(e), arena, state), name) } }, Expr::Literal(lv) => { @@ -92,8 +93,8 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta AExpr::Column(name) }, Expr::BinaryExpr { left, op, right } => { - let l = to_aexpr_impl(*left, arena, state); - let r = to_aexpr_impl(*right, arena, state); + let l = to_aexpr_impl(owned(left), arena, state); + let r = to_aexpr_impl(owned(right), arena, state); AExpr::BinaryExpr { left: l, op, @@ -105,7 +106,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta data_type, strict, } => AExpr::Cast { - expr: to_aexpr_impl(*expr, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), data_type, strict, }, @@ -114,12 +115,12 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta idx, returns_scalar, } => AExpr::Gather { - expr: to_aexpr_impl(*expr, arena, state), - idx: to_aexpr_impl(*idx, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), + idx: to_aexpr_impl(owned(idx), arena, state), returns_scalar, }, Expr::Sort { expr, options } => AExpr::Sort { - expr: to_aexpr_impl(*expr, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), options, }, Expr::SortBy { @@ -127,7 +128,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta by, descending, } => AExpr::SortBy { - expr: to_aexpr_impl(*expr, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), by: by .into_iter() .map(|e| to_aexpr_impl(e, arena, state)) @@ -135,8 +136,8 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta descending, }, Expr::Filter { input, by } => AExpr::Filter { - input: to_aexpr_impl(*input, arena, state), - by: to_aexpr_impl(*by, arena, state), + input: to_aexpr_impl(owned(input), arena, state), + by: to_aexpr_impl(owned(by), arena, state), }, Expr::Agg(agg) => { let a_agg = match agg { @@ -144,38 +145,48 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta input, propagate_nans, } => AAggExpr::Min { - input: to_aexpr_impl(*input, arena, state), + input: to_aexpr_impl(owned(input), arena, state), propagate_nans, }, AggExpr::Max { input, propagate_nans, } => AAggExpr::Max { - input: to_aexpr_impl(*input, arena, state), + input: to_aexpr_impl(owned(input), arena, state), propagate_nans, }, - AggExpr::Median(expr) => AAggExpr::Median(to_aexpr_impl(*expr, arena, state)), - AggExpr::NUnique(expr) => AAggExpr::NUnique(to_aexpr_impl(*expr, arena, state)), - AggExpr::First(expr) => AAggExpr::First(to_aexpr_impl(*expr, arena, state)), - AggExpr::Last(expr) => AAggExpr::Last(to_aexpr_impl(*expr, arena, state)), - AggExpr::Mean(expr) => AAggExpr::Mean(to_aexpr_impl(*expr, arena, state)), - AggExpr::Implode(expr) => AAggExpr::Implode(to_aexpr_impl(*expr, arena, state)), + AggExpr::Median(expr) => AAggExpr::Median(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::NUnique(expr) => { + AAggExpr::NUnique(to_aexpr_impl(owned(expr), arena, state)) + }, + AggExpr::First(expr) => AAggExpr::First(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::Last(expr) => AAggExpr::Last(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::Mean(expr) => AAggExpr::Mean(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::Implode(expr) => { + AAggExpr::Implode(to_aexpr_impl(owned(expr), arena, state)) + }, AggExpr::Count(expr, include_nulls) => { - AAggExpr::Count(to_aexpr_impl(*expr, arena, state), include_nulls) + AAggExpr::Count(to_aexpr_impl(owned(expr), arena, state), include_nulls) }, AggExpr::Quantile { expr, quantile, interpol, } => AAggExpr::Quantile { - expr: to_aexpr_impl(*expr, arena, state), - quantile: to_aexpr_impl(*quantile, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), + quantile: to_aexpr_impl(owned(quantile), arena, state), interpol, }, - AggExpr::Sum(expr) => AAggExpr::Sum(to_aexpr_impl(*expr, arena, state)), - AggExpr::Std(expr, ddof) => AAggExpr::Std(to_aexpr_impl(*expr, arena, state), ddof), - AggExpr::Var(expr, ddof) => AAggExpr::Var(to_aexpr_impl(*expr, arena, state), ddof), - AggExpr::AggGroups(expr) => AAggExpr::AggGroups(to_aexpr_impl(*expr, arena, state)), + AggExpr::Sum(expr) => AAggExpr::Sum(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::Std(expr, ddof) => { + AAggExpr::Std(to_aexpr_impl(owned(expr), arena, state), ddof) + }, + AggExpr::Var(expr, ddof) => { + AAggExpr::Var(to_aexpr_impl(owned(expr), arena, state), ddof) + }, + AggExpr::AggGroups(expr) => { + AAggExpr::AggGroups(to_aexpr_impl(owned(expr), arena, state)) + }, }; AExpr::Agg(a_agg) }, @@ -185,9 +196,9 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta falsy, } => { // Truthy must be resolved first to get the lhs name first set. - let t = to_aexpr_impl(*truthy, arena, state); - let p = to_aexpr_impl(*predicate, arena, state); - let f = to_aexpr_impl(*falsy, arena, state); + let t = to_aexpr_impl(owned(truthy), arena, state); + let p = to_aexpr_impl(owned(predicate), arena, state); + let f = to_aexpr_impl(owned(falsy), arena, state); AExpr::Ternary { predicate: p, truthy: t, @@ -228,7 +239,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta partition_by, options, } => AExpr::Window { - function: to_aexpr_impl(*function, arena, state), + function: to_aexpr_impl(owned(function), arena, state), partition_by: to_aexprs(partition_by, arena, state), options, }, @@ -237,9 +248,9 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta offset, length, } => AExpr::Slice { - input: to_aexpr_impl(*input, arena, state), - offset: to_aexpr_impl(*offset, arena, state), - length: to_aexpr_impl(*length, arena, state), + input: to_aexpr_impl(owned(input), arena, state), + offset: to_aexpr_impl(owned(offset), arena, state), + length: to_aexpr_impl(owned(length), arena, state), }, Expr::Len => { if state.output_name.is_none() { @@ -482,10 +493,10 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let expr = expr_arena.get(node).clone(); match expr { - AExpr::Explode(node) => Expr::Explode(Box::new(node_to_expr(node, expr_arena))), + AExpr::Explode(node) => Expr::Explode(Arc::new(node_to_expr(node, expr_arena))), AExpr::Alias(expr, name) => { let exp = node_to_expr(expr, expr_arena); - Expr::Alias(Box::new(exp), name) + Expr::Alias(Arc::new(exp), name) }, AExpr::Column(a) => Expr::Column(a), AExpr::Literal(s) => Expr::Literal(s), @@ -493,9 +504,9 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let l = node_to_expr(left, expr_arena); let r = node_to_expr(right, expr_arena); Expr::BinaryExpr { - left: Box::new(l), + left: Arc::new(l), op, - right: Box::new(r), + right: Arc::new(r), } }, AExpr::Cast { @@ -505,7 +516,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { } => { let exp = node_to_expr(expr, expr_arena); Expr::Cast { - expr: Box::new(exp), + expr: Arc::new(exp), data_type, strict, } @@ -513,7 +524,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { AExpr::Sort { expr, options } => { let exp = node_to_expr(expr, expr_arena); Expr::Sort { - expr: Box::new(exp), + expr: Arc::new(exp), options, } }, @@ -525,8 +536,8 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let expr = node_to_expr(expr, expr_arena); let idx = node_to_expr(idx, expr_arena); Expr::Gather { - expr: Box::new(expr), - idx: Box::new(idx), + expr: Arc::new(expr), + idx: Arc::new(idx), returns_scalar, } }, @@ -541,7 +552,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { .map(|node| node_to_expr(*node, expr_arena)) .collect(); Expr::SortBy { - expr: Box::new(expr), + expr: Arc::new(expr), by, descending, } @@ -550,8 +561,8 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let input = node_to_expr(input, expr_arena); let by = node_to_expr(by, expr_arena); Expr::Filter { - input: Box::new(input), - by: Box::new(by), + input: Arc::new(input), + by: Arc::new(by), } }, AExpr::Agg(agg) => match agg { @@ -561,7 +572,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { } => { let exp = node_to_expr(input, expr_arena); AggExpr::Min { - input: Box::new(exp), + input: Arc::new(exp), propagate_nans, } .into() @@ -572,7 +583,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { } => { let exp = node_to_expr(input, expr_arena); AggExpr::Max { - input: Box::new(exp), + input: Arc::new(exp), propagate_nans, } .into() @@ -580,27 +591,27 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { AAggExpr::Median(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Median(Box::new(exp)).into() + AggExpr::Median(Arc::new(exp)).into() }, AAggExpr::NUnique(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::NUnique(Box::new(exp)).into() + AggExpr::NUnique(Arc::new(exp)).into() }, AAggExpr::First(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::First(Box::new(exp)).into() + AggExpr::First(Arc::new(exp)).into() }, AAggExpr::Last(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Last(Box::new(exp)).into() + AggExpr::Last(Arc::new(exp)).into() }, AAggExpr::Mean(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Mean(Box::new(exp)).into() + AggExpr::Mean(Arc::new(exp)).into() }, AAggExpr::Implode(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Implode(Box::new(exp)).into() + AggExpr::Implode(Arc::new(exp)).into() }, AAggExpr::Quantile { expr, @@ -610,31 +621,31 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let expr = node_to_expr(expr, expr_arena); let quantile = node_to_expr(quantile, expr_arena); AggExpr::Quantile { - expr: Box::new(expr), - quantile: Box::new(quantile), + expr: Arc::new(expr), + quantile: Arc::new(quantile), interpol, } .into() }, AAggExpr::Sum(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Sum(Box::new(exp)).into() + AggExpr::Sum(Arc::new(exp)).into() }, AAggExpr::Std(expr, ddof) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Std(Box::new(exp), ddof).into() + AggExpr::Std(Arc::new(exp), ddof).into() }, AAggExpr::Var(expr, ddof) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Var(Box::new(exp), ddof).into() + AggExpr::Var(Arc::new(exp), ddof).into() }, AAggExpr::AggGroups(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::AggGroups(Box::new(exp)).into() + AggExpr::AggGroups(Arc::new(exp)).into() }, AAggExpr::Count(expr, include_nulls) => { let expr = node_to_expr(expr, expr_arena); - AggExpr::Count(Box::new(expr), include_nulls).into() + AggExpr::Count(Arc::new(expr), include_nulls).into() }, }, AExpr::Ternary { @@ -647,9 +658,9 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let f = node_to_expr(falsy, expr_arena); Expr::Ternary { - predicate: Box::new(p), - truthy: Box::new(t), - falsy: Box::new(f), + predicate: Arc::new(p), + truthy: Arc::new(t), + falsy: Arc::new(f), } }, AExpr::AnonymousFunction { @@ -677,7 +688,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { partition_by, options, } => { - let function = Box::new(node_to_expr(function, expr_arena)); + let function = Arc::new(node_to_expr(function, expr_arena)); let partition_by = nodes_to_exprs(&partition_by, expr_arena); Expr::Window { function, @@ -690,9 +701,9 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { offset, length, } => Expr::Slice { - input: Box::new(node_to_expr(input, expr_arena)), - offset: Box::new(node_to_expr(offset, expr_arena)), - length: Box::new(node_to_expr(length, expr_arena)), + input: Arc::new(node_to_expr(input, expr_arena)), + offset: Arc::new(node_to_expr(offset, expr_arena)), + length: Arc::new(node_to_expr(length, expr_arena)), }, AExpr::Len => Expr::Len, AExpr::Nth(i) => Expr::Nth(i), diff --git a/crates/polars-plan/src/logical_plan/iterator.rs b/crates/polars-plan/src/logical_plan/iterator.rs index 611f08badd83..9b1ac7ecb01e 100644 --- a/crates/polars-plan/src/logical_plan/iterator.rs +++ b/crates/polars-plan/src/logical_plan/iterator.rs @@ -1,55 +1,58 @@ -use arrow::legacy::error::PolarsResult; +use std::sync::Arc; + +use polars_core::error::PolarsResult; use polars_utils::idx_vec::UnitVec; use polars_utils::unitvec; +use visitor::{RewritingVisitor, TreeWalker}; use crate::prelude::*; macro_rules! push_expr { - ($current_expr:expr, $push:ident, $iter:ident) => {{ + ($current_expr:expr, $c:ident, $push:ident, $push_owned:ident, $iter:ident) => {{ use Expr::*; match $current_expr { Nth(_) | Column(_) | Literal(_) | Wildcard | Columns(_) | DtypeColumn(_) | Len => {}, - Alias(e, _) => $push(e), + Alias(e, _) => $push($c, e), BinaryExpr { left, op: _, right } => { // reverse order so that left is popped first - $push(right); - $push(left); + $push($c, right); + $push($c, left); }, - Cast { expr, .. } => $push(expr), - Sort { expr, .. } => $push(expr), + Cast { expr, .. } => $push($c, expr), + Sort { expr, .. } => $push($c, expr), Gather { expr, idx, .. } => { - $push(idx); - $push(expr); + $push($c, idx); + $push($c, expr); }, Filter { input, by } => { - $push(by); + $push($c, by); // latest, so that it is popped first - $push(input); + $push($c, input); }, SortBy { expr, by, .. } => { for e in by { - $push(e) + $push_owned($c, e) } // latest, so that it is popped first - $push(expr); + $push($c, expr); }, Agg(agg_e) => { use AggExpr::*; match agg_e { - Max { input, .. } => $push(input), - Min { input, .. } => $push(input), - Mean(e) => $push(e), - Median(e) => $push(e), - NUnique(e) => $push(e), - First(e) => $push(e), - Last(e) => $push(e), - Implode(e) => $push(e), - Count(e, _) => $push(e), - Quantile { expr, .. } => $push(expr), - Sum(e) => $push(e), - AggGroups(e) => $push(e), - Std(e, _) => $push(e), - Var(e, _) => $push(e), + Max { input, .. } => $push($c, input), + Min { input, .. } => $push($c, input), + Mean(e) => $push($c, e), + Median(e) => $push($c, e), + NUnique(e) => $push($c, e), + First(e) => $push($c, e), + Last(e) => $push($c, e), + Implode(e) => $push($c, e), + Count(e, _) => $push($c, e), + Quantile { expr, .. } => $push($c, expr), + Sum(e) => $push($c, e), + AggGroups(e) => $push($c, e), + Std(e, _) => $push($c, e), + Var(e, _) => $push($c, e), } }, Ternary { @@ -57,40 +60,40 @@ macro_rules! push_expr { falsy, predicate, } => { - $push(predicate); - $push(falsy); + $push($c, predicate); + $push($c, falsy); // latest, so that it is popped first - $push(truthy); + $push($c, truthy); }, // we iterate in reverse order, so that the lhs is popped first and will be found // as the root columns/ input columns by `_suffix` and `_keep_name` etc. - AnonymousFunction { input, .. } => input.$iter().rev().for_each(|e| $push(e)), - Function { input, .. } => input.$iter().rev().for_each(|e| $push(e)), - Explode(e) => $push(e), + AnonymousFunction { input, .. } => input.$iter().rev().for_each(|e| $push_owned($c, e)), + Function { input, .. } => input.$iter().rev().for_each(|e| $push_owned($c, e)), + Explode(e) => $push($c, e), Window { function, partition_by, .. } => { for e in partition_by.into_iter().rev() { - $push(e) + $push_owned($c, e) } // latest so that it is popped first - $push(function); + $push($c, function); }, Slice { input, offset, length, } => { - $push(length); - $push(offset); + $push($c, length); + $push($c, offset); // latest, so that it is popped first - $push(input); + $push($c, input); }, - Exclude(e, _) => $push(e), - KeepName(e) => $push(e), - RenameAlias { expr, .. } => $push(expr), + Exclude(e, _) => $push($c, e), + KeepName(e) => $push($c, e), + RenameAlias { expr, .. } => $push($c, expr), SubPlan { .. } => {}, // pass Selector(_) => {}, @@ -98,47 +101,6 @@ macro_rules! push_expr { }}; } -impl Expr { - /// Expr::mutate().apply(fn()) - pub fn mutate(&mut self) -> ExprMut { - let stack = unitvec!(self); - ExprMut { stack } - } -} - -pub struct ExprMut<'a> { - stack: UnitVec<&'a mut Expr>, -} - -impl<'a> ExprMut<'a> { - /// - /// # Arguments - /// * `f` - A function that may mutate an expression. If the function returns `true` iteration - /// continues. - pub fn apply(&mut self, mut f: F) - where - F: FnMut(&mut Expr) -> bool, - { - let _ = self.try_apply(|e| Ok(f(e))); - } - - pub fn try_apply(&mut self, mut f: F) -> PolarsResult<()> - where - F: FnMut(&mut Expr) -> PolarsResult, - { - while let Some(current_expr) = self.stack.pop() { - // the order is important, we first modify the Expr - // before we push its children on the stack. - // The modification can make the children invalid. - if !f(current_expr)? { - break; - } - current_expr.nodes_mut(&mut self.stack) - } - Ok(()) - } -} - pub struct ExprIter<'a> { stack: UnitVec<&'a Expr>, } @@ -154,15 +116,36 @@ impl<'a> Iterator for ExprIter<'a> { } } +pub struct ExprMapper { + f: F, +} + +impl PolarsResult> RewritingVisitor for ExprMapper { + type Node = Expr; + + fn mutate(&mut self, node: Self::Node) -> PolarsResult { + (self.f)(node) + } +} + impl Expr { pub fn nodes<'a>(&'a self, container: &mut UnitVec<&'a Expr>) { - let mut push = |e: &'a Expr| container.push(e); - push_expr!(self, push, iter); + let push = |c: &mut UnitVec<&'a Expr>, e: &'a Expr| c.push(e); + push_expr!(self, container, push, push, iter); + } + + pub fn nodes_owned(self, container: &mut UnitVec) { + let push_arc = |c: &mut UnitVec, e: Arc| c.push(Arc::unwrap_or_clone(e)); + let push_owned = |c: &mut UnitVec, e: Expr| c.push(e); + push_expr!(self, container, push_arc, push_owned, into_iter); + } + + pub fn map_expr Self>(self, mut f: F) -> Self { + self.rewrite(&mut ExprMapper { f: |e| Ok(f(e)) }).unwrap() } - pub fn nodes_mut<'a>(&'a mut self, container: &mut UnitVec<&'a mut Expr>) { - let mut push = |e: &'a mut Expr| container.push(e); - push_expr!(self, push, iter_mut); + pub fn try_map_expr PolarsResult>(self, f: F) -> PolarsResult { + self.rewrite(&mut ExprMapper { f }) } } diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index 679b58740b1c..c3bf2d2a2ea3 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -123,13 +123,15 @@ impl<'a> PredicatePushDown<'a> { if needs_rename { // TODO! Do this directly on AExpr. let mut new_expr = node_to_expr(e.node(), expr_arena); - new_expr.mutate().apply(|e| { - if let Expr::Column(name) = e { - if let Some(rename_to) = alias_rename_map.get(name) { - *name = rename_to.clone(); - }; - }; - true + new_expr = new_expr.map_expr(|e| match e { + Expr::Column(name) => { + if let Some(rename_to) = alias_rename_map.get(&*name) { + Expr::Column(rename_to.clone()) + } else { + Expr::Column(name) + } + }, + e => e, }); let predicate = to_aexpr(new_expr, expr_arena); e.set_node(predicate); diff --git a/crates/polars-plan/src/logical_plan/projection.rs b/crates/polars-plan/src/logical_plan/projection.rs index 343e0d0498e3..6450c4822bb5 100644 --- a/crates/polars-plan/src/logical_plan/projection.rs +++ b/crates/polars-plan/src/logical_plan/projection.rs @@ -5,33 +5,20 @@ use super::*; /// This replace the wildcard Expr with a Column Expr. It also removes the Exclude Expr from the /// expression chain. -pub(super) fn replace_wildcard_with_column(mut expr: Expr, column_name: Arc) -> Expr { - expr.mutate().apply(|e| { - match e { - Expr::Wildcard => { - *e = Expr::Column(column_name.clone()); - }, - Expr::Exclude(input, _) => { - *e = replace_wildcard_with_column(std::mem::take(input), column_name.clone()); - }, - _ => {}, - } - // always keep iterating all inputs - true - }); - expr +pub(super) fn replace_wildcard_with_column(expr: Expr, column_name: Arc) -> Expr { + expr.map_expr(|e| match e { + Expr::Wildcard => Expr::Column(column_name.clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }) } #[cfg(feature = "regex")] -fn remove_exclude(mut expr: Expr) -> Expr { - expr.mutate().apply(|e| { - if let Expr::Exclude(input, _) = e { - *e = remove_exclude(std::mem::take(input)); - } - // always keep iterating all inputs - true - }); - expr +fn remove_exclude(expr: Expr) -> Expr { + expr.map_expr(|e| match e { + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }) } fn rewrite_special_aliases(expr: Expr) -> PolarsResult { @@ -79,22 +66,22 @@ fn replace_wildcard( Ok(()) } -fn replace_nth(expr: &mut Expr, schema: &Schema) { - expr.mutate().apply(|e| match e { - Expr::Nth(i) => { +fn replace_nth(expr: Expr, schema: &Schema) -> Expr { + expr.map_expr(|e| { + if let Expr::Nth(i) = e { match i.negative_to_usize(schema.len()) { None => { - let name = if *i == 0 { "first" } else { "last" }; - *e = Expr::Column(ColumnName::from(name)); + let name = if i == 0 { "first" } else { "last" }; + Expr::Column(ColumnName::from(name)) }, Some(idx) => { let (name, _dtype) = schema.get_at_index(idx).unwrap(); - *e = Expr::Column(ColumnName::from(&**name)) + Expr::Column(ColumnName::from(&**name)) }, } - true - }, - _ => true, + } else { + e + } }) } @@ -114,12 +101,11 @@ fn expand_regex( if re.is_match(name) && !exclude.contains(name.as_str()) { let mut new_expr = remove_exclude(expr.clone()); - new_expr.mutate().apply(|e| match &e { + new_expr = new_expr.map_expr(|e| match e { Expr::Column(pat) if pat.as_ref() == pattern => { - *e = Expr::Column(ColumnName::from(name.as_str())); - true + Expr::Column(ColumnName::from(name.as_str())) }, - _ => true, + e => e, }); let new_expr = rewrite_special_aliases(new_expr)?; @@ -203,21 +189,12 @@ fn expand_columns( /// This replaces the dtypes Expr with a Column Expr. It also removes the Exclude Expr from the /// expression chain. -pub(super) fn replace_dtype_with_column(mut expr: Expr, column_name: Arc) -> Expr { - expr.mutate().apply(|e| { - match e { - Expr::DtypeColumn(_) => { - *e = Expr::Column(column_name.clone()); - }, - Expr::Exclude(input, _) => { - *e = replace_dtype_with_column(std::mem::take(input), column_name.clone()); - }, - _ => {}, - } - // always keep iterating all inputs - true - }); - expr +pub(super) fn replace_dtype_with_column(expr: Expr, column_name: Arc) -> Expr { + expr.map_expr(|e| match e { + Expr::DtypeColumn(_) => Expr::Column(column_name.clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }) } /// This replaces the columns Expr with a Column Expr. It also removes the Exclude Expr from the @@ -228,26 +205,18 @@ pub(super) fn replace_columns_with_column( column_name: &str, ) -> (Expr, bool) { let mut is_valid = true; - expr.mutate().apply(|e| { - match e { - Expr::Columns(members) => { - // `col([a, b]) + col([c, d])` - if members == names { - *e = Expr::Column(ColumnName::from(column_name)); - } else { - is_valid = false; - } - }, - Expr::Exclude(input, _) => { - let (new_expr, new_expr_valid) = - replace_columns_with_column(std::mem::take(input), names, column_name); - *e = new_expr; - is_valid &= new_expr_valid; - }, - _ => {}, - } - // always keep iterating all inputs - true + expr = expr.map_expr(|e| match e { + Expr::Columns(members) => { + // `col([a, b]) + col([c, d])` + if members == names { + Expr::Column(ColumnName::from(column_name)) + } else { + is_valid = false; + Expr::Columns(members) + } + }, + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, }); (expr, is_valid) } @@ -363,18 +332,16 @@ fn prepare_excluded( } // functions can have col(["a", "b"]) or col(String) as inputs -fn expand_function_inputs(mut expr: Expr, schema: &Schema) -> Expr { - expr.mutate().apply(|e| match e { +fn expand_function_inputs(expr: Expr, schema: &Schema) -> Expr { + expr.map_expr(|mut e| match &mut e { Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. } if options.input_wildcard_expansion => { - *input = rewrite_projections(input.clone(), schema, &[]).unwrap(); - // continue iteration, there might be more functions. - true + *input = rewrite_projections(core::mem::take(input), schema, &[]).unwrap(); + e }, - _ => true, - }); - expr + _ => e, + }) } /// this is determined in type coercion @@ -460,7 +427,7 @@ pub(crate) fn rewrite_projections( let mut flags = find_flags(&expr); if flags.has_selector { - replace_selector(&mut expr, schema, keys)?; + expr = replace_selector(expr, schema, keys)?; // the selector is replaced with Expr::Columns flags.multiple_columns = true; } @@ -475,21 +442,19 @@ pub(crate) fn rewrite_projections( // them up there. if flags.replace_fill_null_type { for e in &mut result[result_offset..] { - e.mutate().apply(|e| { + *e = e.clone().map_expr(|mut e| { if let Expr::Function { input, function: FunctionExpr::FillNull { super_type }, .. - } = e + } = &mut e { if let Some(new_st) = early_supertype(input, schema) { *super_type = new_st; } } - - // continue iteration - true - }) + e + }); } } } @@ -504,7 +469,7 @@ fn replace_and_add_to_results( keys: &[Expr], ) -> PolarsResult<()> { if flags.has_nth { - replace_nth(&mut expr, schema); + expr = replace_nth(expr, schema); } // has multiple column names @@ -603,20 +568,18 @@ fn replace_selector_inner( Ok(()) } -fn replace_selector(expr: &mut Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult<()> { - // first pass we replace the selectors - // with Expr::Columns - // we expand the `to_add` columns - // and then subtract the `to_subtract` columns - expr.mutate().try_apply(|e| match e { - Expr::Selector(s) => { +fn replace_selector(expr: Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult { + // First pass we replace the selectors with Expr::Columns, we expand the `to_add` columns + // and then subtract the `to_subtract` columns. + expr.try_map_expr(|e| match e { + Expr::Selector(mut s) => { let mut swapped = Selector::Root(Box::new(Expr::Wildcard)); - std::mem::swap(s, &mut swapped); + std::mem::swap(&mut s, &mut swapped); let mut members = PlIndexSet::new(); replace_selector_inner(swapped, &mut members, &mut vec![], schema, keys)?; - *e = Expr::Columns( + Ok(Expr::Columns( members .into_iter() .map(|e| { @@ -626,11 +589,8 @@ fn replace_selector(expr: &mut Expr, schema: &Schema, keys: &[Expr]) -> PolarsRe name.to_string() }) .collect(), - ); - - Ok(true) + )) }, - _ => Ok(true), - })?; - Ok(()) + e => Ok(e), + }) } diff --git a/crates/polars-plan/src/logical_plan/visitor/expr.rs b/crates/polars-plan/src/logical_plan/visitor/expr.rs index 8faf4e67fcce..e75dee347dd3 100644 --- a/crates/polars-plan/src/logical_plan/visitor/expr.rs +++ b/crates/polars-plan/src/logical_plan/visitor/expr.rs @@ -24,8 +24,62 @@ impl TreeWalker for Expr { Ok(VisitRecursion::Continue) } - fn map_children(self, _op: &mut dyn FnMut(Self) -> PolarsResult) -> PolarsResult { - todo!() + fn map_children(self, mut f: &mut dyn FnMut(Self) -> PolarsResult) -> PolarsResult { + use polars_utils::functions::try_arc_map as am; + use AggExpr::*; + use Expr::*; + #[rustfmt::skip] + let ret = match self { + Alias(l, r) => Alias(am(l, f)?, r), + Column(_) => self, + Columns(_) => self, + DtypeColumn(_) => self, + Literal(_) => self, + BinaryExpr { left, op, right } => { + BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?} + }, + Cast { expr, data_type, strict } => Cast { expr: am(expr, f)?, data_type, strict }, + Sort { expr, options } => Sort { expr: am(expr, f)?, options }, + Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar }, + SortBy { expr, by, descending } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::>()?, descending }, + Agg(agg_expr) => Agg(match agg_expr { + Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans }, + Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans }, + Median(x) => Median(am(x, f)?), + NUnique(x) => NUnique(am(x, f)?), + First(x) => First(am(x, f)?), + Last(x) => Last(am(x, f)?), + Mean(x) => Mean(am(x, f)?), + Implode(x) => Implode(am(x, f)?), + Count(x, nulls) => Count(am(x, f)?, nulls), + Quantile { expr, quantile, interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, interpol }, + Sum(x) => Sum(am(x, f)?), + AggGroups(x) => AggGroups(am(x, f)?), + Std(x, ddf) => Std(am(x, f)?, ddf), + Var(x, ddf) => Var(am(x, f)?, ddf), + }), + Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? }, + Function { input, function, options } => Function { input: input.into_iter().map(f).collect::>()?, function, options }, + Explode(expr) => Explode(am(expr, f)?), + Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? }, + Window { function, partition_by, options } => { + let partition_by = partition_by.into_iter().map(&mut f).collect::>()?; + Window { function: am(function, f)?, partition_by, options } + }, + Wildcard => Wildcard, + Slice { input, offset, length } => Slice { input: am(input, &mut f)?, offset: am(offset, &mut f)?, length: am(length, f)? }, + Exclude(expr, excluded) => Exclude(am(expr, f)?, excluded), + KeepName(expr) => KeepName(am(expr, f)?), + Len => Len, + Nth(_) => self, + RenameAlias { function, expr } => RenameAlias { function, expr: am(expr, f)? }, + AnonymousFunction { input, function, output_type, options } => { + AnonymousFunction { input: input.into_iter().map(f).collect::>()?, function, output_type, options } + }, + SubPlan(_, _) => self, + Selector(_) => self, + }; + Ok(ret) } } diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index ff578ea599f7..258d42b4fc03 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -277,12 +277,9 @@ pub(crate) fn rename_matching_aexpr_leaf_names( if leaves.any(|node| matches!(arena.get(node.0), AExpr::Column(name) if &**name == current)) { // we convert to expression as we cannot easily copy the aexpr. let mut new_expr = node_to_expr(node, arena); - new_expr.mutate().apply(|e| match e { - Expr::Column(name) if &**name == current => { - *name = ColumnName::from(new_name); - true - }, - _ => true, + new_expr = new_expr.map_expr(|e| match e { + Expr::Column(name) if &*name == current => Expr::Column(ColumnName::from(new_name)), + e => e, }); to_aexpr(new_expr, arena) } else { diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 4e84219c9963..9a9963ed3259 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -525,15 +525,16 @@ impl SQLContext { fn process_subqueries(&self, lf: LazyFrame, exprs: Vec<&mut Expr>) -> LazyFrame { let mut contexts = vec![]; for expr in exprs { - expr.mutate().apply(|e| { - if let Expr::SubPlan(lp, names) = e { - contexts.push(::from((***lp).clone())); - + *expr = expr.clone().map_expr(|e| match e { + Expr::SubPlan(lp, names) => { + contexts.push(::from((**lp).clone())); if names.len() == 1 { - *e = Expr::Column(names[0].as_str().into()); + Expr::Column(names[0].as_str().into()) + } else { + Expr::SubPlan(lp, names) } - }; - true + }, + e => e, }) } diff --git a/crates/polars-utils/src/functions.rs b/crates/polars-utils/src/functions.rs index 4ff1d724cefb..528bae5ed291 100644 --- a/crates/polars-utils/src/functions.rs +++ b/crates/polars-utils/src/functions.rs @@ -1,4 +1,6 @@ +use std::mem::MaybeUninit; use std::ops::Range; +use std::sync::Arc; // The ith portion of a range split in k (as equal as possible) parts. #[inline(always)] @@ -23,3 +25,40 @@ pub fn flatten>(bufs: &[R], len: Option) -> Vec T>(mut arc: Arc, mut f: F) -> Arc { + unsafe { + // Make the Arc unique (cloning if necessary). + Arc::make_mut(&mut arc); + + // If f panics we must be able to drop the Arc without assuming it is initialized. + let mut uninit_arc = Arc::from_raw(Arc::into_raw(arc).cast::>()); + + // Replace the value inside the arc. + let ptr = Arc::get_mut(&mut uninit_arc).unwrap_unchecked() as *mut MaybeUninit; + *ptr = MaybeUninit::new(f(ptr.read().assume_init())); + + // Now the Arc is properly initialized again. + Arc::from_raw(Arc::into_raw(uninit_arc).cast::()) + } +} + +pub fn try_arc_map Result>( + mut arc: Arc, + mut f: F, +) -> Result, E> { + unsafe { + // Make the Arc unique (cloning if necessary). + Arc::make_mut(&mut arc); + + // If f panics we must be able to drop the Arc without assuming it is initialized. + let mut uninit_arc = Arc::from_raw(Arc::into_raw(arc).cast::>()); + + // Replace the value inside the arc. + let ptr = Arc::get_mut(&mut uninit_arc).unwrap_unchecked() as *mut MaybeUninit; + *ptr = MaybeUninit::new(f(ptr.read().assume_init())?); + + // Now the Arc is properly initialized again. + Ok(Arc::from_raw(Arc::into_raw(uninit_arc).cast::())) + } +}