Skip to content

Commit

Permalink
[CHORE] Use treenode for tree traversal in logical optimizer rules (#…
Browse files Browse the repository at this point in the history
…2797)

Follow-up from #2791 after offline conversation.

This PR still uses the existing logical plan optimizer, but now uses
`common_treenode::Transformed` instead of a custom implementation. In
addition, it removes the apply order logic from the optimizer and
replaces it with `common_treenode::TreeNode` transforms done in the
optimizer rules themselves.

This PR should not cause any functional changes to any of the optimizer
rules or the way they are applied, except for the fact that rules in a
batch are now each applied to the whole tree before the next rule,
instead of each being applied to a single plan node if they have the
same apply order. Future PRs will separate the rules out and make better
use of treenode.
  • Loading branch information
kevinzwang committed Sep 6, 2024
1 parent 6fe408c commit 91d9fe9
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 421 deletions.
25 changes: 25 additions & 0 deletions src/common/treenode/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,31 @@ impl<T> Transformed<T> {
f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr))
}

/// Returns self if self is transformed, otherwise returns other.
pub fn or(self, other: Self) -> Self {
if self.transformed {
self
} else {
other
}
}

/// Maps a `Transformed<T>` to `Transformed<U>`,
/// by supplying a function to apply to a contained Yes value
/// as well as a function to apply to a contained No value.
#[inline]
pub fn map_yes_no<U, Y: FnOnce(T) -> U, N: FnOnce(T) -> U>(
self,
yes_op: Y,
no_op: N,
) -> Transformed<U> {
if self.transformed {
Transformed::yes(yes_op(self.data))
} else {
Transformed::no(no_op(self.data))
}
}

