diff --git a/src/common/treenode/src/lib.rs b/src/common/treenode/src/lib.rs index 908ebc3530..de7ff59c30 100644 --- a/src/common/treenode/src/lib.rs +++ b/src/common/treenode/src/lib.rs @@ -631,6 +631,31 @@ impl Transformed { f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } + /// Returns self if self is transformed, otherwise returns other. + pub fn or(self, other: Self) -> Self { + if self.transformed { + self + } else { + other + } + } + + /// Maps a `Transformed` to `Transformed`, + /// by supplying a function to apply to a contained Yes value + /// as well as a function to apply to a contained No value. + #[inline] + pub fn map_yes_no U, N: FnOnce(T) -> U>( + self, + yes_op: Y, + no_op: N, + ) -> Transformed { + if self.transformed { + Transformed::yes(yes_op(self.data)) + } else { + Transformed::no(no_op(self.data)) + } + } + /// Maps the [`Transformed`] object to the result of the given `f`. pub fn transform_data Result>>( self, diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 5290f46242..3c344c56aa 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -1,12 +1,12 @@ use std::sync::Arc; +use common_treenode::Transformed; use daft_core::prelude::*; use daft_dsl::{optimization, resolve_exprs, AggExpr, ApproxPercentileParams, Expr, ExprRef}; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; use snafu::ResultExt; -use crate::logical_optimization::Transformed; use crate::logical_plan::{CreationSnafu, Result}; use crate::LogicalPlan; @@ -144,7 +144,7 @@ impl Project { .map(|e| { let new_expr = replace_column_with_semantic_id(e.clone(), &subexprs_to_replace, schema); - let new_expr = new_expr.unwrap(); + let new_expr = new_expr.data; // The substitution can unintentionally change the expression's name // (since the name depends on the first column referenced, which can be substituted away) // so re-alias the original name here if it has changed. @@ -185,10 +185,10 @@ fn replace_column_with_semantic_id( Expr::Alias(_, name) => Expr::Alias(new_expr.into(), name.clone()), _ => new_expr, }; - Transformed::Yes(new_expr.into()) + Transformed::yes(new_expr.into()) } else { match e.as_ref() { - Expr::Column(_) | Expr::Literal(_) => Transformed::No(e), + Expr::Column(_) | Expr::Literal(_) => Transformed::no(e), Expr::Agg(agg_expr) => replace_column_with_semantic_id_aggexpr( agg_expr.clone(), subexprs_to_replace, @@ -241,11 +241,11 @@ fn replace_column_with_semantic_id( subexprs_to_replace, schema, ); - if child.is_no() && fill_value.is_no() { - Transformed::No(e) + if !child.transformed && !fill_value.transformed { + Transformed::no(e) } else { - Transformed::Yes( - Expr::FillNull(child.unwrap().clone(), fill_value.unwrap().clone()).into(), + Transformed::yes( + Expr::FillNull(child.data.clone(), fill_value.data.clone()).into(), ) } } @@ -254,12 +254,10 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema); let items = replace_column_with_semantic_id(items.clone(), subexprs_to_replace, schema); - if child.is_no() && items.is_no() { - Transformed::No(e) + if !child.transformed && !items.transformed { + Transformed::no(e) } else { - Transformed::Yes( - Expr::IsIn(child.unwrap().clone(), items.unwrap().clone()).into(), - ) + Transformed::yes(Expr::IsIn(child.data.clone(), items.data.clone()).into()) } } Expr::Between(child, lower, upper) => { @@ -269,16 +267,12 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(lower.clone(), subexprs_to_replace, schema); let upper = replace_column_with_semantic_id(upper.clone(), subexprs_to_replace, schema); - if child.is_no() && lower.is_no() && upper.is_no() { - Transformed::No(e) + if !child.transformed && !lower.transformed && !upper.transformed { + Transformed::no(e) } else { - Transformed::Yes( - Expr::Between( - child.unwrap().clone(), - lower.unwrap().clone(), - upper.unwrap().clone(), - ) - .into(), + Transformed::yes( + Expr::Between(child.data.clone(), lower.data.clone(), upper.data.clone()) + .into(), ) } } @@ -287,14 +281,14 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(left.clone(), subexprs_to_replace, schema); let right = replace_column_with_semantic_id(right.clone(), subexprs_to_replace, schema); - if left.is_no() && right.is_no() { - Transformed::No(e) + if !left.transformed && !right.transformed { + Transformed::no(e) } else { - Transformed::Yes( + Transformed::yes( Expr::BinaryOp { op: *op, - left: left.unwrap().clone(), - right: right.unwrap().clone(), + left: left.data.clone(), + right: right.data.clone(), } .into(), ) @@ -311,14 +305,14 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(if_true.clone(), subexprs_to_replace, schema); let if_false = replace_column_with_semantic_id(if_false.clone(), subexprs_to_replace, schema); - if predicate.is_no() && if_true.is_no() && if_false.is_no() { - Transformed::No(e) + if !predicate.transformed && !if_true.transformed && !if_false.transformed { + Transformed::no(e) } else { - Transformed::Yes( + Transformed::yes( Expr::IfElse { - predicate: predicate.unwrap().clone(), - if_true: if_true.unwrap().clone(), - if_false: if_false.unwrap().clone(), + predicate: predicate.data.clone(), + if_true: if_true.data.clone(), + if_false: if_false.data.clone(), } .into(), ) @@ -331,13 +325,13 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema) }) .collect::>(); - if transforms.iter().all(|e| e.is_no()) { - Transformed::No(e) + if transforms.iter().all(|e| !e.transformed) { + Transformed::no(e) } else { - Transformed::Yes( + Transformed::yes( Expr::Function { func: func.clone(), - inputs: transforms.iter().map(|t| t.unwrap()).cloned().collect(), + inputs: transforms.iter().map(|t| t.data.clone()).collect(), } .into(), ) @@ -352,11 +346,11 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema) }) .collect::>(); - if transforms.iter().all(|e| e.is_no()) { - Transformed::No(e) + if transforms.iter().all(|e| !e.transformed) { + Transformed::no(e) } else { - func.inputs = transforms.iter().map(|t| t.unwrap()).cloned().collect(); - Transformed::Yes(Expr::ScalarFunction(func).into()) + func.inputs = transforms.iter().map(|t| t.data.clone()).collect(); + Transformed::yes(Expr::ScalarFunction(func).into()) } } } @@ -446,12 +440,12 @@ fn replace_column_with_semantic_id_aggexpr( .iter() .map(|e| replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema)) .collect::>(); - if transforms.iter().all(|e| e.is_no()) { - Transformed::No(AggExpr::MapGroups { func, inputs }) + if transforms.iter().all(|e| !e.transformed) { + Transformed::no(AggExpr::MapGroups { func, inputs }) } else { - Transformed::Yes(AggExpr::MapGroups { + Transformed::yes(AggExpr::MapGroups { func: func.clone(), - inputs: transforms.iter().map(|t| t.unwrap()).cloned().collect(), + inputs: transforms.iter().map(|t| t.data.clone()).collect(), }) } } diff --git a/src/daft-plan/src/logical_optimization/mod.rs b/src/daft-plan/src/logical_optimization/mod.rs index d270947f19..37bd2306cd 100644 --- a/src/daft-plan/src/logical_optimization/mod.rs +++ b/src/daft-plan/src/logical_optimization/mod.rs @@ -5,4 +5,3 @@ mod rules; mod test; pub use optimizer::{Optimizer, OptimizerConfig}; -pub use rules::Transformed; diff --git a/src/daft-plan/src/logical_optimization/optimizer.rs b/src/daft-plan/src/logical_optimization/optimizer.rs index 25191955d3..fd411d021c 100644 --- a/src/daft-plan/src/logical_optimization/optimizer.rs +++ b/src/daft-plan/src/logical_optimization/optimizer.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, ops::ControlFlow, sync::Arc}; +use std::{ops::ControlFlow, sync::Arc}; use common_error::DaftResult; @@ -7,11 +7,11 @@ use crate::LogicalPlan; use super::{ logical_plan_tracker::LogicalPlanTracker, rules::{ - ApplyOrder, DropRepartition, OptimizerRule, PushDownFilter, PushDownLimit, - PushDownProjection, SplitActorPoolProjects, Transformed, + DropRepartition, OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, + SplitActorPoolProjects, }, }; -use common_treenode::DynTreeNode; +use common_treenode::Transformed; /// Config for optimizer. #[derive(Debug)] @@ -49,52 +49,12 @@ pub struct RuleBatch { pub rules: Vec>, // The rule execution strategy (once, fixed-point). pub strategy: RuleExecutionStrategy, - // The application order for the entire rule batch, derived from the application - // order of the contained rules. - // If all rules in the batch have the same application order (e.g. top-down), the - // optimizer will apply the rules on a single tree traversal, where a given node - // in the plan tree will be transformed sequentially by each rule in the batch before - // moving on to the next node. - pub order: Option, } impl RuleBatch { pub fn new(rules: Vec>, strategy: RuleExecutionStrategy) -> Self { // Get all unique application orders for the rules. - let unique_application_orders: Vec = rules - .iter() - .map(|rule| rule.apply_order()) - .collect::>() - .into_iter() - .collect(); - let order = match unique_application_orders.as_slice() { - // All rules have the same application order, so use that as the application order for - // the entire batch. - [order] => Some(order.clone()), - // If rules have different application orders, run each rule as a separate tree pass with its own application order. - _ => None, - }; - Self { - rules, - strategy, - order, - } - } - - #[allow(dead_code)] - pub fn with_order( - rules: Vec>, - strategy: RuleExecutionStrategy, - order: Option, - ) -> Self { - debug_assert!(order.clone().map_or(true, |order| rules - .iter() - .all(|rule| rule.apply_order() == order))); - Self { - rules, - strategy, - order, - } + Self { rules, strategy } } /// Get the maximum number of passes the optimizer should make over this rule batch. @@ -211,8 +171,12 @@ impl Optimizer { let result = (0..batch.max_passes(&self.config)).try_fold( plan, |plan, pass| -> ControlFlow>, Arc> { - match self.optimize_with_rules(batch.rules.as_slice(), plan, &batch.order) { - Ok(Transformed::Yes(new_plan)) => { + match self.optimize_with_rules(batch.rules.as_slice(), plan) { + Ok(Transformed { + data: new_plan, + transformed: true, + .. + }) => { // Plan was transformed by the rule batch. if plan_tracker.add_plan(new_plan.as_ref()) { // Transformed plan has not yet been seen by this optimizer, which means we have @@ -226,7 +190,11 @@ impl Optimizer { ControlFlow::Break(Ok(new_plan)) } } - Ok(Transformed::No(plan)) => { + Ok(Transformed { + data: plan, + transformed: false, + .. + }) => { // Plan was not transformed by the rule batch, suggesting that we have reached a fixed-point. // We therefore stop applying this rule batch. observer(plan.as_ref(), batch, pass, false, false); @@ -243,96 +211,17 @@ impl Optimizer { } } - /// Optimize the provided plan with the provided rules using the provided application order. - /// - /// If order.is_some(), all rules are expected to have that application order. + /// Optimize the provided plan with each of the provided rules once. pub fn optimize_with_rules( &self, rules: &[Box], plan: Arc, - order: &Option, - ) -> DaftResult>> { - // Double-check that all rules have the same application order as `order`, if `order` is not None. - debug_assert!(order.clone().map_or(true, |order| rules - .iter() - .all(|rule| rule.apply_order() == order))); - match order { - // Perform a single top-down traversal and apply all rules on each node. - Some(ApplyOrder::TopDown) => { - // First optimize the current node, and then it's children. - let curr_opt = self.optimize_node(rules, plan)?; - let children_opt = - self.optimize_children(rules, curr_opt.unwrap().clone(), ApplyOrder::TopDown)?; - Ok(children_opt.or(curr_opt)) - } - // Perform a single bottom-up traversal and apply all rules on each node. - Some(ApplyOrder::BottomUp) => { - // First optimize the current node's children, and then the current node. - let children_opt = self.optimize_children(rules, plan, ApplyOrder::BottomUp)?; - let curr_opt = self.optimize_node(rules, children_opt.unwrap().clone())?; - Ok(curr_opt.or(children_opt)) - } - // All rules do their own internal tree traversals. - Some(ApplyOrder::Delegated) => self.optimize_node(rules, plan), - // Rule batch doesn't share a single application order, so we apply each rule with its own dedicated tree pass. - None => rules - .windows(1) - .try_fold(Transformed::No(plan), |plan, rule| { - self.optimize_with_rules( - rule, - plan.unwrap().clone(), - &Some(rule[0].apply_order()), - ) - }), - } - } - - /// Optimize a single plan node with the provided rules. - /// - /// This method does not drive traversal of the tree unless the tree is traversed by the rule itself, - /// in rule.try_optimize(). - fn optimize_node( - &self, - rules: &[Box], - plan: Arc, ) -> DaftResult>> { // Fold over the rules, applying each rule to this plan node sequentially. - rules.iter().try_fold(Transformed::No(plan), |plan, rule| { - Ok(rule.try_optimize(plan.unwrap().clone())?.or(plan)) + rules.iter().try_fold(Transformed::no(plan), |plan, rule| { + plan.transform_data(|data| rule.try_optimize(data)) }) } - - /// Optimize the children of the provided plan, updating the provided plan's pointed-to children - /// if the children are transformed. - fn optimize_children( - &self, - rules: &[Box], - plan: Arc, - order: ApplyOrder, - ) -> DaftResult>> { - // Run optimization rules on children. - let children = plan.arc_children(); - let result = children - .into_iter() - .map(|child_plan| { - self.optimize_with_rules(rules, child_plan.clone(), &Some(order.clone())) - }) - .collect::>>()?; - // If the optimization rule didn't change any of the children, return without modifying the plan. - if result.is_empty() || result.iter().all(|o| o.is_no()) { - return Ok(Transformed::No(plan)); - } - // Otherwise, update the parent to point to its optimized children. - let new_children = result - .into_iter() - .map(|maybe_opt_child| maybe_opt_child.unwrap().clone()) - .collect::>(); - - // Return new plan with optimized children. - Ok(Transformed::Yes( - plan.with_new_children(&new_children).into(), - )) - } } #[cfg(test)] @@ -340,13 +229,14 @@ mod tests { use std::sync::{Arc, Mutex}; use common_error::DaftResult; + use common_treenode::{Transformed, TreeNode}; use daft_core::prelude::*; use daft_dsl::{col, lit}; use crate::{ logical_ops::{Filter, Project}, - logical_optimization::rules::{ApplyOrder, OptimizerRule, Transformed}, + logical_optimization::rules::OptimizerRule, test::{dummy_scan_node, dummy_scan_operator}, LogicalPlan, }; @@ -388,15 +278,11 @@ mod tests { } impl OptimizerRule for NoOp { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize( &self, plan: Arc, ) -> DaftResult>> { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } @@ -552,22 +438,20 @@ mod tests { } impl OptimizerRule for FilterOrFalse { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize( &self, plan: Arc, ) -> DaftResult>> { - let filter = match plan.as_ref() { - LogicalPlan::Filter(filter) => filter.clone(), - _ => return Ok(Transformed::No(plan)), - }; - let new_predicate = filter.predicate.or(lit(false)); - Ok(Transformed::Yes( - LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), - )) + plan.transform_down(|node| { + let filter = match node.as_ref() { + LogicalPlan::Filter(filter) => filter.clone(), + _ => return Ok(Transformed::no(node)), + }; + let new_predicate = filter.predicate.or(lit(false)); + Ok(Transformed::yes( + LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), + )) + }) } } @@ -581,22 +465,20 @@ mod tests { } impl OptimizerRule for FilterAndTrue { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize( &self, plan: Arc, ) -> DaftResult>> { - let filter = match plan.as_ref() { - LogicalPlan::Filter(filter) => filter.clone(), - _ => return Ok(Transformed::No(plan)), - }; - let new_predicate = filter.predicate.and(lit(true)); - Ok(Transformed::Yes( - LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), - )) + plan.transform_down(|node| { + let filter = match node.as_ref() { + LogicalPlan::Filter(filter) => filter.clone(), + _ => return Ok(Transformed::no(node)), + }; + let new_predicate = filter.predicate.and(lit(true)); + Ok(Transformed::yes( + LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), + )) + }) } } @@ -614,29 +496,27 @@ mod tests { } impl OptimizerRule for RotateProjection { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize( &self, plan: Arc, ) -> DaftResult>> { - let project = match plan.as_ref() { - LogicalPlan::Project(project) => project.clone(), - _ => return Ok(Transformed::No(plan)), - }; - let mut exprs = project.projection.clone(); - let mut reverse = self.reverse_first.lock().unwrap(); - if *reverse { - exprs.reverse(); - *reverse = false; - } else { - exprs.rotate_left(1); - } - Ok(Transformed::Yes( - LogicalPlan::from(Project::try_new(project.input.clone(), exprs)?).into(), - )) + plan.transform_down(|node| { + let project = match node.as_ref() { + LogicalPlan::Project(project) => project.clone(), + _ => return Ok(Transformed::no(node)), + }; + let mut exprs = project.projection.clone(); + let mut reverse = self.reverse_first.lock().unwrap(); + if *reverse { + exprs.reverse(); + *reverse = false; + } else { + exprs.rotate_left(1); + } + Ok(Transformed::yes( + LogicalPlan::from(Project::try_new(project.input.clone(), exprs)?).into(), + )) + }) } } } diff --git a/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs b/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs index c8bbdaa7a0..838623e30b 100644 --- a/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs +++ b/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs @@ -4,9 +4,9 @@ use common_error::DaftResult; use crate::LogicalPlan; -use super::{ApplyOrder, OptimizerRule, Transformed}; +use super::OptimizerRule; -use common_treenode::DynTreeNode; +use common_treenode::{DynTreeNode, Transformed, TreeNode}; /// Optimization rules for dropping unnecessary Repartitions. /// @@ -22,27 +22,25 @@ impl DropRepartition { } impl OptimizerRule for DropRepartition { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize(&self, plan: Arc) -> DaftResult>> { - let repartition = match plan.as_ref() { - LogicalPlan::Repartition(repartition) => repartition, - _ => return Ok(Transformed::No(plan)), - }; - let child_plan = repartition.input.as_ref(); - let new_plan = match child_plan { - LogicalPlan::Repartition(_) => { - // Drop upstream Repartition for back-to-back Repartitions. - // - // Repartition1-Repartition2 -> Repartition1 - plan.with_new_children(&[child_plan.arc_children()[0].clone()]) - .into() - } - _ => return Ok(Transformed::No(plan)), - }; - Ok(Transformed::Yes(new_plan)) + plan.transform_down(|node| { + let repartition = match node.as_ref() { + LogicalPlan::Repartition(repartition) => repartition, + _ => return Ok(Transformed::no(node)), + }; + let child_plan = repartition.input.as_ref(); + let new_plan = match child_plan { + LogicalPlan::Repartition(_) => { + // Drop upstream Repartition for back-to-back Repartitions. + // + // Repartition1-Repartition2 -> Repartition1 + node.with_new_children(&[child_plan.arc_children()[0].clone()]) + .into() + } + _ => return Ok(Transformed::no(node)), + }; + Ok(Transformed::yes(new_plan)) + }) } } diff --git a/src/daft-plan/src/logical_optimization/rules/mod.rs b/src/daft-plan/src/logical_optimization/rules/mod.rs index fc137589c5..ac8579123a 100644 --- a/src/daft-plan/src/logical_optimization/rules/mod.rs +++ b/src/daft-plan/src/logical_optimization/rules/mod.rs @@ -9,5 +9,5 @@ pub use drop_repartition::DropRepartition; pub use push_down_filter::PushDownFilter; pub use push_down_limit::PushDownLimit; pub use push_down_projection::PushDownProjection; -pub use rule::{ApplyOrder, OptimizerRule, Transformed}; +pub use rule::OptimizerRule; pub use split_actor_pool_projects::SplitActorPoolProjects; diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs index c180acbeff..a8961b470d 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs @@ -20,8 +20,8 @@ use crate::{ LogicalPlan, }; -use super::{ApplyOrder, OptimizerRule, Transformed}; -use common_treenode::DynTreeNode; +use super::OptimizerRule; +use common_treenode::{DynTreeNode, Transformed, TreeNode}; /// Optimization rules for pushing Filters further into the logical plan. #[derive(Default, Debug)] @@ -34,14 +34,20 @@ impl PushDownFilter { } impl OptimizerRule for PushDownFilter { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_down(|node| self.try_optimize_node(node)) } +} - fn try_optimize(&self, plan: Arc) -> DaftResult>> { +impl PushDownFilter { + #[allow(clippy::only_used_in_recursion)] + fn try_optimize_node( + &self, + plan: Arc, + ) -> DaftResult>> { let filter = match plan.as_ref() { LogicalPlan::Filter(filter) => filter, - _ => return Ok(Transformed::No(plan)), + _ => return Ok(Transformed::no(plan)), }; let child_plan = filter.input.as_ref(); let new_plan = match child_plan { @@ -68,20 +74,20 @@ impl OptimizerRule for PushDownFilter { let new_filter: Arc = LogicalPlan::from(Filter::try_new(child_filter.input.clone(), new_predicate)?) .into(); - self.try_optimize(new_filter.clone())? - .or(Transformed::Yes(new_filter)) - .unwrap() + self.try_optimize_node(new_filter.clone())? + .or(Transformed::yes(new_filter)) + .data .clone() } LogicalPlan::Source(source) => { match source.source_info.as_ref() { // Filter pushdown is not supported for in-memory sources. - SourceInfo::InMemory(_) => return Ok(Transformed::No(plan)), + SourceInfo::InMemory(_) => return Ok(Transformed::no(plan)), // Do not pushdown if Source node already has a limit SourceInfo::Physical(external_info) if let Some(_) = external_info.pushdowns.limit => { - return Ok(Transformed::No(plan)) + return Ok(Transformed::no(plan)) } // Pushdown filter into the Source node @@ -126,7 +132,7 @@ impl OptimizerRule for PushDownFilter { // column and a data column (or contain a UDF), then no pushdown into the scan is possible, // so we short-circuit. // TODO(Clark): Support pushing predicates referencing both partition and data columns into the scan. - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } let data_filter = conjuct(data_only_filter); @@ -157,9 +163,9 @@ impl OptimizerRule for PushDownFilter { conjuct(needing_filter_op).unwrap(), )? .into(); - return Ok(Transformed::Yes(filter_op.into())); + return Ok(Transformed::yes(filter_op.into())); } else { - return Ok(Transformed::Yes(new_source.into())); + return Ok(Transformed::yes(new_source.into())); } } SourceInfo::PlaceHolder(..) => { @@ -205,7 +211,7 @@ impl OptimizerRule for PushDownFilter { } if can_push.is_empty() { // No predicate expressions can be pushed through projection. - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } // Create new Filter with predicates that can be pushed past Projection. let predicates_to_push = conjuct(can_push).unwrap(); @@ -335,12 +341,12 @@ impl OptimizerRule for PushDownFilter { new_join } } else { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } } - _ => return Ok(Transformed::No(plan)), + _ => return Ok(Transformed::no(plan)), }; - Ok(Transformed::Yes(new_plan)) + Ok(Transformed::yes(new_plan)) } } diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs b/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs index 405889af32..2359cbdcfa 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs @@ -8,8 +8,8 @@ use crate::{ LogicalPlan, }; -use super::{ApplyOrder, OptimizerRule, Transformed}; -use common_treenode::DynTreeNode; +use super::OptimizerRule; +use common_treenode::{DynTreeNode, Transformed, TreeNode}; /// Optimization rules for pushing Limits further into the logical plan. #[derive(Default, Debug)] @@ -22,11 +22,17 @@ impl PushDownLimit { } impl OptimizerRule for PushDownLimit { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_down(|node| self.try_optimize_node(node)) } +} - fn try_optimize(&self, plan: Arc) -> DaftResult>> { +impl PushDownLimit { + #[allow(clippy::only_used_in_recursion)] + fn try_optimize_node( + &self, + plan: Arc, + ) -> DaftResult>> { match plan.as_ref() { LogicalPlan::Limit(LogicalLimit { input, @@ -42,7 +48,7 @@ impl OptimizerRule for PushDownLimit { let new_limit = plan .with_new_children(&[input.arc_children()[0].clone()]) .into(); - Ok(Transformed::Yes( + Ok(Transformed::yes( input.with_new_children(&[new_limit]).into(), )) } @@ -52,13 +58,13 @@ impl OptimizerRule for PushDownLimit { LogicalPlan::Source(source) => { match source.source_info.as_ref() { // Limit pushdown is not supported for in-memory sources. - SourceInfo::InMemory(_) => Ok(Transformed::No(plan)), + SourceInfo::InMemory(_) => Ok(Transformed::no(plan)), // Do not pushdown if Source node is already more limited than `limit` SourceInfo::Physical(external_info) if let Some(existing_limit) = external_info.pushdowns.limit && existing_limit <= limit => { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } // Pushdown limit into the Source node as a "local" limit SourceInfo::Physical(external_info) => { @@ -74,7 +80,7 @@ impl OptimizerRule for PushDownLimit { } else { plan.with_new_children(&[new_source]).into() }; - Ok(Transformed::Yes(out_plan)) + Ok(Transformed::yes(out_plan)) } SourceInfo::PlaceHolder(..) => { panic!("PlaceHolderInfo should not exist for optimization!"); @@ -99,16 +105,16 @@ impl OptimizerRule for PushDownLimit { ))); // we rerun the optimizer, ideally when we move to a visitor pattern this should go away let optimized = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)) - .unwrap() + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)) + .data .clone(); - Ok(Transformed::Yes(optimized)) + Ok(Transformed::yes(optimized)) } - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } } diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index f990985518..a06f947746 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use common_error::DaftResult; -use common_treenode::TreeNode; +use common_treenode::{Transformed, TreeNode}; use daft_core::prelude::*; use daft_dsl::{ @@ -19,7 +19,7 @@ use crate::{ LogicalPlan, LogicalPlanRef, }; -use super::{ApplyOrder, OptimizerRule, Transformed}; +use super::OptimizerRule; use common_treenode::DynTreeNode; #[derive(Default, Debug)] @@ -56,8 +56,8 @@ impl PushDownProjection { // Projection discarded but new root node has not been looked at; // look at the new root node. let new_plan = self - .try_optimize(upstream_plan.clone())? - .or(Transformed::Yes(upstream_plan.clone())); + .try_optimize_node(upstream_plan.clone())? + .or(Transformed::yes(upstream_plan.clone())); return Ok(new_plan); } @@ -133,8 +133,8 @@ impl PushDownProjection { // Root node is changed, look at it again. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan.clone())); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan.clone())); return Ok(new_plan); } } @@ -167,14 +167,14 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_source.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } - SourceInfo::InMemory(_) => Ok(Transformed::No(plan)), + SourceInfo::InMemory(_) => Ok(Transformed::no(plan)), SourceInfo::PlaceHolder(..) => { panic!("PlaceHolderInfo should not exist for optimization!"); } @@ -200,11 +200,11 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_upstream.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } LogicalPlan::Aggregate(aggregate) => { @@ -228,11 +228,11 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_upstream.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } LogicalPlan::ActorPoolProject(upstream_actor_pool_projection) => { @@ -295,8 +295,8 @@ impl PushDownProjection { // Retry optimization now that the node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); return Ok(new_plan); } } @@ -335,11 +335,11 @@ impl PushDownProjection { // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } LogicalPlan::Sort(..) @@ -362,7 +362,7 @@ impl PushDownProjection { let grand_upstream_plan = &upstream_plan.arc_children()[0]; let grand_upstream_columns = grand_upstream_plan.schema().names(); if grand_upstream_columns.len() == combined_dependencies.len() { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } let new_subprojection: LogicalPlan = { @@ -378,8 +378,8 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_upstream.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } LogicalPlan::Concat(concat) => { @@ -396,7 +396,7 @@ impl PushDownProjection { let grand_upstream_plan = &upstream_plan.children()[0]; let grand_upstream_columns = grand_upstream_plan.schema().names(); if grand_upstream_columns.len() == combined_dependencies.len() { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } let pushdown_column_exprs: Vec = combined_dependencies @@ -417,8 +417,8 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_upstream.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } LogicalPlan::Join(join) => { @@ -457,9 +457,9 @@ impl PushDownProjection { .collect(); let new_project: LogicalPlan = Project::try_new(side.clone(), pushdown_column_exprs)?.into(); - Ok(Transformed::Yes(new_project.into())) + Ok(Transformed::yes(new_project.into())) } else { - Ok(Transformed::No(side.clone())) + Ok(Transformed::no(side.clone())) } } @@ -474,21 +474,21 @@ impl PushDownProjection { projection_dependencies, )?; - if new_left_upstream.is_no() && new_right_upstream.is_no() { - Ok(Transformed::No(plan)) + if !new_left_upstream.transformed && !new_right_upstream.transformed { + Ok(Transformed::no(plan)) } else { // If either pushdown is possible, create a new Join node. let new_join = upstream_plan.with_new_children(&[ - new_left_upstream.unwrap().clone(), - new_right_upstream.unwrap().clone(), + new_left_upstream.data.clone(), + new_right_upstream.data.clone(), ]); let new_plan = Arc::new(plan.with_new_children(&[new_join.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } @@ -496,11 +496,11 @@ impl PushDownProjection { LogicalPlan::Distinct(_) => { // Cannot push down past a Distinct, // since Distinct implicitly requires all parent columns. - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } LogicalPlan::Pivot(_) | LogicalPlan::MonotonicallyIncreasingId(_) => { // Cannot push down past a Pivot/MonotonicallyIncreasingId because it changes the schema. - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } LogicalPlan::Sink(_) => { panic!("Bad projection due to upstream sink node: {:?}", projection) @@ -530,9 +530,9 @@ impl PushDownProjection { }; let new_actor_pool_project = plan.with_new_children(&[new_subprojection.into()]); - Ok(Transformed::Yes(new_actor_pool_project.into())) + Ok(Transformed::yes(new_actor_pool_project.into())) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } @@ -558,9 +558,9 @@ impl PushDownProjection { }; let new_aggregation = plan.with_new_children(&[new_subprojection.into()]); - Ok(Transformed::Yes(new_aggregation.into())) + Ok(Transformed::yes(new_aggregation.into())) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } @@ -595,13 +595,13 @@ impl PushDownProjection { .arced(); Ok(self - .try_optimize(new_join.clone())? - .or(Transformed::Yes(new_join))) + .try_optimize_node(new_join.clone())? + .or(Transformed::yes(new_join))) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } @@ -627,19 +627,16 @@ impl PushDownProjection { }; let new_pivot = plan.with_new_children(&[new_subprojection.into()]); - Ok(Transformed::Yes(new_pivot.into())) + Ok(Transformed::yes(new_pivot.into())) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } -} - -impl OptimizerRule for PushDownProjection { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize(&self, plan: Arc) -> DaftResult>> { + fn try_optimize_node( + &self, + plan: Arc, + ) -> DaftResult>> { match plan.as_ref() { LogicalPlan::Project(projection) => self.try_optimize_project(projection, plan.clone()), // ActorPoolProjects also do column projection @@ -654,11 +651,17 @@ impl OptimizerRule for PushDownProjection { LogicalPlan::Join(join) => self.try_optimize_join(join, plan.clone()), // Pivots also do column projection LogicalPlan::Pivot(pivot) => self.try_optimize_pivot(pivot, plan.clone()), - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } } +impl OptimizerRule for PushDownProjection { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_down(|node| self.try_optimize_node(node)) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/src/daft-plan/src/logical_optimization/rules/rule.rs b/src/daft-plan/src/logical_optimization/rules/rule.rs index bee69cfc6b..331ee2fe05 100644 --- a/src/daft-plan/src/logical_optimization/rules/rule.rs +++ b/src/daft-plan/src/logical_optimization/rules/rule.rs @@ -1,77 +1,14 @@ use std::sync::Arc; use common_error::DaftResult; +use common_treenode::Transformed; use crate::LogicalPlan; -/// Application order of a rule or rule batch. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum ApplyOrder { - // Apply a rule to a node and then it's children. - TopDown, - #[allow(dead_code)] - // Apply a rule to a node's children and then the node itself. - BottomUp, - #[allow(dead_code)] - // Delegate tree traversal to the rule. - Delegated, -} - /// A logical plan optimization rule. pub trait OptimizerRule { /// Try to optimize the logical plan with this rule. /// - /// This returns Transformed::Yes(new_plan) if the rule modified the plan, Transformed::No(old_plan) otherwise. + /// This returns Transformed::yes(new_plan) if the rule modified the plan, Transformed::no(old_plan) otherwise. fn try_optimize(&self, plan: Arc) -> DaftResult>>; - - /// The plan tree order in which this rule should be applied (top-down, bottom-up, or delegated to rule). - fn apply_order(&self) -> ApplyOrder; -} - -/// An enum indicating whether or not the wrapped data has been transformed. -#[derive(Debug)] -pub enum Transformed { - // Yes, the data has been transformed. - Yes(T), - // No, the data has not been transformed. - No(T), -} - -impl Transformed { - /// Returns self if self is Yes, otherwise returns other. - pub fn or(self, other: Self) -> Self { - match self { - Self::Yes(_) => self, - Self::No(_) => other, - } - } - - /// Returns whether self is No. - pub fn is_no(&self) -> bool { - matches!(self, Self::No(_)) - } - - /// Unwraps the enum and returns a reference to the inner value. - // TODO(Clark): Take ownership of self and return an owned T? - pub fn unwrap(&self) -> &T { - match self { - Self::Yes(inner) => inner, - Self::No(inner) => inner, - } - } - - /// Maps a `Transformed` to `Transformed`, - /// by supplying a function to apply to a contained Yes value - /// as well as a function to apply to a contained No value. - #[inline] - pub fn map_yes_no U, N: FnOnce(T) -> U>( - self, - yes_op: Y, - no_op: N, - ) -> Transformed { - match self { - Self::Yes(t) => Transformed::Yes(yes_op(t)), - Self::No(t) => Transformed::No(no_op(t)), - } - } } diff --git a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs index 1c8b87696b..34e8d491ef 100644 --- a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs +++ b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs @@ -1,7 +1,7 @@ use std::{collections::HashSet, iter, sync::Arc}; use common_error::DaftResult; -use common_treenode::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; +use common_treenode::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter}; use daft_dsl::{ functions::{ python::{PythonUDF, StatefulPythonUDF}, @@ -17,7 +17,7 @@ use crate::{ LogicalPlan, }; -use super::{ApplyOrder, OptimizerRule, Transformed}; +use super::OptimizerRule; #[derive(Default, Debug)] pub struct SplitActorPoolProjects {} @@ -121,16 +121,11 @@ impl SplitActorPoolProjects { /// │ │ │ │ │ │ /// └─────────────────┘ └────────────────────┘ └───────────┘ impl OptimizerRule for SplitActorPoolProjects { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize(&self, plan: Arc) -> DaftResult>> { - match plan.as_ref() { - LogicalPlan::Project(projection) => try_optimize_project(projection, plan.clone(), 0), - // TODO: Figure out how to split other nodes as well such as Filter, Agg etc - _ => Ok(Transformed::No(plan)), - } + plan.transform_down(|node| match node.as_ref() { + LogicalPlan::Project(projection) => try_optimize_project(projection, node.clone(), 0), + _ => Ok(Transformed::no(node)), + }) } } @@ -381,7 +376,7 @@ fn try_optimize_project( // Base case: no stateful UDFs at all let has_stateful_udfs = projection.projection.iter().any(has_stateful_udf); if !has_stateful_udfs { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } log::debug!( @@ -424,7 +419,7 @@ fn try_optimize_project( let new_child_project = LogicalPlan::Project(new_project.clone()).arced(); let optimized_child_plan = try_optimize_project(&new_project, new_child_project.clone(), recursive_count + 1)?; - optimized_child_plan.unwrap().clone() + optimized_child_plan.data.clone() }; // Start building a chain of `child -> Project -> ActorPoolProject -> ActorPoolProject -> ... -> Project` @@ -500,7 +495,7 @@ fn try_optimize_project( )?) .arced(); - Ok(Transformed::Yes(final_selection_project)) + Ok(Transformed::yes(final_selection_project)) } #[inline] diff --git a/src/daft-plan/src/logical_optimization/test/mod.rs b/src/daft-plan/src/logical_optimization/test/mod.rs index 7f26e317e6..a6540da0d5 100644 --- a/src/daft-plan/src/logical_optimization/test/mod.rs +++ b/src/daft-plan/src/logical_optimization/test/mod.rs @@ -25,12 +25,8 @@ pub fn assert_optimized_plan_with_rules_eq( Default::default(), ); let optimized_plan = optimizer - .optimize_with_rules( - optimizer.rule_batches[0].rules.as_slice(), - plan.clone(), - &optimizer.rule_batches[0].order, - )? - .unwrap() + .optimize_with_rules(optimizer.rule_batches[0].rules.as_slice(), plan.clone())? + .data .clone(); assert_eq!( optimized_plan,