Skip to content

Commit

Permalink
[FEAT] Add all projection optimization rules to new query planner. (#…
Browse files Browse the repository at this point in the history
…1288)

This PR includes all the rules we currently have for optimizing
Projections, but is organized slightly differently.

When run on a Projection node, ProjectionPushdown will:

1. Delete the node if it is unnecessary; or else
2. Trim columns in upstream Projections, Aggregates, and Source nodes;
or else
3. Create new projections behind other upstream nodes if possible. 

Finally, if a change has been made, this rule recurses on the new node
immediately. This ensures that old projection nodes that have become
newly redundant are pruned immediately, while still having a clean
separation of logic such that none of the logic in {creating, deleting,
modifying} nodes have to worry about the other.

----

There is also a rule for Aggregate (since it also projects columns from
its parent). If the Aggregate implicitly drops columns from its parent,
a Projection is created upstream to make the drop explicit. This rule
will only ever fire once, since projection pushdown monotonically
decreases the number of columns across the plan. Then, the newly created
projection will be optimized just like any other projection.

----

This PR also updates the tabular scan shim to include the ability to
pass in specific columns to read. We also rename schemas in source nodes
and associated structs to make it explicit whether the schema is of the
source or of the intended plan output.

---------

Co-authored-by: Xiayue Charles Lin <[email protected]>
Co-authored-by: Clark Zinzow <[email protected]>
  • Loading branch information
