diff --git a/Cargo.lock b/Cargo.lock index 4443676e34..3077cc5308 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1116,6 +1116,7 @@ dependencies = [ "indexmap 2.0.0", "pyo3", "serde", + "snafu", ] [[package]] diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 701d82e441..66b5a7e539 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -641,7 +641,6 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) class Join(SingleOutputInstruction): left_on: ExpressionsProjection right_on: ExpressionsProjection - output_projection: ExpressionsProjection how: JoinType def run(self, inputs: list[Table]) -> list[Table]: @@ -653,7 +652,6 @@ def _join(self, inputs: list[Table]) -> list[Table]: right, left_on=self.left_on, right_on=self.right_on, - output_projection=self.output_projection, how=self.how, ) return [result] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index cf6bc01cee..84bb4ca656 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -172,7 +172,6 @@ def join( right_plan: InProgressPhysicalPlan[PartitionT], left_on: ExpressionsProjection, right_on: ExpressionsProjection, - output_projection: ExpressionsProjection, how: JoinType, ) -> InProgressPhysicalPlan[PartitionT]: """Pairwise join the partitions from `left_child_plan` and `right_child_plan` together.""" @@ -202,7 +201,6 @@ def join( instruction=execution_step.Join( left_on=left_on, right_on=right_on, - output_projection=output_projection, how=how, ) ) diff --git a/daft/execution/physical_plan_factory.py b/daft/execution/physical_plan_factory.py index 1fe1ee08ce..45ddd7a823 100644 --- a/daft/execution/physical_plan_factory.py +++ b/daft/execution/physical_plan_factory.py @@ -175,7 +175,6 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> right_plan=_get_physical_plan(right_child, psets), left_on=node._left_on, right_on=node._right_on, - output_projection=node._output_projection, how=node._how, ) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 977b3c8953..2f6a3a41cd 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -68,7 +68,9 @@ def run(self, input_partition: Table) -> Table: def explode( input: physical_plan.InProgressPhysicalPlan[PartitionT], explode_exprs: list[PyExpr] ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: - explode_expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in explode_exprs]) + explode_expr_projection = ExpressionsProjection( + [Expression._from_pyexpr(expr)._explode() for expr in explode_exprs] + ) explode_op = ShimExplodeOp(explode_expr_projection) return physical_plan.pipeline_instruction( child_plan=input, @@ -138,18 +140,15 @@ def join( right: physical_plan.InProgressPhysicalPlan[PartitionT], left_on: list[PyExpr], right_on: list[PyExpr], - output_projection: list[PyExpr], join_type: JoinType, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) - output_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in output_projection]) return physical_plan.join( left_plan=input, right_plan=right, left_on=left_on_expr_proj, right_on=right_on_expr_proj, - output_projection=output_expr_proj, how=join_type, ) diff --git a/daft/logical/rust_logical_plan.py b/daft/logical/rust_logical_plan.py index b4cf16d361..771e690f33 100644 --- a/daft/logical/rust_logical_plan.py +++ b/daft/logical/rust_logical_plan.py @@ -93,9 +93,8 @@ def project( projection: ExpressionsProjection, custom_resource_request: ResourceRequest = ResourceRequest(), ) -> RustLogicalPlanBuilder: - schema = projection.resolve_schema(self.schema()) exprs = [expr._expr for expr in projection] - builder = self._builder.project(exprs, schema._schema, custom_resource_request) + builder = self._builder.project(exprs, custom_resource_request) return RustLogicalPlanBuilder(builder) def filter(self, predicate: Expression) -> RustLogicalPlanBuilder: @@ -116,29 +115,15 @@ def limit(self, num_rows: int) -> RustLogicalPlanBuilder: return RustLogicalPlanBuilder(builder) def explode(self, explode_expressions: ExpressionsProjection) -> RustLogicalPlanBuilder: - # TODO(Clark): Move this logic to Rust side after we've ported ExpressionsProjection. - explode_expressions = ExpressionsProjection([expr._explode() for expr in explode_expressions]) - input_schema = self.schema() - explode_schema = explode_expressions.resolve_schema(input_schema) - output_fields = [] - for f in input_schema: - if f.name in explode_schema.column_names(): - output_fields.append(explode_schema[f.name]) - else: - output_fields.append(f) - - exploded_schema = Schema._from_field_name_and_types([(f.name, f.dtype) for f in output_fields]) explode_pyexprs = [expr._expr for expr in explode_expressions] - builder = self._builder.explode(explode_pyexprs, exploded_schema._schema) + builder = self._builder.explode(explode_pyexprs) return RustLogicalPlanBuilder(builder) def count(self) -> RustLogicalPlanBuilder: # TODO(Clark): Add dedicated logical/physical ops when introducing metadata-based count optimizations. first_col = col(self.schema().column_names()[0]) builder = self._builder.aggregate([first_col._count(CountMode.All)._expr], []) - rename_expr = ExpressionsProjection([first_col.alias("count")]) - schema = rename_expr.resolve_schema(Schema._from_pyschema(builder.schema())) - builder = builder.project(rename_expr.to_inner_py_exprs(), schema._schema, ResourceRequest()) + builder = builder.project([first_col.alias("count")._expr], ResourceRequest()) return RustLogicalPlanBuilder(builder) def distinct(self) -> RustLogicalPlanBuilder: @@ -220,21 +205,10 @@ def join( # type: ignore[override] elif how == JoinType.Right: raise NotImplementedError("Right join not implemented.") elif how == JoinType.Inner: - # TODO(Clark): Port this logic to Rust-side once ExpressionsProjection has been ported. - right_drop_set = {r.name() for l, r in zip(left_on, right_on) if l.name() == r.name()} - left_columns = ExpressionsProjection.from_schema(self.schema()) - right_columns = ExpressionsProjection([col(f.name) for f in right.schema() if f.name not in right_drop_set]) - output_projection = left_columns.union(right_columns, rename_dup="right.") - right_columns = ExpressionsProjection(list(output_projection)[len(left_columns) :]) - output_schema = left_columns.resolve_schema(self.schema()).union( - right_columns.resolve_schema(right.schema()) - ) builder = self._builder.join( right._builder, left_on.to_inner_py_exprs(), right_on.to_inner_py_exprs(), - output_projection.to_inner_py_exprs(), - output_schema._schema, how, ) return RustLogicalPlanBuilder(builder) diff --git a/daft/table/table.py b/daft/table/table.py index ffc221c0fd..5bc12d6447 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -293,7 +293,6 @@ def join( right: Table, left_on: ExpressionsProjection, right_on: ExpressionsProjection, - output_projection: ExpressionsProjection | None = None, how: JoinType = JoinType.Inner, ) -> Table: if how != JoinType.Inner: diff --git a/src/common/error/src/error.rs b/src/common/error/src/error.rs index 4a78137d2f..604b139e48 100644 --- a/src/common/error/src/error.rs +++ b/src/common/error/src/error.rs @@ -23,6 +23,23 @@ pub enum DaftError { External(GenericError), } +impl std::error::Error for DaftError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DaftError::FieldNotFound(_) + | DaftError::SchemaMismatch(_) + | DaftError::TypeError(_) + | DaftError::ComputeError(_) + | DaftError::ArrowError(_) + | DaftError::ValueError(_) => None, + DaftError::IoError(io_error) => Some(io_error), + DaftError::FileNotFound { source, .. } | DaftError::External(source) => Some(&**source), + #[cfg(feature = "python")] + DaftError::PyO3Error(pyerr) => Some(pyerr), + } + } +} + impl From for DaftError { fn from(error: arrow2::error::Error) -> Self { DaftError::ArrowError(error.to_string()) diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 556b192797..7a80697d3e 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -9,6 +9,7 @@ daft-table = {path = "../daft-table", default-features = false} indexmap = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} +snafu = {workspace = true} [features] default = ["python"] diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index d113523900..cafd63508d 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -1,7 +1,5 @@ use std::sync::Arc; -use common_error::DaftResult; - use crate::{logical_plan::LogicalPlan, optimization::Optimizer, ResourceRequest}; #[cfg(feature = "python")] @@ -84,20 +82,14 @@ impl LogicalPlanBuilder { pub fn project( &self, projection: Vec, - projected_schema: &PySchema, resource_request: ResourceRequest, ) -> PyResult { let projection_exprs = projection .iter() .map(|e| e.clone().into()) .collect::>(); - let logical_plan: LogicalPlan = ops::Project::new( - projection_exprs, - projected_schema.clone().into(), - resource_request, - self.plan.clone(), - ) - .into(); + let logical_plan: LogicalPlan = + ops::Project::new(projection_exprs, resource_request, self.plan.clone())?.into(); Ok(logical_plan.into()) } @@ -112,21 +104,13 @@ impl LogicalPlanBuilder { Ok(logical_plan.into()) } - pub fn explode( - &self, - explode_pyexprs: Vec, - exploded_schema: &PySchema, - ) -> PyResult { - let explode_exprs = explode_pyexprs + pub fn explode(&self, to_explode_pyexprs: Vec) -> PyResult { + let to_explode = to_explode_pyexprs .iter() .map(|e| e.clone().into()) .collect::>(); - let logical_plan: LogicalPlan = ops::Explode::new( - explode_exprs, - exploded_schema.clone().into(), - self.plan.clone(), - ) - .into(); + let logical_plan: LogicalPlan = + ops::Explode::try_new(self.plan.clone(), to_explode)?.into(); Ok(logical_plan.into()) } @@ -199,24 +183,8 @@ impl LogicalPlanBuilder { .map(|expr| expr.clone().into()) .collect::>(); - let input_schema = self.plan.schema(); - let fields = groupby_exprs - .iter() - .map(|expr| expr.to_field(&input_schema)) - .chain( - agg_exprs - .iter() - .map(|agg_expr| agg_expr.to_field(&input_schema)), - ) - .collect::>>()?; - let output_schema = Schema::new(fields)?; - let logical_plan: LogicalPlan = Aggregate::new( - agg_exprs, - groupby_exprs, - output_schema.into(), - self.plan.clone(), - ) - .into(); + let logical_plan: LogicalPlan = + Aggregate::try_new(self.plan.clone(), agg_exprs, groupby_exprs)?.into(); Ok(logical_plan.into()) } @@ -225,8 +193,6 @@ impl LogicalPlanBuilder { other: &Self, left_on: Vec, right_on: Vec, - output_projection: Vec, - output_schema: &PySchema, join_type: JoinType, ) -> PyResult { let left_on_exprs = left_on @@ -237,19 +203,13 @@ impl LogicalPlanBuilder { .iter() .map(|e| e.clone().into()) .collect::>(); - let output_projection_exprs = output_projection - .iter() - .map(|e| e.clone().into()) - .collect::>(); - let logical_plan: LogicalPlan = ops::Join::new( + let logical_plan: LogicalPlan = ops::Join::try_new( + self.plan.clone(), other.plan.clone(), left_on_exprs, right_on_exprs, - output_projection_exprs, - output_schema.clone().into(), join_type, - self.plan.clone(), - ) + )? .into(); Ok(logical_plan.into()) } diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 328e5e291d..b4757f3e6e 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -1,6 +1,8 @@ use std::{cmp::max, sync::Arc}; +use common_error::DaftError; use daft_core::schema::SchemaRef; +use snafu::Snafu; use crate::{ops::*, PartitionScheme, PartitionSpec}; @@ -129,24 +131,24 @@ impl LogicalPlan { let new_plan = match children { [input] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), - Self::Project(Project { projection, projected_schema, resource_request, .. }) => Self::Project(Project::new( - projection.clone(), projected_schema.clone(), resource_request.clone(), input.clone(), - )), + Self::Project(Project { projection, resource_request, .. }) => Self::Project(Project::new( + projection.clone(), resource_request.clone(), input.clone(), + ).unwrap()), Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::new(predicate.clone(), input.clone())), Self::Limit(Limit { limit, .. }) => Self::Limit(Limit::new(*limit, input.clone())), - Self::Explode(Explode { explode_exprs, exploded_schema, .. }) => Self::Explode(Explode::new(explode_exprs.clone(), exploded_schema.clone(), input.clone())), + Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()), Self::Sort(Sort { sort_by, descending, .. }) => Self::Sort(Sort::new(sort_by.clone(), descending.clone(), input.clone())), Self::Repartition(Repartition { num_partitions, partition_by, scheme, .. }) => Self::Repartition(Repartition::new(*num_partitions, partition_by.clone(), scheme.clone(), input.clone())), Self::Coalesce(Coalesce { num_to, .. }) => Self::Coalesce(Coalesce::new(*num_to, input.clone())), Self::Distinct(_) => Self::Distinct(Distinct::new(input.clone())), - Self::Aggregate(Aggregate { aggregations, groupby, output_schema, ..}) => Self::Aggregate(Aggregate::new(aggregations.clone(), groupby.clone(), output_schema.clone(), input.clone())), + Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.clone(), groupby.clone()).unwrap()), Self::Sink(Sink { schema, sink_info, .. }) => Self::Sink(Sink::new(schema.clone(), sink_info.clone(), input.clone())), _ => panic!("Logical op {} has two inputs, but got one", self), }, [input1, input2] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), Self::Concat(_) => Self::Concat(Concat::new(input2.clone(), input1.clone())), - Self::Join(Join { left_on, right_on, output_projection, output_schema, join_type, .. }) => Self::Join(Join::new(input2.clone(), left_on.clone(), right_on.clone(), output_projection.clone(), output_schema.clone(), *join_type, input1.clone())), + Self::Join(Join { left_on, right_on, join_type, .. }) => Self::Join(Join::try_new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type).unwrap()), _ => panic!("Logical op {} has one input, but got two", self), }, _ => panic!("Logical ops should never have more than 2 inputs, but got: {}", children.len()) @@ -169,8 +171,8 @@ impl LogicalPlan { } Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")], Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")], - Self::Explode(Explode { explode_exprs, .. }) => { - vec![format!("Explode: {explode_exprs:?}")] + Self::Explode(Explode { to_explode, .. }) => { + vec![format!("Explode: {to_explode:?}")] } Self::Sort(sort) => sort.multiline_display(), Self::Repartition(repartition) => repartition.multiline_display(), @@ -190,6 +192,30 @@ impl LogicalPlan { } } +#[derive(Debug, Snafu)] +#[snafu(visibility(pub(crate)))] +pub(crate) enum Error { + #[snafu(display("Unable to create logical plan node due to: {}", source))] + CreationError { source: DaftError }, +} +pub(crate) type Result = std::result::Result; + +impl From for DaftError { + fn from(err: Error) -> DaftError { + match err { + Error::CreationError { source } => source, + } + } +} + +#[cfg(feature = "python")] +impl std::convert::From for pyo3::PyErr { + fn from(value: Error) -> Self { + let daft_error: DaftError = value.into(); + daft_error.into() + } +} + macro_rules! impl_from_data_struct_for_logical_plan { ($name:ident) => { impl From<$name> for LogicalPlan { diff --git a/src/daft-plan/src/ops/agg.rs b/src/daft-plan/src/ops/agg.rs index 0cd2aa20f5..463def076a 100644 --- a/src/daft-plan/src/ops/agg.rs +++ b/src/daft-plan/src/ops/agg.rs @@ -1,12 +1,18 @@ use std::sync::Arc; +use snafu::ResultExt; + use daft_core::schema::{Schema, SchemaRef}; use daft_dsl::{AggExpr, Expr}; +use crate::logical_plan::{self, CreationSnafu}; use crate::LogicalPlan; #[derive(Clone, Debug)] pub struct Aggregate { + // Upstream node. + pub input: Arc, + /// Aggregations to apply. pub aggregations: Vec, @@ -14,24 +20,31 @@ pub struct Aggregate { pub groupby: Vec, pub output_schema: SchemaRef, - - // Upstream node. - pub input: Arc, } impl Aggregate { - pub(crate) fn new( + pub(crate) fn try_new( + input: Arc, aggregations: Vec, groupby: Vec, - output_schema: SchemaRef, - input: Arc, - ) -> Self { - Self { + ) -> logical_plan::Result { + let output_schema = { + let upstream_schema = input.schema(); + let fields = groupby + .iter() + .map(|e| e.to_field(&upstream_schema)) + .chain(aggregations.iter().map(|ae| ae.to_field(&upstream_schema))) + .collect::>>() + .context(CreationSnafu)?; + Schema::new(fields).context(CreationSnafu)?.into() + }; + + Ok(Self { aggregations, groupby, output_schema, input, - } + }) } pub(crate) fn schema(&self) -> SchemaRef { diff --git a/src/daft-plan/src/ops/explode.rs b/src/daft-plan/src/ops/explode.rs index 9f98c3c487..fe8ad037ff 100644 --- a/src/daft-plan/src/ops/explode.rs +++ b/src/daft-plan/src/ops/explode.rs @@ -1,28 +1,54 @@ use std::sync::Arc; -use daft_core::schema::SchemaRef; +use daft_core::schema::{Schema, SchemaRef}; use daft_dsl::Expr; +use snafu::ResultExt; -use crate::LogicalPlan; +use crate::{ + logical_plan::{self, CreationSnafu}, + LogicalPlan, +}; #[derive(Clone, Debug)] pub struct Explode { - pub explode_exprs: Vec, - pub exploded_schema: SchemaRef, // Upstream node. pub input: Arc, + // Expressions to explode. e.g. col("a") + pub to_explode: Vec, + pub exploded_schema: SchemaRef, } impl Explode { - pub(crate) fn new( - explode_exprs: Vec, - exploded_schema: SchemaRef, + pub(crate) fn try_new( input: Arc, - ) -> Self { - Self { - explode_exprs, - exploded_schema, + to_explode: Vec, + ) -> logical_plan::Result { + let explode_exprs = to_explode + .iter() + .map(daft_dsl::functions::list::explode) + .collect::>(); + let exploded_schema = { + let upstream_schema = input.schema(); + let explode_schema = { + let explode_fields = explode_exprs + .iter() + .map(|e| e.to_field(&upstream_schema)) + .collect::>>() + .context(CreationSnafu)?; + Schema::new(explode_fields).context(CreationSnafu)? + }; + let fields = upstream_schema + .fields + .iter() + .map(|(name, field)| explode_schema.fields.get(name).unwrap_or(field)) + .cloned() + .collect::>(); + Schema::new(fields).context(CreationSnafu)?.into() + }; + Ok(Self { input, - } + to_explode, + exploded_schema, + }) } } diff --git a/src/daft-plan/src/ops/join.rs b/src/daft-plan/src/ops/join.rs index be56bbc81b..6a153de947 100644 --- a/src/daft-plan/src/ops/join.rs +++ b/src/daft-plan/src/ops/join.rs @@ -1,41 +1,68 @@ -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; -use daft_core::schema::SchemaRef; +use daft_core::schema::{Schema, SchemaRef}; use daft_dsl::Expr; +use snafu::ResultExt; -use crate::{JoinType, LogicalPlan}; +use crate::{ + logical_plan::{self, CreationSnafu}, + JoinType, LogicalPlan, +}; #[derive(Clone, Debug)] pub struct Join { + // Upstream nodes. + pub input: Arc, pub right: Arc, + pub left_on: Vec, pub right_on: Vec, - pub output_projection: Vec, pub output_schema: SchemaRef, pub join_type: JoinType, - // Upstream node. - pub input: Arc, } impl Join { - pub(crate) fn new( + pub(crate) fn try_new( + input: Arc, right: Arc, left_on: Vec, right_on: Vec, - output_projection: Vec, - output_schema: SchemaRef, join_type: JoinType, - input: Arc, - ) -> Self { - Self { + ) -> logical_plan::Result { + // Schema inference ported from existing behaviour for parity, + // but contains bug https://github.com/Eventual-Inc/Daft/issues/1294 + let output_schema = { + let left_join_keys = left_on + .iter() + .map(|e| e.name()) + .collect::>>() + .context(CreationSnafu)?; + let left_schema = &input.schema().fields; + let fields = left_schema + .iter() + .map(|(_, field)| field) + .cloned() + .chain(right.schema().fields.iter().filter_map(|(rname, rfield)| { + if left_join_keys.contains(rname.as_str()) { + None + } else if left_schema.contains_key(rname) { + let new_name = format!("right.{}", rname); + Some(rfield.rename(new_name)) + } else { + Some(rfield.clone()) + } + })) + .collect::>(); + Schema::new(fields).context(CreationSnafu)?.into() + }; + Ok(Self { + input, right, left_on, right_on, - output_projection, output_schema, join_type, - input, - } + }) } pub fn multiline_display(&self) -> Vec { diff --git a/src/daft-plan/src/ops/project.rs b/src/daft-plan/src/ops/project.rs index fbd74c7d65..72cfcb3320 100644 --- a/src/daft-plan/src/ops/project.rs +++ b/src/daft-plan/src/ops/project.rs @@ -1,8 +1,10 @@ use std::sync::Arc; -use daft_core::schema::SchemaRef; +use daft_core::schema::{Schema, SchemaRef}; use daft_dsl::Expr; +use snafu::ResultExt; +use crate::logical_plan::{CreationSnafu, Result}; use crate::{LogicalPlan, ResourceRequest}; #[derive(Clone, Debug)] @@ -17,15 +19,23 @@ pub struct Project { impl Project { pub(crate) fn new( projection: Vec, - projected_schema: SchemaRef, resource_request: ResourceRequest, input: Arc, - ) -> Self { - Self { + ) -> Result { + let upstream_schema = input.schema(); + let projected_schema = { + let fields = projection + .iter() + .map(|e| e.to_field(&upstream_schema)) + .collect::>>() + .context(CreationSnafu)?; + Schema::new(fields).context(CreationSnafu)?.into() + }; + Ok(Self { projection, projected_schema, resource_request, input, - } + }) } } diff --git a/src/daft-plan/src/optimization/rules/push_down_filter.rs b/src/daft-plan/src/optimization/rules/push_down_filter.rs index aad1d777ed..a43dc9b079 100644 --- a/src/daft-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/optimization/rules/push_down_filter.rs @@ -115,10 +115,9 @@ impl OptimizerRule for PushDownFilter { // Create new Projection. let new_projection: LogicalPlan = Project::new( child_project.projection.clone(), - child_project.projected_schema.clone(), child_project.resource_request.clone(), push_down_filter.into(), - ) + )? .into(); if can_not_push.is_empty() { // If all Filter predicate expressions were pushable past Projection, return new @@ -272,13 +271,8 @@ mod tests { Field::new("b", DataType::Utf8), ]) .into(); - let projection: LogicalPlan = Project::new( - vec![col("a")], - Schema::new(vec![source.schema().get_field("a")?.clone()])?.into(), - Default::default(), - source.into(), - ) - .into(); + let projection: LogicalPlan = + Project::new(vec![col("a")], Default::default(), source.into())?.into(); let filter: LogicalPlan = Filter::new(col("a").lt(&lit(2)), projection.into()).into(); let expected = "\ Project: col(a)\ @@ -295,13 +289,8 @@ mod tests { Field::new("b", DataType::Utf8), ]) .into(); - let projection: LogicalPlan = Project::new( - vec![col("a"), col("b")], - source.schema().clone(), - Default::default(), - source.into(), - ) - .into(); + let projection: LogicalPlan = + Project::new(vec![col("a"), col("b")], Default::default(), source.into())?.into(); let filter: LogicalPlan = Filter::new( col("a").lt(&lit(2)).and(&col("b").eq(&lit("foo"))), projection.into(), @@ -323,13 +312,8 @@ mod tests { ]) .into(); // Projection involves compute on filtered column "a". - let projection: LogicalPlan = Project::new( - vec![col("a") + lit(1)], - Schema::new(vec![source.schema().get_field("a")?.clone()])?.into(), - Default::default(), - source.into(), - ) - .into(); + let projection: LogicalPlan = + Project::new(vec![col("a") + lit(1)], Default::default(), source.into())?.into(); let filter: LogicalPlan = Filter::new(col("a").lt(&lit(2)), projection.into()).into(); // Filter should NOT commute with Project, since this would involve redundant computation. let expected = "\ @@ -349,13 +333,8 @@ mod tests { Field::new("b", DataType::Utf8), ]) .into(); - let projection: LogicalPlan = Project::new( - vec![col("a") + lit(1)], - Schema::new(vec![source.schema().get_field("a")?.clone()])?.into(), - Default::default(), - source.into(), - ) - .into(); + let projection: LogicalPlan = + Project::new(vec![col("a") + lit(1)], Default::default(), source.into())?.into(); let filter: LogicalPlan = Filter::new(col("a").lt(&lit(2)), projection.into()).into(); let expected = "\ Project: col(a) + lit(1)\ @@ -451,16 +430,13 @@ mod tests { Field::new("c", DataType::Float64), ]) .into(); - let output_schema = source1.schema().union(source2.schema().as_ref())?; - let join: LogicalPlan = Join::new( + let join: LogicalPlan = Join::try_new( + source1.into(), source2.into(), vec![col("b")], vec![col("b")], - vec![], - output_schema.into(), JoinType::Inner, - source1.into(), - ) + )? .into(); let filter: LogicalPlan = Filter::new(col("a").lt(&lit(2)), join.into()).into(); let expected = "\ @@ -484,16 +460,13 @@ mod tests { Field::new("c", DataType::Float64), ]) .into(); - let output_schema = source1.schema().union(source2.schema().as_ref())?; - let join: LogicalPlan = Join::new( + let join: LogicalPlan = Join::try_new( + source1.into(), source2.into(), vec![col("b")], vec![col("b")], - vec![], - output_schema.into(), JoinType::Inner, - source1.into(), - ) + )? .into(); let filter: LogicalPlan = Filter::new(col("c").lt(&lit(2.0)), join.into()).into(); let expected = "\ @@ -509,33 +482,26 @@ mod tests { fn filter_commutes_with_join_both_sides() -> DaftResult<()> { let source1: LogicalPlan = dummy_scan_node(vec![ Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8), + Field::new("b", DataType::Int64), Field::new("c", DataType::Float64), ]) .into(); - let source2: LogicalPlan = dummy_scan_node(vec![ - Field::new("b", DataType::Utf8), - Field::new("c", DataType::Float64), - ]) - .into(); - let output_schema = source1.schema().union(source2.schema().as_ref())?; - let join: LogicalPlan = Join::new( + let source2: LogicalPlan = dummy_scan_node(vec![Field::new("b", DataType::Int64)]).into(); + let join: LogicalPlan = Join::try_new( + source1.into(), source2.into(), vec![col("b")], vec![col("b")], - vec![], - output_schema.into(), JoinType::Inner, - source1.into(), - ) + )? .into(); - let filter: LogicalPlan = Filter::new(col("c").lt(&lit(2.0)), join.into()).into(); + let filter: LogicalPlan = Filter::new(col("b").lt(&lit(2.0)), join.into()).into(); let expected = "\ - Join: Type = Inner, On = col(b), Output schema = a (Int64), b (Utf8), c (Float64)\ - \n Filter: col(c) < lit(2.0)\ - \n Source: \"Json\", File paths = /foo, File schema = a (Int64), b (Utf8), c (Float64), Format-specific config = Json(JsonSourceConfig), Output schema = a (Int64), b (Utf8), c (Float64)\ - \n Filter: col(c) < lit(2.0)\ - \n Source: \"Json\", File paths = /foo, File schema = b (Utf8), c (Float64), Format-specific config = Json(JsonSourceConfig), Output schema = b (Utf8), c (Float64)"; + Join: Type = Inner, On = col(b), Output schema = a (Int64), b (Int64), c (Float64)\ + \n Filter: col(b) < lit(2.0)\ + \n Source: \"Json\", File paths = /foo, File schema = a (Int64), b (Int64), c (Float64), Format-specific config = Json(JsonSourceConfig), Output schema = a (Int64), b (Int64), c (Float64)\ + \n Filter: col(b) < lit(2.0)\ + \n Source: \"Json\", File paths = /foo, File schema = b (Int64), Format-specific config = Json(JsonSourceConfig), Output schema = b (Int64)"; assert_optimized_plan_eq(filter.into(), expected)?; Ok(()) } diff --git a/src/daft-plan/src/physical_ops/explode.rs b/src/daft-plan/src/physical_ops/explode.rs index 7d8ad5c6c0..9d3a6275b0 100644 --- a/src/daft-plan/src/physical_ops/explode.rs +++ b/src/daft-plan/src/physical_ops/explode.rs @@ -7,16 +7,13 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Explode { - pub explode_exprs: Vec, // Upstream node. pub input: Arc, + pub to_explode: Vec, } impl Explode { - pub(crate) fn new(explode_exprs: Vec, input: Arc) -> Self { - Self { - explode_exprs, - input, - } + pub(crate) fn new(input: Arc, to_explode: Vec) -> Self { + Self { input, to_explode } } } diff --git a/src/daft-plan/src/physical_ops/join.rs b/src/daft-plan/src/physical_ops/join.rs index aaa36a8d1c..aedc791b1d 100644 --- a/src/daft-plan/src/physical_ops/join.rs +++ b/src/daft-plan/src/physical_ops/join.rs @@ -10,7 +10,6 @@ pub struct Join { pub right: Arc, pub left_on: Vec, pub right_on: Vec, - pub output_projection: Vec, pub join_type: JoinType, // Upstream node. pub input: Arc, @@ -21,7 +20,6 @@ impl Join { right: Arc, left_on: Vec, right_on: Vec, - output_projection: Vec, join_type: JoinType, input: Arc, ) -> Self { @@ -29,7 +27,6 @@ impl Join { right, left_on, right_on, - output_projection, join_type, input, } diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 7a909b3e82..64ecf2f54c 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -276,12 +276,9 @@ impl PhysicalPlan { .call1((local_limit_iter, *limit, *num_partitions))?; Ok(global_limit_iter.into()) } - PhysicalPlan::Explode(Explode { - input, - explode_exprs, - }) => { + PhysicalPlan::Explode(Explode { input, to_explode }) => { let upstream_iter = input.to_partition_tasks(py, psets)?; - let explode_pyexprs: Vec = explode_exprs + let explode_pyexprs: Vec = to_explode .iter() .map(|expr| PyExpr::from(expr.clone())) .collect(); @@ -417,9 +414,9 @@ impl PhysicalPlan { right, left_on, right_on, - output_projection, join_type, input, + .. }) => { let upstream_input_iter = input.to_partition_tasks(py, psets)?; let upstream_right_iter = right.to_partition_tasks(py, psets)?; @@ -431,10 +428,6 @@ impl PhysicalPlan { .iter() .map(|expr| PyExpr::from(expr.clone())) .collect(); - let output_projection_pyexprs: Vec = output_projection - .iter() - .map(|expr| PyExpr::from(expr.clone())) - .collect(); let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? .getattr(pyo3::intern!(py, "join"))? @@ -443,7 +436,6 @@ impl PhysicalPlan { upstream_right_iter, left_on_pyexprs, right_on_pyexprs, - output_projection_pyexprs, *join_type, ))?; Ok(py_iter.into()) diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 0e175f12c1..3718a743a5 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -95,14 +95,12 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { ))) } LogicalPlan::Explode(LogicalExplode { - input, - explode_exprs, - .. + input, to_explode, .. }) => { let input_physical = plan(input)?; Ok(PhysicalPlan::Explode(Explode::new( - explode_exprs.clone(), input_physical.into(), + to_explode.clone(), ))) } LogicalPlan::Sort(LogicalSort { @@ -399,7 +397,6 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { input, left_on, right_on, - output_projection, join_type, .. }) => { @@ -442,7 +439,6 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { right_physical.into(), left_on.clone(), right_on.clone(), - output_projection.clone(), *join_type, left_physical.into(), )))