Skip to content

Commit

Permalink
[CHORE] Move schema construction under LogicalPlan construction (#1290)
Browse files Browse the repository at this point in the history
Currently, the schema for certain LogicalPlan nodes is being passed in
explicitly during construction, which means:

1. The schema passed in can be incorrect
2. Schema resolution logic is duplicated for all node construction

This PR addresses these by moving all schema resolution logic to
underneath LogicalPlan construction. Note that this means the
constructors are now fallible, returning a new snafu-based
`logical_plan::Result`.

---------

Co-authored-by: Xiayue Charles Lin <[email protected]>
Co-authored-by: Clark Zinzow <[email protected]>
  • Loading branch information
3 people committed Aug 24, 2023
1 parent e0b988c commit 1c1abfb
Show file tree
Hide file tree
Showing 20 changed files with 221 additions and 225 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
class Join(SingleOutputInstruction):
left_on: ExpressionsProjection
right_on: ExpressionsProjection
output_projection: ExpressionsProjection
how: JoinType

def run(self, inputs: list[Table]) -> list[Table]:
Expand All @@ -653,7 +652,6 @@ def _join(self, inputs: list[Table]) -> list[Table]:
right,
left_on=self.left_on,
right_on=self.right_on,
output_projection=self.output_projection,
how=self.how,
)
return [result]
Expand Down
2 changes: 0 additions & 2 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def join(
right_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
output_projection: ExpressionsProjection,
how: JoinType,
) -> InProgressPhysicalPlan[PartitionT]:
"""Pairwise join the partitions from `left_child_plan` and `right_child_plan` together."""
Expand Down Expand Up @@ -202,7 +201,6 @@ def join(
instruction=execution_step.Join(
left_on=left_on,
right_on=right_on,
output_projection=output_projection,
how=how,
)
)
Expand Down
1 change: 0 additions & 1 deletion daft/execution/physical_plan_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) ->
right_plan=_get_physical_plan(right_child, psets),
left_on=node._left_on,
right_on=node._right_on,
output_projection=node._output_projection,
how=node._how,
)

Expand Down
7 changes: 3 additions & 4 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def run(self, input_partition: Table) -> Table:
def explode(
input: physical_plan.InProgressPhysicalPlan[PartitionT], explode_exprs: list[PyExpr]
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
explode_expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in explode_exprs])
explode_expr_projection = ExpressionsProjection(
[Expression._from_pyexpr(expr)._explode() for expr in explode_exprs]
)
explode_op = ShimExplodeOp(explode_expr_projection)
return physical_plan.pipeline_instruction(
child_plan=input,
Expand Down Expand Up @@ -138,18 +140,15 @@ def join(
right: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
right_on: list[PyExpr],
output_projection: list[PyExpr],
join_type: JoinType,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on])
output_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in output_projection])
return physical_plan.join(
left_plan=input,
right_plan=right,
left_on=left_on_expr_proj,
right_on=right_on_expr_proj,
output_projection=output_expr_proj,
how=join_type,
)

Expand Down
32 changes: 3 additions & 29 deletions daft/logical/rust_logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def project(
projection: ExpressionsProjection,
custom_resource_request: ResourceRequest = ResourceRequest(),
) -> RustLogicalPlanBuilder:
schema = projection.resolve_schema(self.schema())
exprs = [expr._expr for expr in projection]
builder = self._builder.project(exprs, schema._schema, custom_resource_request)
builder = self._builder.project(exprs, custom_resource_request)
return RustLogicalPlanBuilder(builder)

def filter(self, predicate: Expression) -> RustLogicalPlanBuilder:
Expand All @@ -116,29 +115,15 @@ def limit(self, num_rows: int) -> RustLogicalPlanBuilder:
return RustLogicalPlanBuilder(builder)