3 people authored Aug 29, 2023
1 parent afc874b commit 337f0fc
Show file tree
Hide file tree
Showing 20 changed files with 653 additions and 77 deletions.
1 change: 0 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def file_read(
Yield a plan to read those filenames.
"""

materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque()
output_partition_index = 0

Expand Down
15 changes: 13 additions & 2 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@


def tabular_scan(
schema: PySchema, file_info_table: PyTable, file_format_config: FileFormatConfig, limit: int
schema: PySchema,
columns_to_read: list[str],
file_info_table: PyTable,
file_format_config: FileFormatConfig,
limit: int,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
# TODO(Clark): Fix this Ray runner hack.
part = Table._from_pytable(file_info_table)
Expand All @@ -36,8 +40,15 @@ def tabular_scan(

file_info_iter = physical_plan.partition_read(iter(parts_t))
filepaths_column_name = get_context().runner().runner_io().FS_LISTING_PATH_COLUMN_NAME
pyschema = Schema._from_pyschema(schema)
return physical_plan.file_read(
file_info_iter, limit, Schema._from_pyschema(schema), None, None, file_format_config, filepaths_column_name
child_plan=file_info_iter,
limit_rows=limit,
schema=pyschema,
fs=None,
columns_to_read=columns_to_read,
file_format_config=file_format_config,
filepaths_column_name=filepaths_column_name,
)


Expand Down
42 changes: 22 additions & 20 deletions src/daft-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,29 +133,31 @@ impl Eq for Schema {}

impl Hash for Schema {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// Must preserve x == y --> hash(x) == hash(y).
// Since IndexMap implements order-independent equality, we must implement an order-independent hashing scheme.
// We achieve this by combining the hashes of key-value pairs with an associative + commutative operation so
// order does not matter, i.e. (u64, *, 0) must form a commutative monoid. This is satisfied by * = u64::wrapping_add.
//
// Moreover, the hashing of each individual element must be independent of the hashing of other elements, so we hash
// each element with a fresh state (hasher).
//
// NOTE: This is a relatively weak hash function, but should be fine for our current hashing use case, which is detecting
// logical optimization cycles in the optimizer.
state.write_u64(
self.fields
.iter()
.map(|kv| {
let mut h = DefaultHasher::new();
kv.hash(&mut h);
h.finish()
})
.fold(0, u64::wrapping_add),
)
state.write_u64(hash_index_map(&self.fields))
}
}

pub fn hash_index_map<K: Hash, V: Hash>(indexmap: &indexmap::IndexMap<K, V>) -> u64 {
// Must preserve x == y --> hash(x) == hash(y).
// Since IndexMap implements order-independent equality, we must implement an order-independent hashing scheme.
// We achieve this by combining the hashes of key-value pairs with an associative + commutative operation so
// order does not matter, i.e. (u64, *, 0) must form a commutative monoid. This is satisfied by * = u64::wrapping_add.
//
// Moreover, the hashing of each individual element must be independent of the hashing of other elements, so we hash
// each element with a fresh state (hasher).
//
// NOTE: This is a relatively weak hash function, but should be fine for our current hashing use case, which is detecting
// logical optimization cycles in the optimizer.
indexmap
.iter()
.map(|kv| {
let mut h = DefaultHasher::new();
kv.hash(&mut h);
h.finish()
})
.fold(0, u64::wrapping_add)
}

impl Default for Schema {
fn default() -> Self {
Self::empty()
Expand Down
3 changes: 2 additions & 1 deletion src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ impl LogicalPlanBuilder {
partition_spec: &PartitionSpec,
) -> PyResult<LogicalPlanBuilder> {
let source_info = SourceInfo::InMemoryInfo(InMemoryInfo::new(
schema.schema.clone(),
partition_key.into(),
cache_entry.to_object(cache_entry.py()),
));
Expand Down Expand Up @@ -225,7 +226,7 @@ impl LogicalPlanBuilder {
)));
}
let logical_plan: LogicalPlan =
ops::Concat::new(other.plan.clone(), self.plan.clone()).into();
ops::Concat::new(self.plan.clone(), other.plan.clone()).into();
Ok(logical_plan.into())
}

Expand Down
80 changes: 78 additions & 2 deletions src/daft-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::{cmp::max, num::NonZeroUsize, sync::Arc};

use common_error::DaftError;
use daft_core::schema::SchemaRef;
use daft_dsl::{optimization::get_required_columns, Expr};
use indexmap::IndexSet;
use snafu::Snafu;

use crate::{display::TreeDisplay, ops::*, PartitionScheme, PartitionSpec};
Expand All @@ -26,7 +28,7 @@ pub enum LogicalPlan {
impl LogicalPlan {
pub fn schema(&self) -> SchemaRef {
match self {
Self::Source(Source { schema, .. }) => schema.clone(),
Self::Source(Source { output_schema, .. }) => output_schema.clone(),
Self::Project(Project {
projected_schema, ..
}) => projected_schema.clone(),
Expand All @@ -46,6 +48,80 @@ impl LogicalPlan {
}
}

pub fn required_columns(&self) -> Vec<IndexSet<String>> {
// TODO: https://github.com/Eventual-Inc/Daft/pull/1288#discussion_r1307820697
match self {
Self::Limit(..) | Self::Coalesce(..) => vec![IndexSet::new()],
Self::Concat(..) => vec![IndexSet::new(), IndexSet::new()],
Self::Project(projection) => {
let res = projection
.projection
.iter()
.flat_map(get_required_columns)
.collect();
vec![res]
}
Self::Filter(filter) => {
vec![get_required_columns(&filter.predicate)
.iter()
.cloned()
.collect()]
}
Self::Sort(sort) => {
let res = sort.sort_by.iter().flat_map(get_required_columns).collect();
vec![res]
}
Self::Repartition(repartition) => {
let res = repartition
.partition_by
.iter()
.flat_map(get_required_columns)
.collect();
vec![res]
}
Self::Explode(explode) => {
let res = explode
.to_explode
.iter()
.flat_map(get_required_columns)
.collect();
vec![res]
}
Self::Distinct(distinct) => {
let res = distinct
.input
.schema()
.fields
.iter()
.map(|(name, _)| name)
.cloned()
.collect();
vec![res]
}
Self::Aggregate(aggregate) => {
let res = aggregate
.aggregations
.iter()
.map(|agg| get_required_columns(&Expr::Agg(agg.clone())))
.chain(aggregate.groupby.iter().map(get_required_columns))
.flatten()
.collect();
vec![res]
}
Self::Join(join) => {
let left = join.left_on.iter().flat_map(get_required_columns).collect();
let right = join
.right_on
.iter()
.flat_map(get_required_columns)
.collect();
vec![left, right]
}
Self::Source(_) => todo!(),
Self::Sink(_) => todo!(),
}
}

pub fn partition_spec(&self) -> Arc<PartitionSpec> {
match self {
Self::Source(Source { partition_spec, .. }) => partition_spec.clone(),
Expand Down Expand Up @@ -147,7 +223,7 @@ impl LogicalPlan {
},
[input1, input2] => match self {
Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"),
Self::Concat(_) => Self::Concat(Concat::new(input2.clone(), input1.clone())),
Self::Concat(_) => Self::Concat(Concat::new(input1.clone(), input2.clone())),
Self::Join(Join { left_on, right_on, join_type, .. }) => Self::Join(Join::try_new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type).unwrap()),
_ => panic!("Logical op {} has one input, but got two", self),
},
Expand Down
8 changes: 4 additions & 4 deletions src/daft-plan/src/ops/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use crate::LogicalPlan;

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Concat {
pub other: Arc<LogicalPlan>,
// Upstream node.
// Upstream nodes.
pub input: Arc<LogicalPlan>,
pub other: Arc<LogicalPlan>,
}

impl Concat {
pub(crate) fn new(other: Arc<LogicalPlan>, input: Arc<LogicalPlan>) -> Self {
Self { other, input }
pub(crate) fn new(input: Arc<LogicalPlan>, other: Arc<LogicalPlan>) -> Self {
Self { input, other }
}
}
29 changes: 25 additions & 4 deletions src/daft-plan/src/ops/join.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{collections::HashSet, sync::Arc};

use daft_core::schema::{Schema, SchemaRef};
use daft_core::schema::{hash_index_map, Schema, SchemaRef};
use daft_dsl::Expr;
use snafu::ResultExt;

Expand All @@ -9,16 +9,32 @@ use crate::{
JoinType, LogicalPlan,
};

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Join {
// Upstream nodes.
pub input: Arc<LogicalPlan>,
pub right: Arc<LogicalPlan>,

pub left_on: Vec<Expr>,
pub right_on: Vec<Expr>,
pub output_schema: SchemaRef,
pub join_type: JoinType,
pub output_schema: SchemaRef,

// Joins may rename columns from the right input; this struct tracks those renames.
// Output name -> Original name
pub right_input_mapping: indexmap::IndexMap<String, String>,
}

impl std::hash::Hash for Join {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
std::hash::Hash::hash(&self.input, state);
std::hash::Hash::hash(&self.right, state);
std::hash::Hash::hash(&self.left_on, state);
std::hash::Hash::hash(&self.right_on, state);
std::hash::Hash::hash(&self.join_type, state);
std::hash::Hash::hash(&self.output_schema, state);
state.write_u64(hash_index_map(&self.right_input_mapping))
}
}

impl Join {
Expand All @@ -29,6 +45,7 @@ impl Join {
right_on: Vec<Expr>,
join_type: JoinType,
) -> logical_plan::Result<Self> {
let mut right_input_mapping = indexmap::IndexMap::new();
// Schema inference ported from existing behaviour for parity,
// but contains bug https://github.com/Eventual-Inc/Daft/issues/1294
let output_schema = {
Expand All @@ -44,11 +61,14 @@ impl Join {
.cloned()
.chain(right.schema().fields.iter().filter_map(|(rname, rfield)| {
if left_join_keys.contains(rname.as_str()) {
right_input_mapping.insert(rname.clone(), rname.clone());
None
} else if left_schema.contains_key(rname) {
let new_name = format!("right.{}", rname);
right_input_mapping.insert(new_name.clone(), rname.clone());
Some(rfield.rename(new_name))
} else {
right_input_mapping.insert(rname.clone(), rname.clone());
Some(rfield.clone())
}
}))
Expand All @@ -60,8 +80,9 @@ impl Join {
right,
left_on,
right_on,
output_schema,
join_type,
output_schema,
right_input_mapping,
})
}

Expand Down
19 changes: 11 additions & 8 deletions src/daft-plan/src/ops/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
pub struct Source {
/// The schema of the output of this node (the source data schema).
/// May be a subset of the source data schema; executors should push down this projection if possible.
pub schema: SchemaRef,
pub output_schema: SchemaRef,

/// Information about the source data location.
pub source_info: Arc<SourceInfo>,
Expand All @@ -27,12 +27,12 @@ pub struct Source {

impl Source {
pub(crate) fn new(
schema: SchemaRef,
output_schema: SchemaRef,
source_info: Arc<SourceInfo>,
partition_spec: Arc<PartitionSpec>,
) -> Self {
Self {
schema,
output_schema,
source_info,
partition_spec,
filters: vec![], // Will be populated by plan optimizer.
Expand All @@ -42,7 +42,7 @@ impl Source {

pub fn with_limit(&self, limit: Option<usize>) -> Self {
Self {
schema: self.schema.clone(),
output_schema: self.output_schema.clone(),
source_info: self.source_info.clone(),
partition_spec: self.partition_spec.clone(),
filters: self.filters.clone(),
Expand All @@ -52,7 +52,7 @@ impl Source {

pub fn with_filters(&self, filters: Vec<ExprRef>) -> Self {
Self {
schema: self.schema.clone(),
output_schema: self.output_schema.clone(),
source_info: self.source_info.clone(),
partition_spec: self.partition_spec.clone(),
filters,
Expand All @@ -65,21 +65,24 @@ impl Source {

match self.source_info.as_ref() {
SourceInfo::ExternalInfo(ExternalInfo {
schema,
source_schema,
file_info,
file_format_config,
}) => {
res.push(format!("Source: {:?}", file_format_config.var_name()));
for fp in file_info.file_paths.iter() {
res.push(format!("File paths = {}", fp));
}
res.push(format!("File schema = {}", schema.short_string()));
res.push(format!("File schema = {}", source_schema.short_string()));
res.push(format!("Format-specific config = {:?}", file_format_config));
}
#[cfg(feature = "python")]
SourceInfo::InMemoryInfo(_) => {}
}
res.push(format!("Output schema = {}", self.schema.short_string()));
res.push(format!(
"Output schema = {}",
self.output_schema.short_string()
));
if !self.filters.is_empty() {
res.push(format!("Filters = {:?}", self.filters));
}
Expand Down
Loading

0 comments on commit 337f0fc

Please sign in to comment.