/// Maps the [`Transformed`] object to the result of the given `f`.
pub fn transform_data<U, F: FnOnce(T) -> Result<Transformed<U>>>(
self,
Expand Down
84 changes: 39 additions & 45 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::sync::Arc;

use common_treenode::Transformed;
use daft_core::prelude::*;
use daft_dsl::{optimization, resolve_exprs, AggExpr, ApproxPercentileParams, Expr, ExprRef};
use indexmap::{IndexMap, IndexSet};
use itertools::Itertools;
use snafu::ResultExt;

use crate::logical_optimization::Transformed;
use crate::logical_plan::{CreationSnafu, Result};
use crate::LogicalPlan;

Expand Down Expand Up @@ -144,7 +144,7 @@ impl Project {
.map(|e| {
let new_expr =
replace_column_with_semantic_id(e.clone(), &subexprs_to_replace, schema);
let new_expr = new_expr.unwrap();
let new_expr = new_expr.data;
// The substitution can unintentionally change the expression's name
// (since the name depends on the first column referenced, which can be substituted away)
// so re-alias the original name here if it has changed.
Expand Down Expand Up @@ -185,10 +185,10 @@ fn replace_column_with_semantic_id(
Expr::Alias(_, name) => Expr::Alias(new_expr.into(), name.clone()),
_ => new_expr,
};
Transformed::Yes(new_expr.into())
Transformed::yes(new_expr.into())
} else {
match e.as_ref() {
Expr::Column(_) | Expr::Literal(_) => Transformed::No(e),
Expr::Column(_) | Expr::Literal(_) => Transformed::no(e),
Expr::Agg(agg_expr) => replace_column_with_semantic_id_aggexpr(
agg_expr.clone(),
subexprs_to_replace,
Expand Down Expand Up @@ -241,11 +241,11 @@ fn replace_column_with_semantic_id(
subexprs_to_replace,
schema,
);
if child.is_no() && fill_value.is_no() {
Transformed::No(e)
if !child.transformed && !fill_value.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Expr::FillNull(child.unwrap().clone(), fill_value.unwrap().clone()).into(),
Transformed::yes(
Expr::FillNull(child.data.clone(), fill_value.data.clone()).into(),
)
}
}
Expand All @@ -254,12 +254,10 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema);
let items =
replace_column_with_semantic_id(items.clone(), subexprs_to_replace, schema);
if child.is_no() && items.is_no() {
Transformed::No(e)
if !child.transformed && !items.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Expr::IsIn(child.unwrap().clone(), items.unwrap().clone()).into(),
)
Transformed::yes(Expr::IsIn(child.data.clone(), items.data.clone()).into())
}
}
Expr::Between(child, lower, upper) => {
Expand All @@ -269,16 +267,12 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(lower.clone(), subexprs_to_replace, schema);
let upper =
replace_column_with_semantic_id(upper.clone(), subexprs_to_replace, schema);
if child.is_no() && lower.is_no() && upper.is_no() {
Transformed::No(e)
if !child.transformed && !lower.transformed && !upper.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Expr::Between(
child.unwrap().clone(),
lower.unwrap().clone(),
upper.unwrap().clone(),
)
.into(),
Transformed::yes(
Expr::Between(child.data.clone(), lower.data.clone(), upper.data.clone())
.into(),
)
}
}
Expand All @@ -287,14 +281,14 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(left.clone(), subexprs_to_replace, schema);
let right =
replace_column_with_semantic_id(right.clone(), subexprs_to_replace, schema);
if left.is_no() && right.is_no() {
Transformed::No(e)
if !left.transformed && !right.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Transformed::yes(
Expr::BinaryOp {
op: *op,
left: left.unwrap().clone(),
right: right.unwrap().clone(),
left: left.data.clone(),
right: right.data.clone(),
}
.into(),
)
Expand All @@ -311,14 +305,14 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(if_true.clone(), subexprs_to_replace, schema);
let if_false =
replace_column_with_semantic_id(if_false.clone(), subexprs_to_replace, schema);
if predicate.is_no() && if_true.is_no() && if_false.is_no() {
Transformed::No(e)
if !predicate.transformed && !if_true.transformed && !if_false.transformed {
Transformed::no(e)
} else {
Transformed::Yes(
Transformed::yes(
Expr::IfElse {
predicate: predicate.unwrap().clone(),
if_true: if_true.unwrap().clone(),
if_false: if_false.unwrap().clone(),
predicate: predicate.data.clone(),
if_true: if_true.data.clone(),
if_false: if_false.data.clone(),
}
.into(),
)
Expand All @@ -331,13 +325,13 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema)
})
.collect::<Vec<_>>();
if transforms.iter().all(|e| e.is_no()) {
Transformed::No(e)
if transforms.iter().all(|e| !e.transformed) {
Transformed::no(e)
} else {
Transformed::Yes(
Transformed::yes(
Expr::Function {
func: func.clone(),
inputs: transforms.iter().map(|t| t.unwrap()).cloned().collect(),
inputs: transforms.iter().map(|t| t.data.clone()).collect(),
}
.into(),
)
Expand All @@ -352,11 +346,11 @@ fn replace_column_with_semantic_id(
replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema)
})
.collect::<Vec<_>>();
if transforms.iter().all(|e| e.is_no()) {
Transformed::No(e)
if transforms.iter().all(|e| !e.transformed) {
Transformed::no(e)
} else {
func.inputs = transforms.iter().map(|t| t.unwrap()).cloned().collect();
Transformed::Yes(Expr::ScalarFunction(func).into())
func.inputs = transforms.iter().map(|t| t.data.clone()).collect();
Transformed::yes(Expr::ScalarFunction(func).into())
}
}
}
Expand Down Expand Up @@ -446,12 +440,12 @@ fn replace_column_with_semantic_id_aggexpr(
.iter()
.map(|e| replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema))
.collect::<Vec<_>>();
if transforms.iter().all(|e| e.is_no()) {
Transformed::No(AggExpr::MapGroups { func, inputs })
if transforms.iter().all(|e| !e.transformed) {
Transformed::no(AggExpr::MapGroups { func, inputs })
} else {
Transformed::Yes(AggExpr::MapGroups {
Transformed::yes(AggExpr::MapGroups {
func: func.clone(),
inputs: transforms.iter().map(|t| t.unwrap()).cloned().collect(),
inputs: transforms.iter().map(|t| t.data.clone()).collect(),
})
}
}
Expand Down
1 change: 0 additions & 1 deletion src/daft-plan/src/logical_optimization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ mod rules;
mod test;

pub use optimizer::{Optimizer, OptimizerConfig};
pub use rules::Transformed;
Loading

0 comments on commit 91d9fe9

Please sign in to comment.