def explode(self, explode_expressions: ExpressionsProjection) -> RustLogicalPlanBuilder:
# TODO(Clark): Move this logic to Rust side after we've ported ExpressionsProjection.
explode_expressions = ExpressionsProjection([expr._explode() for expr in explode_expressions])
input_schema = self.schema()
explode_schema = explode_expressions.resolve_schema(input_schema)
output_fields = []
for f in input_schema:
if f.name in explode_schema.column_names():
output_fields.append(explode_schema[f.name])
else:
output_fields.append(f)

exploded_schema = Schema._from_field_name_and_types([(f.name, f.dtype) for f in output_fields])
explode_pyexprs = [expr._expr for expr in explode_expressions]
builder = self._builder.explode(explode_pyexprs, exploded_schema._schema)
builder = self._builder.explode(explode_pyexprs)
return RustLogicalPlanBuilder(builder)

def count(self) -> RustLogicalPlanBuilder:
# TODO(Clark): Add dedicated logical/physical ops when introducing metadata-based count optimizations.
first_col = col(self.schema().column_names()[0])
builder = self._builder.aggregate([first_col._count(CountMode.All)._expr], [])
rename_expr = ExpressionsProjection([first_col.alias("count")])
schema = rename_expr.resolve_schema(Schema._from_pyschema(builder.schema()))
builder = builder.project(rename_expr.to_inner_py_exprs(), schema._schema, ResourceRequest())
builder = builder.project([first_col.alias("count")._expr], ResourceRequest())
return RustLogicalPlanBuilder(builder)

def distinct(self) -> RustLogicalPlanBuilder:
Expand Down Expand Up @@ -220,21 +205,10 @@ def join( # type: ignore[override]
elif how == JoinType.Right:
raise NotImplementedError("Right join not implemented.")
elif how == JoinType.Inner:
# TODO(Clark): Port this logic to Rust-side once ExpressionsProjection has been ported.
right_drop_set = {r.name() for l, r in zip(left_on, right_on) if l.name() == r.name()}
left_columns = ExpressionsProjection.from_schema(self.schema())
right_columns = ExpressionsProjection([col(f.name) for f in right.schema() if f.name not in right_drop_set])
output_projection = left_columns.union(right_columns, rename_dup="right.")
right_columns = ExpressionsProjection(list(output_projection)[len(left_columns) :])
output_schema = left_columns.resolve_schema(self.schema()).union(
right_columns.resolve_schema(right.schema())
)
builder = self._builder.join(
right._builder,
left_on.to_inner_py_exprs(),
right_on.to_inner_py_exprs(),
output_projection.to_inner_py_exprs(),
output_schema._schema,
how,
)
return RustLogicalPlanBuilder(builder)
Expand Down
1 change: 0 additions & 1 deletion daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def join(
right: Table,
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
output_projection: ExpressionsProjection | None = None,
how: JoinType = JoinType.Inner,
) -> Table:
if how != JoinType.Inner:
Expand Down
17 changes: 17 additions & 0 deletions src/common/error/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ pub enum DaftError {
External(GenericError),
}

impl std::error::Error for DaftError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
DaftError::FieldNotFound(_)
| DaftError::SchemaMismatch(_)
| DaftError::TypeError(_)
| DaftError::ComputeError(_)
| DaftError::ArrowError(_)
| DaftError::ValueError(_) => None,
DaftError::IoError(io_error) => Some(io_error),
DaftError::FileNotFound { source, .. } | DaftError::External(source) => Some(&**source),
#[cfg(feature = "python")]
DaftError::PyO3Error(pyerr) => Some(pyerr),
}
}
}

