From 91d9fe9175d5feb05c2932236e3480ecb9af00ad Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Thu, 5 Sep 2024 18:18:05 -0700 Subject: [PATCH] [CHORE] Use treenode for tree traversal in logical optimizer rules (#2797) Follow-up from #2791 after offline conversation. This PR still uses the existing logical plan optimizer, but now uses `common_treenode::Transformed` instead of a custom implementation. In addition, it removes the apply order logic from the optimizer and replaces it with `common_treenode::TreeNode` transforms done in the optimizer rules themselves. This PR should not cause any functional changes to any of the optimizer rules or the way they are applied, except for the fact that rules in a batch are now each applied to the whole tree before the next rule, instead of each being applied to a single plan node if they have the same apply order. Future PRs will separate the rules out and make better use of treenode. --- src/common/treenode/src/lib.rs | 25 ++ src/daft-plan/src/logical_ops/project.rs | 84 +++---- src/daft-plan/src/logical_optimization/mod.rs | 1 - .../src/logical_optimization/optimizer.rs | 238 +++++------------- .../rules/drop_repartition.rs | 42 ++-- .../src/logical_optimization/rules/mod.rs | 2 +- .../rules/push_down_filter.rs | 42 ++-- .../rules/push_down_limit.rs | 36 +-- .../rules/push_down_projection.rs | 113 +++++---- .../src/logical_optimization/rules/rule.rs | 67 +---- .../rules/split_actor_pool_projects.rs | 23 +- .../src/logical_optimization/test/mod.rs | 8 +- 12 files changed, 260 insertions(+), 421 deletions(-) diff --git a/src/common/treenode/src/lib.rs b/src/common/treenode/src/lib.rs index 908ebc3530..de7ff59c30 100644 --- a/src/common/treenode/src/lib.rs +++ b/src/common/treenode/src/lib.rs @@ -631,6 +631,31 @@ impl Transformed { f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } + /// Returns self if self is transformed, otherwise returns other. + pub fn or(self, other: Self) -> Self { + if self.transformed { + self + } else { + other + } + } + + /// Maps a `Transformed` to `Transformed`, + /// by supplying a function to apply to a contained Yes value + /// as well as a function to apply to a contained No value. + #[inline] + pub fn map_yes_no U, N: FnOnce(T) -> U>( + self, + yes_op: Y, + no_op: N, + ) -> Transformed { + if self.transformed { + Transformed::yes(yes_op(self.data)) + } else { + Transformed::no(no_op(self.data)) + } + } + /// Maps the [`Transformed`] object to the result of the given `f`. pub fn transform_data Result>>( self, diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 5290f46242..3c344c56aa 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -1,12 +1,12 @@ use std::sync::Arc; +use common_treenode::Transformed; use daft_core::prelude::*; use daft_dsl::{optimization, resolve_exprs, AggExpr, ApproxPercentileParams, Expr, ExprRef}; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; use snafu::ResultExt; -use crate::logical_optimization::Transformed; use crate::logical_plan::{CreationSnafu, Result}; use crate::LogicalPlan; @@ -144,7 +144,7 @@ impl Project { .map(|e| { let new_expr = replace_column_with_semantic_id(e.clone(), &subexprs_to_replace, schema); - let new_expr = new_expr.unwrap(); + let new_expr = new_expr.data; // The substitution can unintentionally change the expression's name // (since the name depends on the first column referenced, which can be substituted away) // so re-alias the original name here if it has changed. @@ -185,10 +185,10 @@ fn replace_column_with_semantic_id( Expr::Alias(_, name) => Expr::Alias(new_expr.into(), name.clone()), _ => new_expr, }; - Transformed::Yes(new_expr.into()) + Transformed::yes(new_expr.into()) } else { match e.as_ref() { - Expr::Column(_) | Expr::Literal(_) => Transformed::No(e), + Expr::Column(_) | Expr::Literal(_) => Transformed::no(e), Expr::Agg(agg_expr) => replace_column_with_semantic_id_aggexpr( agg_expr.clone(), subexprs_to_replace, @@ -241,11 +241,11 @@ fn replace_column_with_semantic_id( subexprs_to_replace, schema, ); - if child.is_no() && fill_value.is_no() { - Transformed::No(e) + if !child.transformed && !fill_value.transformed { + Transformed::no(e) } else { - Transformed::Yes( - Expr::FillNull(child.unwrap().clone(), fill_value.unwrap().clone()).into(), + Transformed::yes( + Expr::FillNull(child.data.clone(), fill_value.data.clone()).into(), ) } } @@ -254,12 +254,10 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema); let items = replace_column_with_semantic_id(items.clone(), subexprs_to_replace, schema); - if child.is_no() && items.is_no() { - Transformed::No(e) + if !child.transformed && !items.transformed { + Transformed::no(e) } else { - Transformed::Yes( - Expr::IsIn(child.unwrap().clone(), items.unwrap().clone()).into(), - ) + Transformed::yes(Expr::IsIn(child.data.clone(), items.data.clone()).into()) } } Expr::Between(child, lower, upper) => { @@ -269,16 +267,12 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(lower.clone(), subexprs_to_replace, schema); let upper = replace_column_with_semantic_id(upper.clone(), subexprs_to_replace, schema); - if child.is_no() && lower.is_no() && upper.is_no() { - Transformed::No(e) + if !child.transformed && !lower.transformed && !upper.transformed { + Transformed::no(e) } else { - Transformed::Yes( - Expr::Between( - child.unwrap().clone(), - lower.unwrap().clone(), - upper.unwrap().clone(), - ) - .into(), + Transformed::yes( + Expr::Between(child.data.clone(), lower.data.clone(), upper.data.clone()) + .into(), ) } } @@ -287,14 +281,14 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(left.clone(), subexprs_to_replace, schema); let right = replace_column_with_semantic_id(right.clone(), subexprs_to_replace, schema); - if left.is_no() && right.is_no() { - Transformed::No(e) + if !left.transformed && !right.transformed { + Transformed::no(e) } else { - Transformed::Yes( + Transformed::yes( Expr::BinaryOp { op: *op, - left: left.unwrap().clone(), - right: right.unwrap().clone(), + left: left.data.clone(), + right: right.data.clone(), } .into(), ) @@ -311,14 +305,14 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(if_true.clone(), subexprs_to_replace, schema); let if_false = replace_column_with_semantic_id(if_false.clone(), subexprs_to_replace, schema); - if predicate.is_no() && if_true.is_no() && if_false.is_no() { - Transformed::No(e) + if !predicate.transformed && !if_true.transformed && !if_false.transformed { + Transformed::no(e) } else { - Transformed::Yes( + Transformed::yes( Expr::IfElse { - predicate: predicate.unwrap().clone(), - if_true: if_true.unwrap().clone(), - if_false: if_false.unwrap().clone(), + predicate: predicate.data.clone(), + if_true: if_true.data.clone(), + if_false: if_false.data.clone(), } .into(), ) @@ -331,13 +325,13 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema) }) .collect::>(); - if transforms.iter().all(|e| e.is_no()) { - Transformed::No(e) + if transforms.iter().all(|e| !e.transformed) { + Transformed::no(e) } else { - Transformed::Yes( + Transformed::yes( Expr::Function { func: func.clone(), - inputs: transforms.iter().map(|t| t.unwrap()).cloned().collect(), + inputs: transforms.iter().map(|t| t.data.clone()).collect(), } .into(), ) @@ -352,11 +346,11 @@ fn replace_column_with_semantic_id( replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema) }) .collect::>(); - if transforms.iter().all(|e| e.is_no()) { - Transformed::No(e) + if transforms.iter().all(|e| !e.transformed) { + Transformed::no(e) } else { - func.inputs = transforms.iter().map(|t| t.unwrap()).cloned().collect(); - Transformed::Yes(Expr::ScalarFunction(func).into()) + func.inputs = transforms.iter().map(|t| t.data.clone()).collect(); + Transformed::yes(Expr::ScalarFunction(func).into()) } } } @@ -446,12 +440,12 @@ fn replace_column_with_semantic_id_aggexpr( .iter() .map(|e| replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema)) .collect::>(); - if transforms.iter().all(|e| e.is_no()) { - Transformed::No(AggExpr::MapGroups { func, inputs }) + if transforms.iter().all(|e| !e.transformed) { + Transformed::no(AggExpr::MapGroups { func, inputs }) } else { - Transformed::Yes(AggExpr::MapGroups { + Transformed::yes(AggExpr::MapGroups { func: func.clone(), - inputs: transforms.iter().map(|t| t.unwrap()).cloned().collect(), + inputs: transforms.iter().map(|t| t.data.clone()).collect(), }) } } diff --git a/src/daft-plan/src/logical_optimization/mod.rs b/src/daft-plan/src/logical_optimization/mod.rs index d270947f19..37bd2306cd 100644 --- a/src/daft-plan/src/logical_optimization/mod.rs +++ b/src/daft-plan/src/logical_optimization/mod.rs @@ -5,4 +5,3 @@ mod rules; mod test; pub use optimizer::{Optimizer, OptimizerConfig}; -pub use rules::Transformed; diff --git a/src/daft-plan/src/logical_optimization/optimizer.rs b/src/daft-plan/src/logical_optimization/optimizer.rs index 25191955d3..fd411d021c 100644 --- a/src/daft-plan/src/logical_optimization/optimizer.rs +++ b/src/daft-plan/src/logical_optimization/optimizer.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, ops::ControlFlow, sync::Arc}; +use std::{ops::ControlFlow, sync::Arc}; use common_error::DaftResult; @@ -7,11 +7,11 @@ use crate::LogicalPlan; use super::{ logical_plan_tracker::LogicalPlanTracker, rules::{ - ApplyOrder, DropRepartition, OptimizerRule, PushDownFilter, PushDownLimit, - PushDownProjection, SplitActorPoolProjects, Transformed, + DropRepartition, OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, + SplitActorPoolProjects, }, }; -use common_treenode::DynTreeNode; +use common_treenode::Transformed; /// Config for optimizer. #[derive(Debug)] @@ -49,52 +49,12 @@ pub struct RuleBatch { pub rules: Vec>, // The rule execution strategy (once, fixed-point). pub strategy: RuleExecutionStrategy, - // The application order for the entire rule batch, derived from the application - // order of the contained rules. - // If all rules in the batch have the same application order (e.g. top-down), the - // optimizer will apply the rules on a single tree traversal, where a given node - // in the plan tree will be transformed sequentially by each rule in the batch before - // moving on to the next node. - pub order: Option, } impl RuleBatch { pub fn new(rules: Vec>, strategy: RuleExecutionStrategy) -> Self { // Get all unique application orders for the rules. - let unique_application_orders: Vec = rules - .iter() - .map(|rule| rule.apply_order()) - .collect::>() - .into_iter() - .collect(); - let order = match unique_application_orders.as_slice() { - // All rules have the same application order, so use that as the application order for - // the entire batch. - [order] => Some(order.clone()), - // If rules have different application orders, run each rule as a separate tree pass with its own application order. - _ => None, - }; - Self { - rules, - strategy, - order, - } - } - - #[allow(dead_code)] - pub fn with_order( - rules: Vec>, - strategy: RuleExecutionStrategy, - order: Option, - ) -> Self { - debug_assert!(order.clone().map_or(true, |order| rules - .iter() - .all(|rule| rule.apply_order() == order))); - Self { - rules, - strategy, - order, - } + Self { rules, strategy } } /// Get the maximum number of passes the optimizer should make over this rule batch. @@ -211,8 +171,12 @@ impl Optimizer { let result = (0..batch.max_passes(&self.config)).try_fold( plan, |plan, pass| -> ControlFlow>, Arc> { - match self.optimize_with_rules(batch.rules.as_slice(), plan, &batch.order) { - Ok(Transformed::Yes(new_plan)) => { + match self.optimize_with_rules(batch.rules.as_slice(), plan) { + Ok(Transformed { + data: new_plan, + transformed: true, + .. + }) => { // Plan was transformed by the rule batch. if plan_tracker.add_plan(new_plan.as_ref()) { // Transformed plan has not yet been seen by this optimizer, which means we have @@ -226,7 +190,11 @@ impl Optimizer { ControlFlow::Break(Ok(new_plan)) } } - Ok(Transformed::No(plan)) => { + Ok(Transformed { + data: plan, + transformed: false, + .. + }) => { // Plan was not transformed by the rule batch, suggesting that we have reached a fixed-point. // We therefore stop applying this rule batch. observer(plan.as_ref(), batch, pass, false, false); @@ -243,96 +211,17 @@ impl Optimizer { } } - /// Optimize the provided plan with the provided rules using the provided application order. - /// - /// If order.is_some(), all rules are expected to have that application order. + /// Optimize the provided plan with each of the provided rules once. pub fn optimize_with_rules( &self, rules: &[Box], plan: Arc, - order: &Option, - ) -> DaftResult>> { - // Double-check that all rules have the same application order as `order`, if `order` is not None. - debug_assert!(order.clone().map_or(true, |order| rules - .iter() - .all(|rule| rule.apply_order() == order))); - match order { - // Perform a single top-down traversal and apply all rules on each node. - Some(ApplyOrder::TopDown) => { - // First optimize the current node, and then it's children. - let curr_opt = self.optimize_node(rules, plan)?; - let children_opt = - self.optimize_children(rules, curr_opt.unwrap().clone(), ApplyOrder::TopDown)?; - Ok(children_opt.or(curr_opt)) - } - // Perform a single bottom-up traversal and apply all rules on each node. - Some(ApplyOrder::BottomUp) => { - // First optimize the current node's children, and then the current node. - let children_opt = self.optimize_children(rules, plan, ApplyOrder::BottomUp)?; - let curr_opt = self.optimize_node(rules, children_opt.unwrap().clone())?; - Ok(curr_opt.or(children_opt)) - } - // All rules do their own internal tree traversals. - Some(ApplyOrder::Delegated) => self.optimize_node(rules, plan), - // Rule batch doesn't share a single application order, so we apply each rule with its own dedicated tree pass. - None => rules - .windows(1) - .try_fold(Transformed::No(plan), |plan, rule| { - self.optimize_with_rules( - rule, - plan.unwrap().clone(), - &Some(rule[0].apply_order()), - ) - }), - } - } - - /// Optimize a single plan node with the provided rules. - /// - /// This method does not drive traversal of the tree unless the tree is traversed by the rule itself, - /// in rule.try_optimize(). - fn optimize_node( - &self, - rules: &[Box], - plan: Arc, ) -> DaftResult>> { // Fold over the rules, applying each rule to this plan node sequentially. - rules.iter().try_fold(Transformed::No(plan), |plan, rule| { - Ok(rule.try_optimize(plan.unwrap().clone())?.or(plan)) + rules.iter().try_fold(Transformed::no(plan), |plan, rule| { + plan.transform_data(|data| rule.try_optimize(data)) }) } - - /// Optimize the children of the provided plan, updating the provided plan's pointed-to children - /// if the children are transformed. - fn optimize_children( - &self, - rules: &[Box], - plan: Arc, - order: ApplyOrder, - ) -> DaftResult>> { - // Run optimization rules on children. - let children = plan.arc_children(); - let result = children - .into_iter() - .map(|child_plan| { - self.optimize_with_rules(rules, child_plan.clone(), &Some(order.clone())) - }) - .collect::>>()?; - // If the optimization rule didn't change any of the children, return without modifying the plan. - if result.is_empty() || result.iter().all(|o| o.is_no()) { - return Ok(Transformed::No(plan)); - } - // Otherwise, update the parent to point to its optimized children. - let new_children = result - .into_iter() - .map(|maybe_opt_child| maybe_opt_child.unwrap().clone()) - .collect::>(); - - // Return new plan with optimized children. - Ok(Transformed::Yes( - plan.with_new_children(&new_children).into(), - )) - } } #[cfg(test)] @@ -340,13 +229,14 @@ mod tests { use std::sync::{Arc, Mutex}; use common_error::DaftResult; + use common_treenode::{Transformed, TreeNode}; use daft_core::prelude::*; use daft_dsl::{col, lit}; use crate::{ logical_ops::{Filter, Project}, - logical_optimization::rules::{ApplyOrder, OptimizerRule, Transformed}, + logical_optimization::rules::OptimizerRule, test::{dummy_scan_node, dummy_scan_operator}, LogicalPlan, }; @@ -388,15 +278,11 @@ mod tests { } impl OptimizerRule for NoOp { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize( &self, plan: Arc, ) -> DaftResult>> { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } @@ -552,22 +438,20 @@ mod tests { } impl OptimizerRule for FilterOrFalse { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize( &self, plan: Arc, ) -> DaftResult>> { - let filter = match plan.as_ref() { - LogicalPlan::Filter(filter) => filter.clone(), - _ => return Ok(Transformed::No(plan)), - }; - let new_predicate = filter.predicate.or(lit(false)); - Ok(Transformed::Yes( - LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), - )) + plan.transform_down(|node| { + let filter = match node.as_ref() { + LogicalPlan::Filter(filter) => filter.clone(), + _ => return Ok(Transformed::no(node)), + }; + let new_predicate = filter.predicate.or(lit(false)); + Ok(Transformed::yes( + LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), + )) + }) } } @@ -581,22 +465,20 @@ mod tests { } impl OptimizerRule for FilterAndTrue { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize( &self, plan: Arc, ) -> DaftResult>> { - let filter = match plan.as_ref() { - LogicalPlan::Filter(filter) => filter.clone(), - _ => return Ok(Transformed::No(plan)), - }; - let new_predicate = filter.predicate.and(lit(true)); - Ok(Transformed::Yes( - LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), - )) + plan.transform_down(|node| { + let filter = match node.as_ref() { + LogicalPlan::Filter(filter) => filter.clone(), + _ => return Ok(Transformed::no(node)), + }; + let new_predicate = filter.predicate.and(lit(true)); + Ok(Transformed::yes( + LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), + )) + }) } } @@ -614,29 +496,27 @@ mod tests { } impl OptimizerRule for RotateProjection { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize( &self, plan: Arc, ) -> DaftResult>> { - let project = match plan.as_ref() { - LogicalPlan::Project(project) => project.clone(), - _ => return Ok(Transformed::No(plan)), - }; - let mut exprs = project.projection.clone(); - let mut reverse = self.reverse_first.lock().unwrap(); - if *reverse { - exprs.reverse(); - *reverse = false; - } else { - exprs.rotate_left(1); - } - Ok(Transformed::Yes( - LogicalPlan::from(Project::try_new(project.input.clone(), exprs)?).into(), - )) + plan.transform_down(|node| { + let project = match node.as_ref() { + LogicalPlan::Project(project) => project.clone(), + _ => return Ok(Transformed::no(node)), + }; + let mut exprs = project.projection.clone(); + let mut reverse = self.reverse_first.lock().unwrap(); + if *reverse { + exprs.reverse(); + *reverse = false; + } else { + exprs.rotate_left(1); + } + Ok(Transformed::yes( + LogicalPlan::from(Project::try_new(project.input.clone(), exprs)?).into(), + )) + }) } } } diff --git a/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs b/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs index c8bbdaa7a0..838623e30b 100644 --- a/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs +++ b/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs @@ -4,9 +4,9 @@ use common_error::DaftResult; use crate::LogicalPlan; -use super::{ApplyOrder, OptimizerRule, Transformed}; +use super::OptimizerRule; -use common_treenode::DynTreeNode; +use common_treenode::{DynTreeNode, Transformed, TreeNode}; /// Optimization rules for dropping unnecessary Repartitions. /// @@ -22,27 +22,25 @@ impl DropRepartition { } impl OptimizerRule for DropRepartition { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize(&self, plan: Arc) -> DaftResult>> { - let repartition = match plan.as_ref() { - LogicalPlan::Repartition(repartition) => repartition, - _ => return Ok(Transformed::No(plan)), - }; - let child_plan = repartition.input.as_ref(); - let new_plan = match child_plan { - LogicalPlan::Repartition(_) => { - // Drop upstream Repartition for back-to-back Repartitions. - // - // Repartition1-Repartition2 -> Repartition1 - plan.with_new_children(&[child_plan.arc_children()[0].clone()]) - .into() - } - _ => return Ok(Transformed::No(plan)), - }; - Ok(Transformed::Yes(new_plan)) + plan.transform_down(|node| { + let repartition = match node.as_ref() { + LogicalPlan::Repartition(repartition) => repartition, + _ => return Ok(Transformed::no(node)), + }; + let child_plan = repartition.input.as_ref(); + let new_plan = match child_plan { + LogicalPlan::Repartition(_) => { + // Drop upstream Repartition for back-to-back Repartitions. + // + // Repartition1-Repartition2 -> Repartition1 + node.with_new_children(&[child_plan.arc_children()[0].clone()]) + .into() + } + _ => return Ok(Transformed::no(node)), + }; + Ok(Transformed::yes(new_plan)) + }) } } diff --git a/src/daft-plan/src/logical_optimization/rules/mod.rs b/src/daft-plan/src/logical_optimization/rules/mod.rs index fc137589c5..ac8579123a 100644 --- a/src/daft-plan/src/logical_optimization/rules/mod.rs +++ b/src/daft-plan/src/logical_optimization/rules/mod.rs @@ -9,5 +9,5 @@ pub use drop_repartition::DropRepartition; pub use push_down_filter::PushDownFilter; pub use push_down_limit::PushDownLimit; pub use push_down_projection::PushDownProjection; -pub use rule::{ApplyOrder, OptimizerRule, Transformed}; +pub use rule::OptimizerRule; pub use split_actor_pool_projects::SplitActorPoolProjects; diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs index c180acbeff..a8961b470d 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs @@ -20,8 +20,8 @@ use crate::{ LogicalPlan, }; -use super::{ApplyOrder, OptimizerRule, Transformed}; -use common_treenode::DynTreeNode; +use super::OptimizerRule; +use common_treenode::{DynTreeNode, Transformed, TreeNode}; /// Optimization rules for pushing Filters further into the logical plan. #[derive(Default, Debug)] @@ -34,14 +34,20 @@ impl PushDownFilter { } impl OptimizerRule for PushDownFilter { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_down(|node| self.try_optimize_node(node)) } +} - fn try_optimize(&self, plan: Arc) -> DaftResult>> { +impl PushDownFilter { + #[allow(clippy::only_used_in_recursion)] + fn try_optimize_node( + &self, + plan: Arc, + ) -> DaftResult>> { let filter = match plan.as_ref() { LogicalPlan::Filter(filter) => filter, - _ => return Ok(Transformed::No(plan)), + _ => return Ok(Transformed::no(plan)), }; let child_plan = filter.input.as_ref(); let new_plan = match child_plan { @@ -68,20 +74,20 @@ impl OptimizerRule for PushDownFilter { let new_filter: Arc = LogicalPlan::from(Filter::try_new(child_filter.input.clone(), new_predicate)?) .into(); - self.try_optimize(new_filter.clone())? - .or(Transformed::Yes(new_filter)) - .unwrap() + self.try_optimize_node(new_filter.clone())? + .or(Transformed::yes(new_filter)) + .data .clone() } LogicalPlan::Source(source) => { match source.source_info.as_ref() { // Filter pushdown is not supported for in-memory sources. - SourceInfo::InMemory(_) => return Ok(Transformed::No(plan)), + SourceInfo::InMemory(_) => return Ok(Transformed::no(plan)), // Do not pushdown if Source node already has a limit SourceInfo::Physical(external_info) if let Some(_) = external_info.pushdowns.limit => { - return Ok(Transformed::No(plan)) + return Ok(Transformed::no(plan)) } // Pushdown filter into the Source node @@ -126,7 +132,7 @@ impl OptimizerRule for PushDownFilter { // column and a data column (or contain a UDF), then no pushdown into the scan is possible, // so we short-circuit. // TODO(Clark): Support pushing predicates referencing both partition and data columns into the scan. - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } let data_filter = conjuct(data_only_filter); @@ -157,9 +163,9 @@ impl OptimizerRule for PushDownFilter { conjuct(needing_filter_op).unwrap(), )? .into(); - return Ok(Transformed::Yes(filter_op.into())); + return Ok(Transformed::yes(filter_op.into())); } else { - return Ok(Transformed::Yes(new_source.into())); + return Ok(Transformed::yes(new_source.into())); } } SourceInfo::PlaceHolder(..) => { @@ -205,7 +211,7 @@ impl OptimizerRule for PushDownFilter { } if can_push.is_empty() { // No predicate expressions can be pushed through projection. - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } // Create new Filter with predicates that can be pushed past Projection. let predicates_to_push = conjuct(can_push).unwrap(); @@ -335,12 +341,12 @@ impl OptimizerRule for PushDownFilter { new_join } } else { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } } - _ => return Ok(Transformed::No(plan)), + _ => return Ok(Transformed::no(plan)), }; - Ok(Transformed::Yes(new_plan)) + Ok(Transformed::yes(new_plan)) } } diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs b/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs index 405889af32..2359cbdcfa 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs @@ -8,8 +8,8 @@ use crate::{ LogicalPlan, }; -use super::{ApplyOrder, OptimizerRule, Transformed}; -use common_treenode::DynTreeNode; +use super::OptimizerRule; +use common_treenode::{DynTreeNode, Transformed, TreeNode}; /// Optimization rules for pushing Limits further into the logical plan. #[derive(Default, Debug)] @@ -22,11 +22,17 @@ impl PushDownLimit { } impl OptimizerRule for PushDownLimit { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_down(|node| self.try_optimize_node(node)) } +} - fn try_optimize(&self, plan: Arc) -> DaftResult>> { +impl PushDownLimit { + #[allow(clippy::only_used_in_recursion)] + fn try_optimize_node( + &self, + plan: Arc, + ) -> DaftResult>> { match plan.as_ref() { LogicalPlan::Limit(LogicalLimit { input, @@ -42,7 +48,7 @@ impl OptimizerRule for PushDownLimit { let new_limit = plan .with_new_children(&[input.arc_children()[0].clone()]) .into(); - Ok(Transformed::Yes( + Ok(Transformed::yes( input.with_new_children(&[new_limit]).into(), )) } @@ -52,13 +58,13 @@ impl OptimizerRule for PushDownLimit { LogicalPlan::Source(source) => { match source.source_info.as_ref() { // Limit pushdown is not supported for in-memory sources. - SourceInfo::InMemory(_) => Ok(Transformed::No(plan)), + SourceInfo::InMemory(_) => Ok(Transformed::no(plan)), // Do not pushdown if Source node is already more limited than `limit` SourceInfo::Physical(external_info) if let Some(existing_limit) = external_info.pushdowns.limit && existing_limit <= limit => { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } // Pushdown limit into the Source node as a "local" limit SourceInfo::Physical(external_info) => { @@ -74,7 +80,7 @@ impl OptimizerRule for PushDownLimit { } else { plan.with_new_children(&[new_source]).into() }; - Ok(Transformed::Yes(out_plan)) + Ok(Transformed::yes(out_plan)) } SourceInfo::PlaceHolder(..) => { panic!("PlaceHolderInfo should not exist for optimization!"); @@ -99,16 +105,16 @@ impl OptimizerRule for PushDownLimit { ))); // we rerun the optimizer, ideally when we move to a visitor pattern this should go away let optimized = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)) - .unwrap() + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)) + .data .clone(); - Ok(Transformed::Yes(optimized)) + Ok(Transformed::yes(optimized)) } - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } } diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index f990985518..a06f947746 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use common_error::DaftResult; -use common_treenode::TreeNode; +use common_treenode::{Transformed, TreeNode}; use daft_core::prelude::*; use daft_dsl::{ @@ -19,7 +19,7 @@ use crate::{ LogicalPlan, LogicalPlanRef, }; -use super::{ApplyOrder, OptimizerRule, Transformed}; +use super::OptimizerRule; use common_treenode::DynTreeNode; #[derive(Default, Debug)] @@ -56,8 +56,8 @@ impl PushDownProjection { // Projection discarded but new root node has not been looked at; // look at the new root node. let new_plan = self - .try_optimize(upstream_plan.clone())? - .or(Transformed::Yes(upstream_plan.clone())); + .try_optimize_node(upstream_plan.clone())? + .or(Transformed::yes(upstream_plan.clone())); return Ok(new_plan); } @@ -133,8 +133,8 @@ impl PushDownProjection { // Root node is changed, look at it again. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan.clone())); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan.clone())); return Ok(new_plan); } } @@ -167,14 +167,14 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_source.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } - SourceInfo::InMemory(_) => Ok(Transformed::No(plan)), + SourceInfo::InMemory(_) => Ok(Transformed::no(plan)), SourceInfo::PlaceHolder(..) => { panic!("PlaceHolderInfo should not exist for optimization!"); } @@ -200,11 +200,11 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_upstream.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } LogicalPlan::Aggregate(aggregate) => { @@ -228,11 +228,11 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_upstream.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } LogicalPlan::ActorPoolProject(upstream_actor_pool_projection) => { @@ -295,8 +295,8 @@ impl PushDownProjection { // Retry optimization now that the node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); return Ok(new_plan); } } @@ -335,11 +335,11 @@ impl PushDownProjection { // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } LogicalPlan::Sort(..) @@ -362,7 +362,7 @@ impl PushDownProjection { let grand_upstream_plan = &upstream_plan.arc_children()[0]; let grand_upstream_columns = grand_upstream_plan.schema().names(); if grand_upstream_columns.len() == combined_dependencies.len() { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } let new_subprojection: LogicalPlan = { @@ -378,8 +378,8 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_upstream.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } LogicalPlan::Concat(concat) => { @@ -396,7 +396,7 @@ impl PushDownProjection { let grand_upstream_plan = &upstream_plan.children()[0]; let grand_upstream_columns = grand_upstream_plan.schema().names(); if grand_upstream_columns.len() == combined_dependencies.len() { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } let pushdown_column_exprs: Vec = combined_dependencies @@ -417,8 +417,8 @@ impl PushDownProjection { let new_plan = Arc::new(plan.with_new_children(&[new_upstream.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } LogicalPlan::Join(join) => { @@ -457,9 +457,9 @@ impl PushDownProjection { .collect(); let new_project: LogicalPlan = Project::try_new(side.clone(), pushdown_column_exprs)?.into(); - Ok(Transformed::Yes(new_project.into())) + Ok(Transformed::yes(new_project.into())) } else { - Ok(Transformed::No(side.clone())) + Ok(Transformed::no(side.clone())) } } @@ -474,21 +474,21 @@ impl PushDownProjection { projection_dependencies, )?; - if new_left_upstream.is_no() && new_right_upstream.is_no() { - Ok(Transformed::No(plan)) + if !new_left_upstream.transformed && !new_right_upstream.transformed { + Ok(Transformed::no(plan)) } else { // If either pushdown is possible, create a new Join node. let new_join = upstream_plan.with_new_children(&[ - new_left_upstream.unwrap().clone(), - new_right_upstream.unwrap().clone(), + new_left_upstream.data.clone(), + new_right_upstream.data.clone(), ]); let new_plan = Arc::new(plan.with_new_children(&[new_join.into()])); // Retry optimization now that the upstream node is different. let new_plan = self - .try_optimize(new_plan.clone())? - .or(Transformed::Yes(new_plan)); + .try_optimize_node(new_plan.clone())? + .or(Transformed::yes(new_plan)); Ok(new_plan) } @@ -496,11 +496,11 @@ impl PushDownProjection { LogicalPlan::Distinct(_) => { // Cannot push down past a Distinct, // since Distinct implicitly requires all parent columns. - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } LogicalPlan::Pivot(_) | LogicalPlan::MonotonicallyIncreasingId(_) => { // Cannot push down past a Pivot/MonotonicallyIncreasingId because it changes the schema. - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } LogicalPlan::Sink(_) => { panic!("Bad projection due to upstream sink node: {:?}", projection) @@ -530,9 +530,9 @@ impl PushDownProjection { }; let new_actor_pool_project = plan.with_new_children(&[new_subprojection.into()]); - Ok(Transformed::Yes(new_actor_pool_project.into())) + Ok(Transformed::yes(new_actor_pool_project.into())) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } @@ -558,9 +558,9 @@ impl PushDownProjection { }; let new_aggregation = plan.with_new_children(&[new_subprojection.into()]); - Ok(Transformed::Yes(new_aggregation.into())) + Ok(Transformed::yes(new_aggregation.into())) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } @@ -595,13 +595,13 @@ impl PushDownProjection { .arced(); Ok(self - .try_optimize(new_join.clone())? - .or(Transformed::Yes(new_join))) + .try_optimize_node(new_join.clone())? + .or(Transformed::yes(new_join))) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } @@ -627,19 +627,16 @@ impl PushDownProjection { }; let new_pivot = plan.with_new_children(&[new_subprojection.into()]); - Ok(Transformed::Yes(new_pivot.into())) + Ok(Transformed::yes(new_pivot.into())) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } -} - -impl OptimizerRule for PushDownProjection { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize(&self, plan: Arc) -> DaftResult>> { + fn try_optimize_node( + &self, + plan: Arc, + ) -> DaftResult>> { match plan.as_ref() { LogicalPlan::Project(projection) => self.try_optimize_project(projection, plan.clone()), // ActorPoolProjects also do column projection @@ -654,11 +651,17 @@ impl OptimizerRule for PushDownProjection { LogicalPlan::Join(join) => self.try_optimize_join(join, plan.clone()), // Pivots also do column projection LogicalPlan::Pivot(pivot) => self.try_optimize_pivot(pivot, plan.clone()), - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } } +impl OptimizerRule for PushDownProjection { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_down(|node| self.try_optimize_node(node)) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/src/daft-plan/src/logical_optimization/rules/rule.rs b/src/daft-plan/src/logical_optimization/rules/rule.rs index bee69cfc6b..331ee2fe05 100644 --- a/src/daft-plan/src/logical_optimization/rules/rule.rs +++ b/src/daft-plan/src/logical_optimization/rules/rule.rs @@ -1,77 +1,14 @@ use std::sync::Arc; use common_error::DaftResult; +use common_treenode::Transformed; use crate::LogicalPlan; -/// Application order of a rule or rule batch. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum ApplyOrder { - // Apply a rule to a node and then it's children. - TopDown, - #[allow(dead_code)] - // Apply a rule to a node's children and then the node itself. - BottomUp, - #[allow(dead_code)] - // Delegate tree traversal to the rule. - Delegated, -} - /// A logical plan optimization rule. pub trait OptimizerRule { /// Try to optimize the logical plan with this rule. /// - /// This returns Transformed::Yes(new_plan) if the rule modified the plan, Transformed::No(old_plan) otherwise. + /// This returns Transformed::yes(new_plan) if the rule modified the plan, Transformed::no(old_plan) otherwise. fn try_optimize(&self, plan: Arc) -> DaftResult>>; - - /// The plan tree order in which this rule should be applied (top-down, bottom-up, or delegated to rule). - fn apply_order(&self) -> ApplyOrder; -} - -/// An enum indicating whether or not the wrapped data has been transformed. -#[derive(Debug)] -pub enum Transformed { - // Yes, the data has been transformed. - Yes(T), - // No, the data has not been transformed. - No(T), -} - -impl Transformed { - /// Returns self if self is Yes, otherwise returns other. - pub fn or(self, other: Self) -> Self { - match self { - Self::Yes(_) => self, - Self::No(_) => other, - } - } - - /// Returns whether self is No. - pub fn is_no(&self) -> bool { - matches!(self, Self::No(_)) - } - - /// Unwraps the enum and returns a reference to the inner value. - // TODO(Clark): Take ownership of self and return an owned T? - pub fn unwrap(&self) -> &T { - match self { - Self::Yes(inner) => inner, - Self::No(inner) => inner, - } - } - - /// Maps a `Transformed` to `Transformed`, - /// by supplying a function to apply to a contained Yes value - /// as well as a function to apply to a contained No value. - #[inline] - pub fn map_yes_no U, N: FnOnce(T) -> U>( - self, - yes_op: Y, - no_op: N, - ) -> Transformed { - match self { - Self::Yes(t) => Transformed::Yes(yes_op(t)), - Self::No(t) => Transformed::No(no_op(t)), - } - } } diff --git a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs index 1c8b87696b..34e8d491ef 100644 --- a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs +++ b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs @@ -1,7 +1,7 @@ use std::{collections::HashSet, iter, sync::Arc}; use common_error::DaftResult; -use common_treenode::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; +use common_treenode::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter}; use daft_dsl::{ functions::{ python::{PythonUDF, StatefulPythonUDF}, @@ -17,7 +17,7 @@ use crate::{ LogicalPlan, }; -use super::{ApplyOrder, OptimizerRule, Transformed}; +use super::OptimizerRule; #[derive(Default, Debug)] pub struct SplitActorPoolProjects {} @@ -121,16 +121,11 @@ impl SplitActorPoolProjects { /// │ │ │ │ │ │ /// └─────────────────┘ └────────────────────┘ └───────────┘ impl OptimizerRule for SplitActorPoolProjects { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - fn try_optimize(&self, plan: Arc) -> DaftResult>> { - match plan.as_ref() { - LogicalPlan::Project(projection) => try_optimize_project(projection, plan.clone(), 0), - // TODO: Figure out how to split other nodes as well such as Filter, Agg etc - _ => Ok(Transformed::No(plan)), - } + plan.transform_down(|node| match node.as_ref() { + LogicalPlan::Project(projection) => try_optimize_project(projection, node.clone(), 0), + _ => Ok(Transformed::no(node)), + }) } } @@ -381,7 +376,7 @@ fn try_optimize_project( // Base case: no stateful UDFs at all let has_stateful_udfs = projection.projection.iter().any(has_stateful_udf); if !has_stateful_udfs { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } log::debug!( @@ -424,7 +419,7 @@ fn try_optimize_project( let new_child_project = LogicalPlan::Project(new_project.clone()).arced(); let optimized_child_plan = try_optimize_project(&new_project, new_child_project.clone(), recursive_count + 1)?; - optimized_child_plan.unwrap().clone() + optimized_child_plan.data.clone() }; // Start building a chain of `child -> Project -> ActorPoolProject -> ActorPoolProject -> ... -> Project` @@ -500,7 +495,7 @@ fn try_optimize_project( )?) .arced(); - Ok(Transformed::Yes(final_selection_project)) + Ok(Transformed::yes(final_selection_project)) } #[inline] diff --git a/src/daft-plan/src/logical_optimization/test/mod.rs b/src/daft-plan/src/logical_optimization/test/mod.rs index 7f26e317e6..a6540da0d5 100644 --- a/src/daft-plan/src/logical_optimization/test/mod.rs +++ b/src/daft-plan/src/logical_optimization/test/mod.rs @@ -25,12 +25,8 @@ pub fn assert_optimized_plan_with_rules_eq( Default::default(), ); let optimized_plan = optimizer - .optimize_with_rules( - optimizer.rule_batches[0].rules.as_slice(), - plan.clone(), - &optimizer.rule_batches[0].order, - )? - .unwrap() + .optimize_with_rules(optimizer.rule_batches[0].rules.as_slice(), plan.clone())? + .data .clone(); assert_eq!( optimized_plan,