Skip to content

Commit

Permalink
[FEAT] [New Query Planner] Add support for Sort, Repartition, and Dis…
Browse files Browse the repository at this point in the history
…tinct in new query planner. (#1248)

This PR adds support for `df.sort()`, `df.repartition()`, and
`df.distinct()` in the new query planner.
  • Loading branch information
clarkzinzow authored Aug 10, 2023
1 parent 03380a4 commit 074e37c
Show file tree
Hide file tree
Showing 21 changed files with 577 additions and 30 deletions.
2 changes: 0 additions & 2 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,6 @@ def coalesce(

def reduce(
fanout_plan: InProgressPhysicalPlan[PartitionT],
num_partitions: int,
reduce_instruction: ReduceInstruction,
) -> InProgressPhysicalPlan[PartitionT]:
"""Reduce the result of fanout_plan.
Expand Down Expand Up @@ -656,7 +655,6 @@ def sort(
# Execute a sorting reduce on it.
yield from reduce(
fanout_plan=range_fanout_plan,
num_partitions=num_partitions,
reduce_instruction=execution_step.ReduceMergeAndSort(
sort_by=sort_by,
descending=descending,
Expand Down
1 change: 0 additions & 1 deletion daft/execution/physical_plan_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) ->
# Do the reduce.
return physical_plan.reduce(
fanout_plan=fanout_plan,
num_partitions=node.num_partitions(),
reduce_instruction=execution_step.ReduceMerge(),
)

Expand Down
39 changes: 39 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,42 @@ def tabular_scan(
return physical_plan.file_read(
file_info_iter, limit, Schema._from_pyschema(schema), None, None, file_format_config, filepaths_column_name
)


def sort(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
sort_by: list[PyExpr],
descending: list[bool],
num_partitions: int,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in sort_by])
return physical_plan.sort(
child_plan=input,
sort_by=expr_projection,
descending=descending,
num_partitions=num_partitions,
)


def split_by_hash(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
num_partitions: int,
partition_by: list[PyExpr],
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in partition_by])
fanout_instruction = execution_step.FanoutHash(
_num_outputs=num_partitions,
partition_by=expr_projection,
)
return physical_plan.pipeline_instruction(
input,
fanout_instruction,
ResourceRequest(), # TODO(Clark): Propagate resource request.
)


def reduce_merge(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
reduce_instruction = execution_step.ReduceMerge()
return physical_plan.reduce(input, reduce_instruction)
24 changes: 19 additions & 5 deletions daft/logical/rust_logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import fsspec

from daft import DataType
from daft.context import get_context
from daft.daft import FileFormat, FileFormatConfig
from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder
from daft.daft import PartitionScheme, PartitionSpec
from daft.errors import ExpressionTypeError
from daft.expressions.expressions import Expression, ExpressionsProjection
from daft.logical.builder import JoinType, LogicalPlanBuilder
from daft.logical.schema import Schema
Expand Down Expand Up @@ -110,15 +112,29 @@ def count(self) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")

def distinct(self) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")
builder = self._builder.distinct()
return RustLogicalPlanBuilder(builder)

def sort(self, sort_by: ExpressionsProjection, descending: list[bool] | bool = False) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")
# Disallow sorting by null, binary, and boolean columns.
# TODO(Clark): This is a port of an existing constraint, we should look at relaxing this.
resolved_sort_by_schema = sort_by.resolve_schema(self.schema())
for f, sort_by_expr in zip(resolved_sort_by_schema, sort_by):
if f.dtype == DataType.null() or f.dtype == DataType.binary() or f.dtype == DataType.bool():
raise ExpressionTypeError(f"Cannot sort on expression {sort_by_expr} with type: {f.dtype}")

sort_by_exprs = [expr._expr for expr in sort_by]
if not isinstance(descending, list):
descending = [descending] * len(sort_by_exprs)
builder = self._builder.sort(sort_by_exprs, descending)
return RustLogicalPlanBuilder(builder)

def repartition(
self, num_partitions: int, partition_by: ExpressionsProjection, scheme: PartitionScheme
) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")
partition_by_exprs = [expr._expr for expr in partition_by]
builder = self._builder.repartition(num_partitions, partition_by_exprs, scheme)
return RustLogicalPlanBuilder(builder)

def coalesce(self, num_partitions: int) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")
Expand Down Expand Up @@ -147,8 +163,6 @@ def agg(
builder = self._builder.aggregate([expr._expr for expr in exprs])
return RustLogicalPlanBuilder(builder)

raise NotImplementedError("not implemented")

def concat(self, other: LogicalPlanBuilder) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")

Expand Down
39 changes: 39 additions & 0 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,45 @@ impl LogicalPlanBuilder {
Ok(logical_plan_builder)
}

pub fn sort(
&self,
sort_by: Vec<PyExpr>,
descending: Vec<bool>,
) -> PyResult<LogicalPlanBuilder> {
let sort_by_exprs: Vec<Expr> = sort_by.iter().map(|expr| expr.clone().into()).collect();
let logical_plan: LogicalPlan =
ops::Sort::new(sort_by_exprs, descending, self.plan.clone()).into();
let logical_plan_builder = LogicalPlanBuilder::new(logical_plan.into());
Ok(logical_plan_builder)
}

pub fn repartition(
&self,
num_partitions: usize,
partition_by: Vec<PyExpr>,
scheme: PartitionScheme,
) -> PyResult<LogicalPlanBuilder> {
let partition_by_exprs: Vec<Expr> = partition_by
.iter()
.map(|expr| expr.clone().into())
.collect();
let logical_plan: LogicalPlan = ops::Repartition::new(
num_partitions,
partition_by_exprs,
scheme,
self.plan.clone(),
)
.into();
let logical_plan_builder = LogicalPlanBuilder::new(logical_plan.into());
Ok(logical_plan_builder)
}

pub fn distinct(&self) -> PyResult<LogicalPlanBuilder> {
let logical_plan: LogicalPlan = ops::Distinct::new(self.plan.clone()).into();
let logical_plan_builder = LogicalPlanBuilder::new(logical_plan.into());
Ok(logical_plan_builder)
}

pub fn aggregate(&self, agg_exprs: Vec<PyExpr>) -> PyResult<LogicalPlanBuilder> {
use crate::ops::Aggregate;
let agg_exprs = agg_exprs
Expand Down
37 changes: 35 additions & 2 deletions src/daft-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ use std::sync::Arc;

use daft_core::schema::SchemaRef;

use crate::{ops::*, PartitionSpec};
use crate::{ops::*, PartitionScheme, PartitionSpec};

#[derive(Clone, Debug)]
pub enum LogicalPlan {
Source(Source),
Filter(Filter),
Limit(Limit),
Sort(Sort),
Repartition(Repartition),
Distinct(Distinct),
Aggregate(Aggregate),
}

Expand All @@ -18,6 +21,9 @@ impl LogicalPlan {
Self::Source(Source { schema, .. }) => schema.clone(),
Self::Filter(Filter { input, .. }) => input.schema(),
Self::Limit(Limit { input, .. }) => input.schema(),
Self::Sort(Sort { input, .. }) => input.schema(),
Self::Repartition(Repartition { input, .. }) => input.schema(),
Self::Distinct(Distinct { input, .. }) => input.schema(),
Self::Aggregate(Aggregate { schema, .. }) => schema.clone(),
}
}
Expand All @@ -27,6 +33,24 @@ impl LogicalPlan {
Self::Source(Source { partition_spec, .. }) => partition_spec.clone(),
Self::Filter(Filter { input, .. }) => input.partition_spec(),
Self::Limit(Limit { input, .. }) => input.partition_spec(),
Self::Sort(Sort { input, sort_by, .. }) => PartitionSpec::new_internal(
PartitionScheme::Range,
input.partition_spec().num_partitions,
Some(sort_by.clone()),
)
.into(),
Self::Repartition(Repartition {
num_partitions,
partition_by,
scheme,
..
}) => PartitionSpec::new_internal(
scheme.clone(),
*num_partitions,
Some(partition_by.clone()),
)
.into(),
Self::Distinct(Distinct { input, .. }) => input.partition_spec(),
Self::Aggregate(Aggregate { input, .. }) => input.partition_spec(), // TODO
}
}
Expand All @@ -36,6 +60,9 @@ impl LogicalPlan {
Self::Source(..) => vec![],
Self::Filter(Filter { input, .. }) => vec![input],
Self::Limit(Limit { input, .. }) => vec![input],
Self::Sort(Sort { input, .. }) => vec![input],
Self::Repartition(Repartition { input, .. }) => vec![input],
Self::Distinct(Distinct { input, .. }) => vec![input],
Self::Aggregate(Aggregate { input, .. }) => vec![input],
}
}
Expand All @@ -45,6 +72,9 @@ impl LogicalPlan {
Self::Source(source) => source.multiline_display(),
Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")],
Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")],
Self::Sort(sort) => sort.multiline_display(),
Self::Repartition(repartition) => repartition.multiline_display(),
Self::Distinct(_) => vec!["Distinct".to_string()],
Self::Aggregate(aggregate) => aggregate.multiline_display(),
}
}
Expand All @@ -68,5 +98,8 @@ macro_rules! impl_from_data_struct_for_logical_plan {

impl_from_data_struct_for_logical_plan!(Source);
impl_from_data_struct_for_logical_plan!(Filter);
impl_from_data_struct_for_logical_plan!(Aggregate);
impl_from_data_struct_for_logical_plan!(Limit);
impl_from_data_struct_for_logical_plan!(Sort);
impl_from_data_struct_for_logical_plan!(Repartition);
impl_from_data_struct_for_logical_plan!(Distinct);
impl_from_data_struct_for_logical_plan!(Aggregate);
15 changes: 15 additions & 0 deletions src/daft-plan/src/ops/distinct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use std::sync::Arc;

use crate::LogicalPlan;

#[derive(Clone, Debug)]
pub struct Distinct {
// Upstream node.
pub input: Arc<LogicalPlan>,
}

impl Distinct {
pub(crate) fn new(input: Arc<LogicalPlan>) -> Self {
Self { input }
}
}
6 changes: 6 additions & 0 deletions src/daft-plan/src/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
mod agg;
mod distinct;
mod filter;
mod limit;
mod repartition;
mod sort;
mod source;

pub use agg::Aggregate;
pub use distinct::Distinct;
pub use filter::Filter;
pub use limit::Limit;
pub use repartition::Repartition;
pub use sort::Sort;
pub use source::Source;
42 changes: 42 additions & 0 deletions src/daft-plan/src/ops/repartition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use std::sync::Arc;

use daft_dsl::Expr;

use crate::{LogicalPlan, PartitionScheme};

#[derive(Clone, Debug)]
pub struct Repartition {
pub num_partitions: usize,
pub partition_by: Vec<Expr>,
pub scheme: PartitionScheme,
// Upstream node.
pub input: Arc<LogicalPlan>,
}

impl Repartition {
pub(crate) fn new(
num_partitions: usize,
partition_by: Vec<Expr>,
scheme: PartitionScheme,
input: Arc<LogicalPlan>,
) -> Self {
Self {
num_partitions,
partition_by,
scheme,
input,
}
}

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
res.push(format!(
"Repartition ({:?}): n={}",
self.scheme, self.num_partitions
));
if !self.partition_by.is_empty() {
res.push(format!(" Partition by: {:?}", self.partition_by));
}
res
}
}
44 changes: 44 additions & 0 deletions src/daft-plan/src/ops/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use std::sync::Arc;

use daft_dsl::Expr;

use crate::LogicalPlan;

#[derive(Clone, Debug)]
pub struct Sort {
pub sort_by: Vec<Expr>,
pub descending: Vec<bool>,
// Upstream node.
pub input: Arc<LogicalPlan>,
}

impl Sort {
pub(crate) fn new(sort_by: Vec<Expr>, descending: Vec<bool>, input: Arc<LogicalPlan>) -> Self {
Self {
sort_by,
descending,
input,
}
}

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
res.push("Sort:".to_string());
if !self.sort_by.is_empty() {
let pairs: Vec<String> = self
.sort_by
.iter()
.zip(self.descending.iter())
.map(|(sb, d)| {
format!(
"({:?}, {})",
sb,
if *d { "descending" } else { "ascending" },
)
})
.collect();
res.push(format!(" Sort by: {:?}", pairs));
}
res
}
}
Loading

0 comments on commit 074e37c

Please sign in to comment.