impl From<arrow2::error::Error> for DaftError {
fn from(error: arrow2::error::Error) -> Self {
DaftError::ArrowError(error.to_string())
Expand Down
1 change: 1 addition & 0 deletions src/daft-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ daft-table = {path = "../daft-table", default-features = false}
indexmap = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true, features = ["rc"]}
snafu = {workspace = true}

[features]
default = ["python"]
Expand Down
62 changes: 11 additions & 51 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::sync::Arc;

use common_error::DaftResult;

use crate::{logical_plan::LogicalPlan, optimization::Optimizer, ResourceRequest};

#[cfg(feature = "python")]
Expand Down Expand Up @@ -84,20 +82,14 @@ impl LogicalPlanBuilder {
pub fn project(
&self,
projection: Vec<PyExpr>,
projected_schema: &PySchema,
resource_request: ResourceRequest,
) -> PyResult<LogicalPlanBuilder> {
let projection_exprs = projection
.iter()
.map(|e| e.clone().into())
.collect::<Vec<Expr>>();
let logical_plan: LogicalPlan = ops::Project::new(
projection_exprs,
projected_schema.clone().into(),
resource_request,
self.plan.clone(),
)
.into();
let logical_plan: LogicalPlan =
ops::Project::new(projection_exprs, resource_request, self.plan.clone())?.into();
Ok(logical_plan.into())
}

Expand All @@ -112,21 +104,13 @@ impl LogicalPlanBuilder {
Ok(logical_plan.into())
}

pub fn explode(
&self,
explode_pyexprs: Vec<PyExpr>,
exploded_schema: &PySchema,
) -> PyResult<LogicalPlanBuilder> {
let explode_exprs = explode_pyexprs
pub fn explode(&self, to_explode_pyexprs: Vec<PyExpr>) -> PyResult<LogicalPlanBuilder> {
let to_explode = to_explode_pyexprs
.iter()
.map(|e| e.clone().into())
.collect::<Vec<Expr>>();
let logical_plan: LogicalPlan = ops::Explode::new(
explode_exprs,
exploded_schema.clone().into(),
self.plan.clone(),
)
.into();
let logical_plan: LogicalPlan =
ops::Explode::try_new(self.plan.clone(), to_explode)?.into();
Ok(logical_plan.into())
}

Expand Down Expand Up @@ -199,24 +183,8 @@ impl LogicalPlanBuilder {
.map(|expr| expr.clone().into())
.collect::<Vec<Expr>>();

let input_schema = self.plan.schema();
let fields = groupby_exprs
.iter()
.map(|expr| expr.to_field(&input_schema))
.chain(
agg_exprs
.iter()
.map(|agg_expr| agg_expr.to_field(&input_schema)),
)
.collect::<DaftResult<Vec<Field>>>()?;
let output_schema = Schema::new(fields)?;
let logical_plan: LogicalPlan = Aggregate::new(
agg_exprs,
groupby_exprs,
output_schema.into(),
self.plan.clone(),
)
.into();
let logical_plan: LogicalPlan =
Aggregate::try_new(self.plan.clone(), agg_exprs, groupby_exprs)?.into();
Ok(logical_plan.into())
}

Expand All @@ -225,8 +193,6 @@ impl LogicalPlanBuilder {
other: &Self,
left_on: Vec<PyExpr>,
right_on: Vec<PyExpr>,
output_projection: Vec<PyExpr>,
output_schema: &PySchema,
join_type: JoinType,
) -> PyResult<LogicalPlanBuilder> {
let left_on_exprs = left_on
Expand All @@ -237,19 +203,13 @@ impl LogicalPlanBuilder {
.iter()
.map(|e| e.clone().into())
.collect::<Vec<Expr>>();
let output_projection_exprs = output_projection
.iter()
.map(|e| e.clone().into())
.collect::<Vec<Expr>>();
let logical_plan: LogicalPlan = ops::Join::new(
let logical_plan: LogicalPlan = ops::Join::try_new(
self.plan.clone(),
other.plan.clone(),
left_on_exprs,
right_on_exprs,
output_projection_exprs,
output_schema.clone().into(),
join_type,
self.plan.clone(),
)
)?
.into();
Ok(logical_plan.into())
}
Expand Down
42 changes: 34 additions & 8 deletions src/daft-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{cmp::max, sync::Arc};

use common_error::DaftError;
use daft_core::schema::SchemaRef;
use snafu::Snafu;

use crate::{ops::*, PartitionScheme, PartitionSpec};

Expand Down Expand Up @@ -129,24 +131,24 @@ impl LogicalPlan {
let new_plan = match children {
[input] => match self {
Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"),
Self::Project(Project { projection, projected_schema, resource_request, .. }) => Self::Project(Project::new(
projection.clone(), projected_schema.clone(), resource_request.clone(), input.clone(),
)),
Self::Project(Project { projection, resource_request, .. }) => Self::Project(Project::new(
projection.clone(), resource_request.clone(), input.clone(),
).unwrap()),
Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::new(predicate.clone(), input.clone())),
Self::Limit(Limit { limit, .. }) => Self::Limit(Limit::new(*limit, input.clone())),
Self::Explode(Explode { explode_exprs, exploded_schema, .. }) => Self::Explode(Explode::new(explode_exprs.clone(), exploded_schema.clone(), input.clone())),
Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()),
Self::Sort(Sort { sort_by, descending, .. }) => Self::Sort(Sort::new(sort_by.clone(), descending.clone(), input.clone())),
Self::Repartition(Repartition { num_partitions, partition_by, scheme, .. }) => Self::Repartition(Repartition::new(*num_partitions, partition_by.clone(), scheme.clone(), input.clone())),
Self::Coalesce(Coalesce { num_to, .. }) => Self::Coalesce(Coalesce::new(*num_to, input.clone())),
Self::Distinct(_) => Self::Distinct(Distinct::new(input.clone())),
Self::Aggregate(Aggregate { aggregations, groupby, output_schema, ..}) => Self::Aggregate(Aggregate::new(aggregations.clone(), groupby.clone(), output_schema.clone(), input.clone())),
Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.clone(), groupby.clone()).unwrap()),
Self::Sink(Sink { schema, sink_info, .. }) => Self::Sink(Sink::new(schema.clone(), sink_info.clone(), input.clone())),
_ => panic!("Logical op {} has two inputs, but got one", self),
},
[input1, input2] => match self {
Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"),
Self::Concat(_) => Self::Concat(Concat::new(input2.clone(), input1.clone())),
Self::Join(Join { left_on, right_on, output_projection, output_schema, join_type, .. }) => Self::Join(Join::new(input2.clone(), left_on.clone(), right_on.clone(), output_projection.clone(), output_schema.clone(), *join_type, input1.clone())),
Self::Join(Join { left_on, right_on, join_type, .. }) => Self::Join(Join::try_new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type).unwrap()),
_ => panic!("Logical op {} has one input, but got two", self),
},
_ => panic!("Logical ops should never have more than 2 inputs, but got: {}", children.len())
Expand All @@ -169,8 +171,8 @@ impl LogicalPlan {
}
Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")],
Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")],
Self::Explode(Explode { explode_exprs, .. }) => {
vec![format!("Explode: {explode_exprs:?}")]
Self::Explode(Explode { to_explode, .. }) => {
vec![format!("Explode: {to_explode:?}")]
}
Self::Sort(sort) => sort.multiline_display(),
Self::Repartition(repartition) => repartition.multiline_display(),
Expand All @@ -190,6 +192,30 @@ impl LogicalPlan {
}
}

#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
pub(crate) enum Error {
#[snafu(display("Unable to create logical plan node due to: {}", source))]
CreationError { source: DaftError },
}
pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;

impl From<Error> for DaftError {
fn from(err: Error) -> DaftError {
match err {
Error::CreationError { source } => source,
}
}
}

#[cfg(feature = "python")]
impl std::convert::From<Error> for pyo3::PyErr {
fn from(value: Error) -> Self {
let daft_error: DaftError = value.into();
daft_error.into()
}
}

macro_rules! impl_from_data_struct_for_logical_plan {
($name:ident) => {
impl From<$name> for LogicalPlan {
Expand Down
Loading

0 comments on commit 1c1abfb

Please sign in to comment.