From 1c1abfbe82bfaa0e9a764c472bff20088df2ff32 Mon Sep 17 00:00:00 2001 From: xcharleslin <4212216+xcharleslin@users.noreply.github.com> Date: Thu, 24 Aug 2023 15:22:23 -0700 Subject: [PATCH] [CHORE] Move schema construction under LogicalPlan construction (#1290) Currently, the schema for certain LogicalPlan nodes is being passed in explicitly during construction, which means: 1. The schema passed in can be incorrect 2. Schema resolution logic is duplicated for all node construction This PR addresses these by moving all schema resolution logic to underneath LogicalPlan construction. Note that this means the constructors are now fallible, returning a new snafu-based `logical_plan::Result`. --------- Co-authored-by: Xiayue Charles Lin Co-authored-by: Clark Zinzow --- Cargo.lock | 1 + daft/execution/execution_step.py | 2 - daft/execution/physical_plan.py | 2 - daft/execution/physical_plan_factory.py | 1 - daft/execution/rust_physical_plan_shim.py | 7 +- daft/logical/rust_logical_plan.py | 32 +------ daft/table/table.py | 1 - src/common/error/src/error.rs | 17 ++++ src/daft-plan/Cargo.toml | 1 + src/daft-plan/src/builder.rs | 62 +++---------- src/daft-plan/src/logical_plan.rs | 42 +++++++-- src/daft-plan/src/ops/agg.rs | 31 +++++-- src/daft-plan/src/ops/explode.rs | 50 ++++++++--- src/daft-plan/src/ops/join.rs | 57 ++++++++---- src/daft-plan/src/ops/project.rs | 20 +++-- .../optimization/rules/push_down_filter.rs | 86 ++++++------------- src/daft-plan/src/physical_ops/explode.rs | 9 +- src/daft-plan/src/physical_ops/join.rs | 3 - src/daft-plan/src/physical_plan.rs | 14 +-- src/daft-plan/src/planner.rs | 8 +- 20 files changed, 221 insertions(+), 225 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4443676e34..3077cc5308 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1116,6 +1116,7 @@ dependencies = [ "indexmap 2.0.0", "pyo3", "serde", + "snafu", ] [[package]] diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 701d82e441..66b5a7e539 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -641,7 +641,6 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) class Join(SingleOutputInstruction): left_on: ExpressionsProjection right_on: ExpressionsProjection - output_projection: ExpressionsProjection how: JoinType def run(self, inputs: list[Table]) -> list[Table]: @@ -653,7 +652,6 @@ def _join(self, inputs: list[Table]) -> list[Table]: right, left_on=self.left_on, right_on=self.right_on, - output_projection=self.output_projection, how=self.how, ) return [result] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index cf6bc01cee..84bb4ca656 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -172,7 +172,6 @@ def join( right_plan: InProgressPhysicalPlan[PartitionT], left_on: ExpressionsProjection, right_on: ExpressionsProjection, - output_projection: ExpressionsProjection, how: JoinType, ) -> InProgressPhysicalPlan[PartitionT]: """Pairwise join the partitions from `left_child_plan` and `right_child_plan` together.""" @@ -202,7 +201,6 @@ def join( instruction=execution_step.Join( left_on=left_on, right_on=right_on, - output_projection=output_projection, how=how, ) ) diff --git a/daft/execution/physical_plan_factory.py b/daft/execution/physical_plan_factory.py index 1fe1ee08ce..45ddd7a823 100644 --- a/daft/execution/physical_plan_factory.py +++ b/daft/execution/physical_plan_factory.py @@ -175,7 +175,6 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> right_plan=_get_physical_plan(right_child, psets), left_on=node._left_on, right_on=node._right_on, - output_projection=node._output_projection, how=node._how, ) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 977b3c8953..2f6a3a41cd 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -68,7 +68,9 @@ def run(self, input_partition: Table) -> Table: def explode( input: physical_plan.InProgressPhysicalPlan[PartitionT], explode_exprs: list[PyExpr] ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: - explode_expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in explode_exprs]) + explode_expr_projection = ExpressionsProjection( + [Expression._from_pyexpr(expr)._explode() for expr in explode_exprs] + ) explode_op = ShimExplodeOp(explode_expr_projection) return physical_plan.pipeline_instruction( child_plan=input, @@ -138,18 +140,15 @@ def join( right: physical_plan.InProgressPhysicalPlan[PartitionT], left_on: list[PyExpr], right_on: list[PyExpr], - output_projection: list[PyExpr], join_type: JoinType, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) - output_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in output_projection]) return physical_plan.join( left_plan=input, right_plan=right, left_on=left_on_expr_proj, right_on=right_on_expr_proj, - output_projection=output_expr_proj, how=join_type, ) diff --git a/daft/logical/rust_logical_plan.py b/daft/logical/rust_logical_plan.py index b4cf16d361..771e690f33 100644 --- a/daft/logical/rust_logical_plan.py +++ b/daft/logical/rust_logical_plan.py @@ -93,9 +93,8 @@ def project( projection: ExpressionsProjection, custom_resource_request: ResourceRequest = ResourceRequest(), ) -> RustLogicalPlanBuilder: - schema = projection.resolve_schema(self.schema()) exprs = [expr._expr for expr in projection] - builder = self._builder.project(exprs, schema._schema, custom_resource_request) + builder = self._builder.project(exprs, custom_resource_request) return RustLogicalPlanBuilder(builder) def filter(self, predicate: Expression) -> RustLogicalPlanBuilder: @@ -116,29 +115,15 @@ def limit(self, num_rows: int) -> RustLogicalPlanBuilder: return RustLogicalPlanBuilder(builder) def explode(self, explode_expressions: ExpressionsProjection) -> RustLogicalPlanBuilder: - # TODO(Clark): Move this logic to Rust side after we've ported ExpressionsProjection. - explode_expressions = ExpressionsProjection([expr._explode() for expr in explode_expressions]) - input_schema = self.schema() - explode_schema = explode_expressions.resolve_schema(input_schema) - output_fields = [] - for f in input_schema: - if f.name in explode_schema.column_names(): - output_fields.append(explode_schema[f.name]) - else: - output_fields.append(f) - - exploded_schema = Schema._from_field_name_and_types([(f.name, f.dtype) for f in output_fields]) explode_pyexprs = [expr._expr for expr in explode_expressions] - builder = self._builder.explode(explode_pyexprs, exploded_schema._schema) + builder = self._builder.explode(explode_pyexprs) return RustLogicalPlanBuilder(builder) def count(self) -> RustLogicalPlanBuilder: # TODO(Clark): Add dedicated logical/physical ops when introducing metadata-based count optimizations. first_col = col(self.schema().column_names()[0]) builder = self._builder.aggregate([first_col._count(CountMode.All)._expr], []) - rename_expr = ExpressionsProjection([first_col.alias("count")]) - schema = rename_expr.resolve_schema(Schema._from_pyschema(builder.schema())) - builder = builder.project(rename_expr.to_inner_py_exprs(), schema._schema, ResourceRequest()) + builder = builder.project([first_col.alias("count")._expr], ResourceRequest()) return RustLogicalPlanBuilder(builder) def distinct(self) -> RustLogicalPlanBuilder: @@ -220,21 +205,10 @@ def join( # type: ignore[override] elif how == JoinType.Right: raise NotImplementedError("Right join not implemented.") elif how == JoinType.Inner: - # TODO(Clark): Port this logic to Rust-side once ExpressionsProjection has been ported. - right_drop_set = {r.name() for l, r in zip(left_on, right_on) if l.name() == r.name()} - left_columns = ExpressionsProjection.from_schema(self.schema()) - right_columns = ExpressionsProjection([col(f.name) for f in right.schema() if f.name not in right_drop_set]) - output_projection = left_columns.union(right_columns, rename_dup="right.") - right_columns = ExpressionsProjection(list(output_projection)[len(left_columns) :]) - output_schema = left_columns.resolve_schema(self.schema()).union( - right_columns.resolve_schema(right.schema()) - ) builder = self._builder.join( right._builder, left_on.to_inner_py_exprs(), right_on.to_inner_py_exprs(), - output_projection.to_inner_py_exprs(), - output_schema._schema, how, ) return RustLogicalPlanBuilder(builder) diff --git a/daft/table/table.py b/daft/table/table.py index ffc221c0fd..5bc12d6447 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -293,7 +293,6 @@ def join( right: Table, left_on: ExpressionsProjection, right_on: ExpressionsProjection, - output_projection: ExpressionsProjection | None = None, how: JoinType = JoinType.Inner, ) -> Table: if how != JoinType.Inner: diff --git a/src/common/error/src/error.rs b/src/common/error/src/error.rs index 4a78137d2f..604b139e48 100644 --- a/src/common/error/src/error.rs +++ b/src/common/error/src/error.rs @@ -23,6 +23,23 @@ pub enum DaftError { External(GenericError), } +impl std::error::Error for DaftError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DaftError::FieldNotFound(_) + | DaftError::SchemaMismatch(_) + | DaftError::TypeError(_) + | DaftError::ComputeError(_) + | DaftError::ArrowError(_) + | DaftError::ValueError(_) => None, + DaftError::IoError(io_error) => Some(io_error), + DaftError::FileNotFound { source, .. } | DaftError::External(source) => Some(&**source), + #[cfg(feature = "python")] + DaftError::PyO3Error(pyerr) => Some(pyerr), + } + } +} + impl From for DaftError { fn from(error: arrow2::error::Error) -> Self { DaftError::ArrowError(error.to_string()) diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 556b192797..7a80697d3e 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -9,6 +9,7 @@ daft-table = {path = "../daft-table", default-features = false} indexmap = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} +snafu = {workspace = true} [features] default = ["python"] diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index d113523900..cafd63508d 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -1,7 +1,5 @@ use std::sync::Arc; -use common_error::DaftResult; - use crate::{logical_plan::LogicalPlan, optimization::Optimizer, ResourceRequest}; #[cfg(feature = "python")] @@ -84,20 +82,14 @@ impl LogicalPlanBuilder { pub fn project( &self, projection: Vec, - projected_schema: &PySchema, resource_request: ResourceRequest, ) -> PyResult { let projection_exprs = projection .iter() .map(|e| e.clone().into()) .collect::>(); - let logical_plan: LogicalPlan = ops::Project::new( - projection_exprs, - projected_schema.clone().into(), - resource_request, - self.plan.clone(), - ) - .into(); + let logical_plan: LogicalPlan = + ops::Project::new(projection_exprs, resource_request, self.plan.clone())?.into(); Ok(logical_plan.into()) } @@ -112,21 +104,13 @@ impl LogicalPlanBuilder { Ok(logical_plan.into()) } - pub fn explode( - &self, - explode_pyexprs: Vec, - exploded_schema: &PySchema, - ) -> PyResult { - let explode_exprs = explode_pyexprs + pub fn explode(&self, to_explode_pyexprs: Vec) -> PyResult { + let to_explode = to_explode_pyexprs .iter() .map(|e| e.clone().into()) .collect::>(); - let logical_plan: LogicalPlan = ops::Explode::new( - explode_exprs, - exploded_schema.clone().into(), - self.plan.clone(), - ) - .into(); + let logical_plan: LogicalPlan = + ops::Explode::try_new(self.plan.clone(), to_explode)?.into(); Ok(logical_plan.into()) } @@ -199,24 +183,8 @@ impl LogicalPlanBuilder { .map(|expr| expr.clone().into()) .collect::>(); - let input_schema = self.plan.schema(); - let fields = groupby_exprs - .iter() - .map(|expr| expr.to_field(&input_schema)) - .chain( - agg_exprs - .iter() - .map(|agg_expr| agg_expr.to_field(&input_schema)), - ) - .collect::>>()?; - let output_schema = Schema::new(fields)?; - let logical_plan: LogicalPlan = Aggregate::new( - agg_exprs, - groupby_exprs, - output_schema.into(), - self.plan.clone(), - ) - .into(); + let logical_plan: LogicalPlan = + Aggregate::try_new(self.plan.clone(), agg_exprs, groupby_exprs)?.into(); Ok(logical_plan.into()) } @@ -225,8 +193,6 @@ impl LogicalPlanBuilder { other: &Self, left_on: Vec, right_on: Vec, - output_projection: Vec, - output_schema: &PySchema, join_type: JoinType, ) -> PyResult { let left_on_exprs = left_on @@ -237,19 +203,13 @@ impl LogicalPlanBuilder { .iter() .map(|e| e.clone().into()) .collect::>(); - let output_projection_exprs = output_projection - .iter() - .map(|e| e.clone().into()) - .collect::>(); - let logical_plan: LogicalPlan = ops::Join::new( + let logical_plan: LogicalPlan = ops::Join::try_new( + self.plan.clone(), other.plan.clone(), left_on_exprs, right_on_exprs, - output_projection_exprs, - output_schema.clone().into(), join_type, - self.plan.clone(), - ) + )? .into(); Ok(logical_plan.into()) } diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 328e5e291d..b4757f3e6e 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -1,6 +1,8 @@ use std::{cmp::max, sync::Arc}; +use common_error::DaftError; use daft_core::schema::SchemaRef; +use snafu::Snafu; use crate::{ops::*, PartitionScheme, PartitionSpec}; @@ -129,24 +131,24 @@ impl LogicalPlan { let new_plan = match children { [input] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), - Self::Project(Project { projection, projected_schema, resource_request, .. }) => Self::Project(Project::new( - projection.clone(), projected_schema.clone(), resource_request.clone(), input.clone(), - )), + Self::Project(Project { projection, resource_request, .. }) => Self::Project(Project::new( + projection.clone(), resource_request.clone(), input.clone(), + ).unwrap()), Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::new(predicate.clone(), input.clone())), Self::Limit(Limit { limit, .. }) => Self::Limit(Limit::new(*limit, input.clone())), - Self::Explode(Explode { explode_exprs, exploded_schema, .. }) => Self::Explode(Explode::new(explode_exprs.clone(), exploded_schema.clone(), input.clone())), + Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()), Self::Sort(Sort { sort_by, descending, .. }) => Self::Sort(Sort::new(sort_by.clone(), descending.clone(), input.clone())), Self::Repartition(Repartition { num_partitions, partition_by, scheme, .. }) => Self::Repartition(Repartition::new(*num_partitions, partition_by.clone(), scheme.clone(), input.clone())), Self::Coalesce(Coalesce { num_to, .. }) => Self::Coalesce(Coalesce::new(*num_to, input.clone())), Self::Distinct(_) => Self::Distinct(Distinct::new(input.clone())), - Self::Aggregate(Aggregate { aggregations, groupby, output_schema, ..}) => Self::Aggregate(Aggregate::new(aggregations.clone(), groupby.clone(), output_schema.clone(), input.clone())), + Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.clone(), groupby.clone()).unwrap()), Self::Sink(Sink { schema, sink_info, .. }) => Self::Sink(Sink::new(schema.clone(), sink_info.clone(), input.clone())), _ => panic!("Logical op {} has two inputs, but got one", self), }, [input1, input2] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), Self::Concat(_) => Self::Concat(Concat::new(input2.clone(), input1.clone())), - Self::Join(Join { left_on, right_on, output_projection, output_schema, join_type, .. }) => Self::Join(Join::new(input2.clone(), left_on.clone(), right_on.clone(), output_projection.clone(), output_schema.clone(), *join_type, input1.clone())), + Self::Join(Join { left_on, right_on, join_type, .. }) => Self::Join(Join::try_new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type).unwrap()), _ => panic!("Logical op {} has one input, but got two", self), }, _ => panic!("Logical ops should never have more than 2 inputs, but got: {}", children.len()) @@ -169,8 +171,8 @@ impl LogicalPlan { } Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")], Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")], - Self::Explode(Explode { explode_exprs, .. }) => { - vec![format!("Explode: {explode_exprs:?}")] + Self::Explode(Explode { to_explode, .. }) => { + vec![format!("Explode: {to_explode:?}")] } Self::Sort(sort) => sort.multiline_display(), Self::Repartition(repartition) => repartition.multiline_display(), @@ -190,6 +192,30 @@ impl LogicalPlan { } } +#[derive(Debug, Snafu)] +#[snafu(visibility(pub(crate)))] +pub(crate) enum Error { + #[snafu(display("Unable to create logical plan node due to: {}", source))] + CreationError { source: DaftError }, +} +pub(crate) type Result = std::result::Result; + +impl From for DaftError { + fn from(err: Error) -> DaftError { + match err { + Error::CreationError { source } => source, + } + } +} + +#[cfg(feature = "python")] +impl std::convert::From for pyo3::PyErr { + fn from(value: Error) -> Self { + let daft_error: DaftError = value.into(); + daft_error.into() + } +} + macro_rules! impl_from_data_struct_for_logical_plan { ($name:ident) => { impl From<$name> for LogicalPlan { diff --git a/src/daft-plan/src/ops/agg.rs b/src/daft-plan/src/ops/agg.rs index 0cd2aa20f5..463def076a 100644 --- a/src/daft-plan/src/ops/agg.rs +++ b/src/daft-plan/src/ops/agg.rs @@ -1,12 +1,18 @@ use std::sync::Arc; +use snafu::ResultExt; + use daft_core::schema::{Schema, SchemaRef}; use daft_dsl::{AggExpr, Expr}; +use crate::logical_plan::{self, CreationSnafu}; use crate::LogicalPlan; #[derive(Clone, Debug)] pub struct Aggregate { + // Upstream node. + pub input: Arc, + /// Aggregations to apply. pub aggregations: Vec, @@ -14,24 +20,31 @@ pub struct Aggregate { pub groupby: Vec, pub output_schema: SchemaRef, - - // Upstream node. - pub input: Arc, } impl Aggregate { - pub(crate) fn new( + pub(crate) fn try_new( + input: Arc, aggregations: Vec, groupby: Vec, - output_schema: SchemaRef, - input: Arc, - ) -> Self { - Self { + ) -> logical_plan::Result { + let output_schema = { + let upstream_schema = input.schema(); + let fields = groupby + .iter() + .map(|e| e.to_field(&upstream_schema)) + .chain(aggregations.iter().map(|ae| ae.to_field(&upstream_schema))) + .collect::>>() + .context(CreationSnafu)?; + Schema::new(fields).context(CreationSnafu)?.into() + }; + + Ok(Self { aggregations, groupby, output_schema, input, - } + }) } pub(crate) fn schema(&self) -> SchemaRef { diff --git a/src/daft-plan/src/ops/explode.rs b/src/daft-plan/src/ops/explode.rs index 9f98c3c487..fe8ad037ff 100644 --- a/src/daft-plan/src/ops/explode.rs +++ b/src/daft-plan/src/ops/explode.rs @@ -1,28 +1,54 @@ use std::sync::Arc; -use daft_core::schema::SchemaRef; +use daft_core::schema::{Schema, SchemaRef}; use daft_dsl::Expr; +use snafu::ResultExt; -use crate::LogicalPlan; +use crate::{ + logical_plan::{self, CreationSnafu}, + LogicalPlan, +}; #[derive(Clone, Debug)] pub struct Explode { - pub explode_exprs: Vec, - pub exploded_schema: SchemaRef, // Upstream node. pub input: Arc, + // Expressions to explode. e.g. col("a") + pub to_explode: Vec, + pub exploded_schema: SchemaRef, } impl Explode { - pub(crate) fn new( - explode_exprs: Vec, - exploded_schema: SchemaRef, + pub(crate) fn try_new( input: Arc, - ) -> Self { - Self { - explode_exprs, - exploded_schema, + to_explode: Vec, + ) -> logical_plan::Result { + let explode_exprs = to_explode + .iter() + .map(daft_dsl::functions::list::explode) + .collect::>(); + let exploded_schema = { + let upstream_schema = input.schema(); + let explode_schema = { + let explode_fields = explode_exprs + .iter() + .map(|e| e.to_field(&upstream_schema)) + .collect::>>() + .context(CreationSnafu)?; + Schema::new(explode_fields).context(CreationSnafu)? + }; + let fields = upstream_schema + .fields + .iter() + .map(|(name, field)| explode_schema.fields.get(name).unwrap_or(field)) + .cloned() + .collect::>(); + Schema::new(fields).context(CreationSnafu)?.into() + }; + Ok(Self { input, - } + to_explode, + exploded_schema, + }) } } diff --git a/src/daft-plan/src/ops/join.rs b/src/daft-plan/src/ops/join.rs index be56bbc81b..6a153de947 100644 --- a/src/daft-plan/src/ops/join.rs +++ b/src/daft-plan/src/ops/join.rs @@ -1,41 +1,68 @@ -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; -use daft_core::schema::SchemaRef; +use daft_core::schema::{Schema, SchemaRef}; use daft_dsl::Expr; +use snafu::ResultExt; -use crate::{JoinType, LogicalPlan}; +use crate::{ + logical_plan::{self, CreationSnafu}, + JoinType, LogicalPlan, +}; #[derive(Clone, Debug)] pub struct Join { + // Upstream nodes. + pub input: Arc, pub right: Arc, + pub left_on: Vec, pub right_on: Vec, - pub output_projection: Vec, pub output_schema: SchemaRef, pub join_type: JoinType, - // Upstream node. - pub input: Arc, } impl Join { - pub(crate) fn new( + pub(crate) fn try_new( + input: Arc, right: Arc, left_on: Vec, right_on: Vec, - output_projection: Vec, - output_schema: SchemaRef, join_type: JoinType, - input: Arc, - ) -> Self { - Self { + ) -> logical_plan::Result { + // Schema inference ported from existing behaviour for parity, + // but contains bug https://github.com/Eventual-Inc/Daft/issues/1294 + let output_schema = { + let left_join_keys = left_on + .iter() + .map(|e| e.name()) + .collect::>>() + .context(CreationSnafu)?; + let left_schema = &input.schema().fields; + let fields = left_schema + .iter() + .map(|(_, field)| field) + .cloned() + .chain(right.schema().fields.iter().filter_map(|(rname, rfield)| { + if left_join_keys.contains(rname.as_str()) { + None + } else if left_schema.contains_key(rname) { + let new_name = format!("right.{}", rname); + Some(rfield.rename(new_name)) + } else { + Some(rfield.clone()) + } + })) + .collect::>(); + Schema::new(fields).context(CreationSnafu)?.into() + }; + Ok(Self { + input, right, left_on, right_on, - output_projection, output_schema, join_type, - input, - } + }) } pub fn multiline_display(&self) -> Vec { diff --git a/src/daft-plan/src/ops/project.rs b/src/daft-plan/src/ops/project.rs index fbd74c7d65..72cfcb3320 100644 --- a/src/daft-plan/src/ops/project.rs +++ b/src/daft-plan/src/ops/project.rs @@ -1,8 +1,10 @@ use std::sync::Arc; -use daft_core::schema::SchemaRef; +use daft_core::schema::{Schema, SchemaRef}; use daft_dsl::Expr; +use snafu::ResultExt; +use crate::logical_plan::{CreationSnafu, Result}; use crate::{LogicalPlan, ResourceRequest}; #[derive(Clone, Debug)] @@ -17,15 +19,23 @@ pub struct Project { impl Project { pub(crate) fn new( projection: Vec, - projected_schema: SchemaRef, resource_request: ResourceRequest, input: Arc, - ) -> Self { - Self { + ) -> Result { + let upstream_schema = input.schema(); + let projected_schema = { + let fields = projection + .iter() + .map(|e| e.to_field(&upstream_schema)) + .collect::>>() + .context(CreationSnafu)?; + Schema::new(fields).context(CreationSnafu)?.into() + }; + Ok(Self { projection, projected_schema, resource_request, input, - } + }) } } diff --git a/src/daft-plan/src/optimization/rules/push_down_filter.rs b/src/daft-plan/src/optimization/rules/push_down_filter.rs index aad1d777ed..a43dc9b079 100644 --- a/src/daft-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/optimization/rules/push_down_filter.rs @@ -115,10 +115,9 @@ impl OptimizerRule for PushDownFilter { // Create new Projection. let new_projection: LogicalPlan = Project::new( child_project.projection.clone(), - child_project.projected_schema.clone(), child_project.resource_request.clone(), push_down_filter.into(), - ) + )? .into(); if can_not_push.is_empty() { // If all Filter predicate expressions were pushable past Projection, return new @@ -272,13 +271,8 @@ mod tests { Field::new("b", DataType::Utf8), ]) .into(); - let projection: LogicalPlan = Project::new( - vec![col("a")], - Schema::new(vec![source.schema().get_field("a")?.clone()])?.into(), - Default::default(), - source.into(), - ) - .into(); + let projection: LogicalPlan = + Project::new(vec![col("a")], Default::default(), source.into())?.into(); let filter: LogicalPlan = Filter::new(col("a").lt(&lit(2)), projection.into()).into(); let expected = "\ Project: col(a)\ @@ -295,13 +289,8 @@ mod tests { Field::new("b", DataType::Utf8), ]) .into(); - let projection: LogicalPlan = Project::new( - vec![col("a"), col("b")], - source.schema().clone(), - Default::default(), - source.into(), - ) - .into(); + let projection: LogicalPlan = + Project::new(vec![col("a"), col("b")], Default::default(), source.into())?.into(); let filter: LogicalPlan = Filter::new( col("a").lt(&lit(2)).and(&col("b").eq(&lit("foo"))), projection.into(), @@ -323,13 +312,8 @@ mod tests { ]) .into(); // Projection involves compute on filtered column "a". - let projection: LogicalPlan = Project::new( - vec![col("a") + lit(1)], - Schema::new(vec![source.schema().get_field("a")?.clone()])?.into(), - Default::default(), - source.into(), - ) - .into(); + let projection: LogicalPlan = + Project::new(vec![col("a") + lit(1)], Default::default(), source.into())?.into(); let filter: LogicalPlan = Filter::new(col("a").lt(&lit(2)), projection.into()).into(); // Filter should NOT commute with Project, since this would involve redundant computation. let expected = "\ @@ -349,13 +333,8 @@ mod tests { Field::new("b", DataType::Utf8), ]) .into(); - let projection: LogicalPlan = Project::new( - vec![col("a") + lit(1)], - Schema::new(vec![source.schema().get_field("a")?.clone()])?.into(), - Default::default(), - source.into(), - ) - .into(); + let projection: LogicalPlan = + Project::new(vec![col("a") + lit(1)], Default::default(), source.into())?.into(); let filter: LogicalPlan = Filter::new(col("a").lt(&lit(2)), projection.into()).into(); let expected = "\ Project: col(a) + lit(1)\ @@ -451,16 +430,13 @@ mod tests { Field::new("c", DataType::Float64), ]) .into(); - let output_schema = source1.schema().union(source2.schema().as_ref())?; - let join: LogicalPlan = Join::new( + let join: LogicalPlan = Join::try_new( + source1.into(), source2.into(), vec![col("b")], vec![col("b")], - vec![], - output_schema.into(), JoinType::Inner, - source1.into(), - ) + )? .into(); let filter: LogicalPlan = Filter::new(col("a").lt(&lit(2)), join.into()).into(); let expected = "\ @@ -484,16 +460,13 @@ mod tests { Field::new("c", DataType::Float64), ]) .into(); - let output_schema = source1.schema().union(source2.schema().as_ref())?; - let join: LogicalPlan = Join::new( + let join: LogicalPlan = Join::try_new( + source1.into(), source2.into(), vec![col("b")], vec![col("b")], - vec![], - output_schema.into(), JoinType::Inner, - source1.into(), - ) + )? .into(); let filter: LogicalPlan = Filter::new(col("c").lt(&lit(2.0)), join.into()).into(); let expected = "\ @@ -509,33 +482,26 @@ mod tests { fn filter_commutes_with_join_both_sides() -> DaftResult<()> { let source1: LogicalPlan = dummy_scan_node(vec![ Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8), + Field::new("b", DataType::Int64), Field::new("c", DataType::Float64), ]) .into(); - let source2: LogicalPlan = dummy_scan_node(vec![ - Field::new("b", DataType::Utf8), - Field::new("c", DataType::Float64), - ]) - .into(); - let output_schema = source1.schema().union(source2.schema().as_ref())?; - let join: LogicalPlan = Join::new( + let source2: LogicalPlan = dummy_scan_node(vec![Field::new("b", DataType::Int64)]).into(); + let join: LogicalPlan = Join::try_new( + source1.into(), source2.into(), vec![col("b")], vec![col("b")], - vec![], - output_schema.into(), JoinType::Inner, - source1.into(), - ) + )? .into(); - let filter: LogicalPlan = Filter::new(col("c").lt(&lit(2.0)), join.into()).into(); + let filter: LogicalPlan = Filter::new(col("b").lt(&lit(2.0)), join.into()).into(); let expected = "\ - Join: Type = Inner, On = col(b), Output schema = a (Int64), b (Utf8), c (Float64)\ - \n Filter: col(c) < lit(2.0)\ - \n Source: \"Json\", File paths = /foo, File schema = a (Int64), b (Utf8), c (Float64), Format-specific config = Json(JsonSourceConfig), Output schema = a (Int64), b (Utf8), c (Float64)\ - \n Filter: col(c) < lit(2.0)\ - \n Source: \"Json\", File paths = /foo, File schema = b (Utf8), c (Float64), Format-specific config = Json(JsonSourceConfig), Output schema = b (Utf8), c (Float64)"; + Join: Type = Inner, On = col(b), Output schema = a (Int64), b (Int64), c (Float64)\ + \n Filter: col(b) < lit(2.0)\ + \n Source: \"Json\", File paths = /foo, File schema = a (Int64), b (Int64), c (Float64), Format-specific config = Json(JsonSourceConfig), Output schema = a (Int64), b (Int64), c (Float64)\ + \n Filter: col(b) < lit(2.0)\ + \n Source: \"Json\", File paths = /foo, File schema = b (Int64), Format-specific config = Json(JsonSourceConfig), Output schema = b (Int64)"; assert_optimized_plan_eq(filter.into(), expected)?; Ok(()) } diff --git a/src/daft-plan/src/physical_ops/explode.rs b/src/daft-plan/src/physical_ops/explode.rs index 7d8ad5c6c0..9d3a6275b0 100644 --- a/src/daft-plan/src/physical_ops/explode.rs +++ b/src/daft-plan/src/physical_ops/explode.rs @@ -7,16 +7,13 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Explode { - pub explode_exprs: Vec, // Upstream node. pub input: Arc, + pub to_explode: Vec, } impl Explode { - pub(crate) fn new(explode_exprs: Vec, input: Arc) -> Self { - Self { - explode_exprs, - input, - } + pub(crate) fn new(input: Arc, to_explode: Vec) -> Self { + Self { input, to_explode } } } diff --git a/src/daft-plan/src/physical_ops/join.rs b/src/daft-plan/src/physical_ops/join.rs index aaa36a8d1c..aedc791b1d 100644 --- a/src/daft-plan/src/physical_ops/join.rs +++ b/src/daft-plan/src/physical_ops/join.rs @@ -10,7 +10,6 @@ pub struct Join { pub right: Arc, pub left_on: Vec, pub right_on: Vec, - pub output_projection: Vec, pub join_type: JoinType, // Upstream node. pub input: Arc, @@ -21,7 +20,6 @@ impl Join { right: Arc, left_on: Vec, right_on: Vec, - output_projection: Vec, join_type: JoinType, input: Arc, ) -> Self { @@ -29,7 +27,6 @@ impl Join { right, left_on, right_on, - output_projection, join_type, input, } diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 7a909b3e82..64ecf2f54c 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -276,12 +276,9 @@ impl PhysicalPlan { .call1((local_limit_iter, *limit, *num_partitions))?; Ok(global_limit_iter.into()) } - PhysicalPlan::Explode(Explode { - input, - explode_exprs, - }) => { + PhysicalPlan::Explode(Explode { input, to_explode }) => { let upstream_iter = input.to_partition_tasks(py, psets)?; - let explode_pyexprs: Vec = explode_exprs + let explode_pyexprs: Vec = to_explode .iter() .map(|expr| PyExpr::from(expr.clone())) .collect(); @@ -417,9 +414,9 @@ impl PhysicalPlan { right, left_on, right_on, - output_projection, join_type, input, + .. }) => { let upstream_input_iter = input.to_partition_tasks(py, psets)?; let upstream_right_iter = right.to_partition_tasks(py, psets)?; @@ -431,10 +428,6 @@ impl PhysicalPlan { .iter() .map(|expr| PyExpr::from(expr.clone())) .collect(); - let output_projection_pyexprs: Vec = output_projection - .iter() - .map(|expr| PyExpr::from(expr.clone())) - .collect(); let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? .getattr(pyo3::intern!(py, "join"))? @@ -443,7 +436,6 @@ impl PhysicalPlan { upstream_right_iter, left_on_pyexprs, right_on_pyexprs, - output_projection_pyexprs, *join_type, ))?; Ok(py_iter.into()) diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 0e175f12c1..3718a743a5 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -95,14 +95,12 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { ))) } LogicalPlan::Explode(LogicalExplode { - input, - explode_exprs, - .. + input, to_explode, .. }) => { let input_physical = plan(input)?; Ok(PhysicalPlan::Explode(Explode::new( - explode_exprs.clone(), input_physical.into(), + to_explode.clone(), ))) } LogicalPlan::Sort(LogicalSort { @@ -399,7 +397,6 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { input, left_on, right_on, - output_projection, join_type, .. }) => { @@ -442,7 +439,6 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { right_physical.into(), left_on.clone(), right_on.clone(), - output_projection.clone(), *join_type, left_physical.into(), )))