From fd8c9401a63ae5762dab13664ba095777d167a40 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:54:02 -0700 Subject: [PATCH] [FEAT] Add ActorPoolProject logical and physical plans (#2601) Adds new LogicalPlan and PhysicalPlan nodes for ActorPoolProject Note that this is not yet reachable via user-facing code, since nothing from our API should be able to create logical `ActorPoolProject` plans yet (this will be left to a follow-on PR and we can feature-flag it) --- **ActorPoolProject** works just like a normal Project (it has a list of expressions that it is responsible for evaluating). The main differences are: 1. The scheduler is responsible for running these projections on top of an actor pool, instead of naively running it like a normal projection 2. The expressions are expected to contain `PyPartialUDFs`, which the execution backend is responsible for initializing once-per-actor. 3. (TODO in a follow-up PR) During evaluation of the expressions, we will be looking for the initialized Python classes from some global state. The projections in an ActorPoolProject cannot be evaluated without that global state having been initialized. --------- Co-authored-by: Jay Chia Co-authored-by: Desmond Cheong --- daft/daft.pyi | 4 +- daft/execution/physical_plan.py | 12 ++ daft/execution/rust_physical_plan_shim.py | 19 +++ daft/expressions/expressions.py | 18 +- daft/udf.py | 12 ++ src/daft-dsl/src/functions/mod.rs | 4 - src/daft-dsl/src/functions/python/mod.rs | 47 ++++- src/daft-dsl/src/functions/python/udf.rs | 66 ++++---- src/daft-dsl/src/python.rs | 18 +- src/daft-functions/Cargo.toml | 1 + src/daft-plan/Cargo.toml | 3 +- .../src/logical_ops/actor_pool_project.rs | 91 ++++++++++ src/daft-plan/src/logical_ops/mod.rs | 2 + .../rules/push_down_projection.rs | 160 +++++++++++++++++- src/daft-plan/src/logical_plan.rs | 12 ++ .../src/physical_ops/actor_pool_project.rs | 104 ++++++++++++ src/daft-plan/src/physical_ops/mod.rs | 2 + .../rules/reorder_partition_keys.rs | 10 ++ src/daft-plan/src/physical_plan.rs | 15 +- .../src/physical_planner/translate.rs | 20 ++- src/daft-scheduler/src/scheduler.rs | 54 ++++++ 21 files changed, 619 insertions(+), 55 deletions(-) create mode 100644 src/daft-plan/src/logical_ops/actor_pool_project.rs create mode 100644 src/daft-plan/src/physical_ops/actor_pool_project.rs diff --git a/daft/daft.pyi b/daft/daft.pyi index 8012a3edfe..67308b06ed 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1179,10 +1179,10 @@ def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ... def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ... def series_lit(item: PySeries) -> PyExpr: ... def stateless_udf( - partial_stateless_udf: PartialStatelessUDF, expressions: list[PyExpr], return_dtype: PyDataType + name: str, partial_stateless_udf: PartialStatelessUDF, expressions: list[PyExpr], return_dtype: PyDataType ) -> PyExpr: ... def stateful_udf( - partial_stateful_udf: PartialStatefulUDF, expressions: list[PyExpr], return_dtype: PyDataType + name: str, partial_stateful_udf: PartialStatefulUDF, expressions: list[PyExpr], return_dtype: PyDataType ) -> PyExpr: ... def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ... def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ... diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 99affa1098..d0d529075a 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -56,6 +56,8 @@ from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties + from daft.udf import PartialStatefulUDF + # A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks. InProgressPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionTaskBuilder[PartitionT]]] @@ -199,6 +201,16 @@ def pipeline_instruction( ) +def actor_pool_project( + child_plan: InProgressPhysicalPlan[PartitionT], + projection: ExpressionsProjection, + partial_stateful_udfs: dict[str, PartialStatefulUDF], + resource_request: execution_step.ResourceRequest, + num_actors: int, +) -> InProgressPhysicalPlan[PartitionT]: + raise NotImplementedError("Execution of ActorPoolProjects not yet implemented") + + def monotonically_increasing_id( child_plan: InProgressPhysicalPlan[PartitionT], column_name: str ) -> InProgressPhysicalPlan[PartitionT]: diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index b14d424d0c..78c9fc95a0 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -23,6 +23,8 @@ from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties + from daft.udf import PartialStatefulUDF + def scan_with_tasks( scan_tasks: list[ScanTask], @@ -73,6 +75,23 @@ def project( ) +def actor_pool_project( + input: physical_plan.InProgressPhysicalPlan[PartitionT], + projection: list[PyExpr], + partial_stateful_udfs: dict[str, PartialStatefulUDF], + resource_request: ResourceRequest, + num_actors: int, +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in projection]) + return physical_plan.actor_pool_project( + child_plan=input, + projection=expr_projection, + partial_stateful_udfs=partial_stateful_udfs, + resource_request=resource_request, + num_actors=num_actors, + ) + + class ShimExplodeOp(MapPartitionOp): explode_columns: ExpressionsProjection diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 1e8dc06f8e..16914718d7 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -230,17 +230,22 @@ def _to_expression(obj: object) -> Expression: @staticmethod def stateless_udf( + name: builtins.str, partial: PartialStatelessUDF, expressions: builtins.list[Expression], return_dtype: DataType, ) -> Expression: - return Expression._from_pyexpr(_stateless_udf(partial, [e._expr for e in expressions], return_dtype._dtype)) + return Expression._from_pyexpr( + _stateless_udf(name, partial, [e._expr for e in expressions], return_dtype._dtype) + ) @staticmethod def stateful_udf( - partial: PartialStatefulUDF, expressions: builtins.list[Expression], return_dtype: DataType + name: builtins.str, partial: PartialStatefulUDF, expressions: builtins.list[Expression], return_dtype: DataType ) -> Expression: - return Expression._from_pyexpr(_stateful_udf(partial, [e._expr for e in expressions], return_dtype._dtype)) + return Expression._from_pyexpr( + _stateful_udf(name, partial, [e._expr for e in expressions], return_dtype._dtype) + ) def __bool__(self) -> bool: raise ValueError( @@ -799,7 +804,12 @@ def apply(self, func: Callable, return_dtype: DataType) -> Expression: def batch_func(self_series): return [func(x) for x in self_series.to_pylist()] - return StatelessUDF(func=batch_func, return_dtype=return_dtype)(self) + name = getattr(func, "__module__", "") # type: ignore[call-overload] + if name: + name = name + "." + name = name + getattr(func, "__qualname__") # type: ignore[call-overload] + + return StatelessUDF(name=name, func=batch_func, return_dtype=return_dtype)(self) def is_null(self) -> Expression: """Checks if values in the Expression are Null (a special value indicating missing data) diff --git a/daft/udf.py b/daft/udf.py index d69d2f4fea..170e9f42d4 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -163,6 +163,7 @@ class PartialStatefulUDF: @dataclasses.dataclass class StatelessUDF(UDF): + name: str func: UserProvidedPythonFunction return_dtype: DataType @@ -178,6 +179,7 @@ def __call__(self, *args, **kwargs) -> Expression: bound_args = BoundUDFArgs(self.bind_func(*args, **kwargs)) expressions = list(bound_args.expressions().values()) return Expression.stateless_udf( + name=self.name, partial=PartialStatelessUDF(self.func, self.return_dtype, bound_args), expressions=expressions, return_dtype=self.return_dtype, @@ -195,6 +197,7 @@ def __hash__(self) -> int: @dataclasses.dataclass class StatefulUDF(UDF): + name: str cls: type return_dtype: DataType @@ -210,6 +213,7 @@ def __call__(self, *args, **kwargs) -> Expression: bound_args = BoundUDFArgs(self.bind_func(*args, **kwargs)) expressions = list(bound_args.expressions().values()) return Expression.stateful_udf( + name=self.name, partial=PartialStatefulUDF(self.cls, self.return_dtype, bound_args), expressions=expressions, return_dtype=self.return_dtype, @@ -284,13 +288,21 @@ def udf( """ def _udf(f: UserProvidedPythonFunction | type) -> UDF: + # Grab a name for the UDF. It **should** be unique. + name = getattr(f, "__module__", "") # type: ignore[call-overload] + if name: + name = name + "." + name = name + getattr(f, "__qualname__") # type: ignore[call-overload] + if inspect.isclass(f): return StatefulUDF( + name=name, cls=f, return_dtype=return_dtype, ) else: return StatelessUDF( + name=name, func=f, return_dtype=return_dtype, ) diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 439af54132..67b019c20b 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -35,9 +35,7 @@ use daft_core::{datatypes::Field, schema::Schema, series::Series}; use serde::{Deserialize, Serialize}; -#[cfg(feature = "python")] pub mod python; -#[cfg(feature = "python")] use python::PythonUDF; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -52,7 +50,6 @@ pub enum FunctionExpr { Struct(StructExpr), Json(JsonExpr), Image(ImageExpr), - #[cfg(feature = "python")] Python(PythonUDF), Partitioning(PartitioningExpr), } @@ -83,7 +80,6 @@ impl FunctionExpr { Struct(expr) => expr.get_evaluator(), Json(expr) => expr.get_evaluator(), Image(expr) => expr.get_evaluator(), - #[cfg(feature = "python")] Python(expr) => expr.get_evaluator(), Partitioning(expr) => expr.get_evaluator(), } diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index df17429866..99e308fc36 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -1,6 +1,9 @@ +#[cfg(feature = "python")] mod partial_udf; mod udf; +use std::sync::Arc; + use common_error::DaftResult; use daft_core::datatypes::DataType; use serde::{Deserialize, Serialize}; @@ -27,6 +30,8 @@ impl PythonUDF { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct StatelessPythonUDF { + pub name: Arc, + #[cfg(feature = "python")] partial_func: partial_udf::PyPartialUDF, num_expressions: usize, pub return_dtype: DataType, @@ -34,18 +39,23 @@ pub struct StatelessPythonUDF { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct StatefulPythonUDF { + pub name: Arc, + #[cfg(feature = "python")] pub stateful_partial_func: partial_udf::PyPartialUDF, - num_expressions: usize, + pub num_expressions: usize, pub return_dtype: DataType, } +#[cfg(feature = "python")] pub fn stateless_udf( + name: &str, py_partial_stateless_udf: pyo3::PyObject, expressions: &[ExprRef], return_dtype: DataType, ) -> DaftResult { Ok(Expr::Function { func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF { + name: name.to_string().into(), partial_func: partial_udf::PyPartialUDF(py_partial_stateless_udf), num_expressions: expressions.len(), return_dtype, @@ -54,13 +64,32 @@ pub fn stateless_udf( }) } +#[cfg(not(feature = "python"))] +pub fn stateless_udf( + name: &str, + expressions: &[ExprRef], + return_dtype: DataType, +) -> DaftResult { + Ok(Expr::Function { + func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF { + name: name.to_string().into(), + num_expressions: expressions.len(), + return_dtype, + })), + inputs: expressions.into(), + }) +} + +#[cfg(feature = "python")] pub fn stateful_udf( + name: &str, py_stateful_partial_func: pyo3::PyObject, expressions: &[ExprRef], return_dtype: DataType, ) -> DaftResult { Ok(Expr::Function { func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + name: name.to_string().into(), stateful_partial_func: partial_udf::PyPartialUDF(py_stateful_partial_func), num_expressions: expressions.len(), return_dtype, @@ -68,3 +97,19 @@ pub fn stateful_udf( inputs: expressions.into(), }) } + +#[cfg(not(feature = "python"))] +pub fn stateful_udf( + name: &str, + expressions: &[ExprRef], + return_dtype: DataType, +) -> DaftResult { + Ok(Expr::Function { + func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + name: name.to_string().into(), + num_expressions: expressions.len(), + return_dtype, + })), + inputs: expressions.into(), + }) +} diff --git a/src/daft-dsl/src/functions/python/udf.rs b/src/daft-dsl/src/functions/python/udf.rs index 3c7d8fbfce..15d29f4c07 100644 --- a/src/daft-dsl/src/functions/python/udf.rs +++ b/src/daft-dsl/src/functions/python/udf.rs @@ -1,4 +1,6 @@ use daft_core::DataType; + +#[cfg(feature = "python")] use pyo3::{types::PyModule, PyAny, PyResult}; use daft_core::{datatypes::Field, schema::Schema, series::Series}; @@ -10,7 +12,6 @@ use common_error::{DaftError, DaftResult}; use super::super::FunctionEvaluator; use super::{StatefulPythonUDF, StatelessPythonUDF}; use crate::functions::FunctionExpr; -use daft_core::python::{PyDataType, PySeries}; impl FunctionEvaluator for StatelessPythonUDF { fn fn_name(&self) -> &'static str { @@ -38,11 +39,18 @@ impl FunctionEvaluator for StatelessPythonUDF { } } + #[cfg(feature = "python")] fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { self.call_udf(inputs) } + + #[cfg(not(feature = "python"))] + fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + panic!("Cannot evaluate a StatelessPythonUDF without compiling for Python"); + } } +#[cfg(feature = "python")] fn run_udf( py: pyo3::Python, inputs: &[Series], @@ -50,6 +58,8 @@ fn run_udf( bound_args: pyo3::Py, return_dtype: &DataType, ) -> DaftResult { + use daft_core::python::{PyDataType, PySeries}; + // Convert input Rust &[Series] to wrapped Python Vec<&PyAny> let py_series_module = PyModule::import(py, pyo3::intern!(py, "daft.series"))?; let py_series_class = py_series_module.getattr(pyo3::intern!(py, "Series"))?; @@ -88,19 +98,7 @@ fn run_udf( } impl StatelessPythonUDF { - pub fn get_func_and_bound_args( - &self, - py: pyo3::Python, - ) -> DaftResult<(pyo3::Py, pyo3::Py)> { - // Extract the required Python objects to call our run_udf helper - let func = self.partial_func.0.getattr(py, pyo3::intern!(py, "func"))?; - let bound_args = self - .partial_func - .0 - .getattr(py, pyo3::intern!(py, "bound_args"))?; - Ok((func, bound_args)) - } - + #[cfg(feature = "python")] pub fn call_udf(&self, inputs: &[Series]) -> DaftResult { use pyo3::Python; @@ -114,30 +112,17 @@ impl StatelessPythonUDF { Python::with_gil(|py| { // Extract the required Python objects to call our run_udf helper - let (func, bound_args) = self.get_func_and_bound_args(py)?; + let func = self.partial_func.0.getattr(py, pyo3::intern!(py, "func"))?; + let bound_args = self + .partial_func + .0 + .getattr(py, pyo3::intern!(py, "bound_args"))?; + run_udf(py, inputs, func, bound_args, &self.return_dtype) }) } } -impl StatefulPythonUDF { - pub fn get_func_and_bound_args( - &self, - py: pyo3::Python, - ) -> DaftResult<(pyo3::Py, pyo3::Py)> { - // Extract the required Python objects to call our run_udf helper - let func = self - .stateful_partial_func - .0 - .getattr(py, pyo3::intern!(py, "func_cls"))?; - let bound_args = self - .stateful_partial_func - .0 - .getattr(py, pyo3::intern!(py, "bound_args"))?; - Ok((func, bound_args)) - } -} - impl FunctionEvaluator for StatefulPythonUDF { fn fn_name(&self) -> &'static str { "pyclass_udf" @@ -164,6 +149,7 @@ impl FunctionEvaluator for StatefulPythonUDF { } } + #[cfg(feature = "python")] fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { use pyo3::Python; @@ -177,7 +163,14 @@ impl FunctionEvaluator for StatefulPythonUDF { Python::with_gil(|py| { // Extract the required Python objects to call our run_udf helper - let (func, bound_args) = self.get_func_and_bound_args(py)?; + let func = self + .stateful_partial_func + .0 + .getattr(py, pyo3::intern!(py, "func_cls"))?; + let bound_args = self + .stateful_partial_func + .0 + .getattr(py, pyo3::intern!(py, "bound_args"))?; // HACK: This is the naive initialization of the class. It is performed once-per-evaluate which is not ideal. // Ideally we need to allow evaluate to somehow take in the **initialized** Python class that is provided by the Actor. @@ -187,4 +180,9 @@ impl FunctionEvaluator for StatefulPythonUDF { run_udf(py, inputs, func, bound_args, &self.return_dtype) }) } + + #[cfg(not(feature = "python"))] + fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + panic!("Cannot evaluate a StatelessPythonUDF without compiling for Python"); + } } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 890307e661..90cd44604f 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -149,6 +149,7 @@ pub fn lit(item: &PyAny) -> PyResult { #[pyfunction] pub fn stateless_udf( py: Python, + name: &str, partial_stateless_udf: &PyAny, expressions: Vec, return_dtype: PyDataType, @@ -160,7 +161,13 @@ pub fn stateless_udf( let partial_stateless_udf = partial_stateless_udf.to_object(py); let expressions_map: Vec = expressions.into_iter().map(|pyexpr| pyexpr.expr).collect(); Ok(PyExpr { - expr: stateless_udf(partial_stateless_udf, &expressions_map, return_dtype.dtype)?.into(), + expr: stateless_udf( + name, + partial_stateless_udf, + &expressions_map, + return_dtype.dtype, + )? + .into(), }) } @@ -171,6 +178,7 @@ pub fn stateless_udf( #[pyfunction] pub fn stateful_udf( py: Python, + name: &str, partial_stateful_udf: &PyAny, expressions: Vec, return_dtype: PyDataType, @@ -182,7 +190,13 @@ pub fn stateful_udf( let partial_stateful_udf = partial_stateful_udf.to_object(py); let expressions_map: Vec = expressions.into_iter().map(|pyexpr| pyexpr.expr).collect(); Ok(PyExpr { - expr: stateful_udf(partial_stateful_udf, &expressions_map, return_dtype.dtype)?.into(), + expr: stateful_udf( + name, + partial_stateful_udf, + &expressions_map, + return_dtype.dtype, + )? + .into(), }) } diff --git a/src/daft-functions/Cargo.toml b/src/daft-functions/Cargo.toml index f3ed8e2a48..92b2d1bd1a 100644 --- a/src/daft-functions/Cargo.toml +++ b/src/daft-functions/Cargo.toml @@ -21,6 +21,7 @@ python = [ "dep:pyo3", "common-error/python", "daft-core/python", + "daft-io/python", "common-io-config/python" ] diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 0d039c9743..e3f70dcf98 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -44,7 +44,8 @@ python = [ "daft-core/python", "daft-dsl/python", "daft-functions/python", - "daft-table/python" + "daft-table/python", + "daft-scan/python" ] [package] diff --git a/src/daft-plan/src/logical_ops/actor_pool_project.rs b/src/daft-plan/src/logical_ops/actor_pool_project.rs new file mode 100644 index 0000000000..e2971a9543 --- /dev/null +++ b/src/daft-plan/src/logical_ops/actor_pool_project.rs @@ -0,0 +1,91 @@ +use std::sync::Arc; + +use common_treenode::TreeNode; +use daft_core::schema::{Schema, SchemaRef}; +use daft_dsl::{ + functions::{ + python::{PythonUDF, StatefulPythonUDF}, + FunctionExpr, + }, + resolve_exprs, Expr, ExprRef, +}; +use itertools::Itertools; +use snafu::ResultExt; + +use crate::{ + logical_plan::{CreationSnafu, Result}, + LogicalPlan, ResourceRequest, +}; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ActorPoolProject { + // Upstream node. + pub input: Arc, + pub projection: Vec, + pub resource_request: ResourceRequest, + pub projected_schema: SchemaRef, + pub num_actors: usize, +} + +impl ActorPoolProject { + pub(crate) fn try_new( + input: Arc, + projection: Vec, + resource_request: ResourceRequest, + num_actors: usize, + ) -> Result { + let (projection, fields) = + resolve_exprs(projection, input.schema().as_ref()).context(CreationSnafu)?; + let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); + Ok(ActorPoolProject { + input, + projection, + resource_request, + projected_schema, + num_actors, + }) + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push("ActorPoolProject:".to_string()); + res.push(format!( + "Projection = [{}]", + self.projection.iter().map(|e| e.to_string()).join(", ") + )); + res.push(format!( + "UDFs = [{}]", + self.projection + .iter() + .flat_map(|proj| { + let mut udf_names = vec![]; + proj.apply(|e| { + if let Expr::Function { + func: + FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + name, + .. + })), + .. + } = e.as_ref() + { + udf_names.push(name.clone()); + } + Ok(common_treenode::TreeNodeRecursion::Continue) + }) + .unwrap(); + udf_names + }) + .join(", ") + )); + res.push(format!("Num actors = {}", self.num_actors,)); + let resource_request = self.resource_request.multiline_display(); + if !resource_request.is_empty() { + res.push(format!( + "Resource request = {{ {} }}", + resource_request.join(", ") + )); + } + res + } +} diff --git a/src/daft-plan/src/logical_ops/mod.rs b/src/daft-plan/src/logical_ops/mod.rs index 7216d25427..339589deea 100644 --- a/src/daft-plan/src/logical_ops/mod.rs +++ b/src/daft-plan/src/logical_ops/mod.rs @@ -1,3 +1,4 @@ +mod actor_pool_project; mod agg; mod concat; mod distinct; @@ -15,6 +16,7 @@ mod sort; mod source; mod unpivot; +pub use actor_pool_project::ActorPoolProject; pub use agg::Aggregate; pub use concat::Concat; pub use distinct::Distinct; diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index 048cb33f0d..9fe76f37a0 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -2,12 +2,13 @@ use std::{collections::HashMap, sync::Arc}; use common_error::DaftResult; +use common_treenode::TreeNode; use daft_core::{schema::Schema, JoinType}; use daft_dsl::{col, optimization::replace_columns_with_expressions, Expr, ExprRef}; use indexmap::IndexSet; use crate::{ - logical_ops::{Aggregate, Join, Pivot, Project, Source}, + logical_ops::{ActorPoolProject, Aggregate, Join, Pivot, Project, Source}, source_info::SourceInfo, LogicalPlan, ResourceRequest, }; @@ -235,6 +236,50 @@ impl PushDownProjection { Ok(Transformed::No(plan)) } } + LogicalPlan::ActorPoolProject(upstream_actor_pool_projection) => { + // Prune columns from the child ActorPoolProjection that are not used in this projection. + let required_columns = &plan.required_columns()[0]; + if required_columns.len() < upstream_schema.names().len() { + let pruned_upstream_projections = upstream_actor_pool_projection + .projection + .iter() + .filter(|&e| required_columns.contains(e.name())) + .cloned() + .collect::>(); + + // If all StatefulUDF expressions end up being pruned, the ActorPoolProject should essentially become + // a no-op passthrough projection for the rest of the columns. In this case, we should just get rid of it + // altogether since it serves no purpose. + let all_projections_are_just_colexprs = + pruned_upstream_projections.iter().all(|proj| { + !proj.exists(|e| match e.as_ref() { + Expr::Column(_) => false, + // Check for existence of any non-ColumnExprs + _ => true, + }) + }); + let new_upstream = if all_projections_are_just_colexprs { + upstream_plan.children()[0].clone() + } else { + LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + upstream_actor_pool_projection.input.clone(), + pruned_upstream_projections, + upstream_actor_pool_projection.resource_request.clone(), + upstream_actor_pool_projection.num_actors, + )?) + .arced() + }; + let new_plan = Arc::new(plan.with_new_children(&[new_upstream])); + + // Retry optimization now that the upstream node is different. + let new_plan = self + .try_optimize(new_plan.clone())? + .or(Transformed::Yes(new_plan)); + Ok(new_plan) + } else { + Ok(Transformed::No(plan)) + } + } LogicalPlan::Sort(..) | LogicalPlan::Repartition(..) | LogicalPlan::Limit(..) @@ -813,6 +858,119 @@ mod tests { .build(); let expected = plan.clone(); assert_optimized_plan_eq(plan, expected)?; + + Ok(()) + } + + /// Projection<-ActorPoolProject prunes columns from the ActorPoolProject + #[cfg(not(feature = "python"))] + #[test] + fn test_projection_pushdown_into_actorpoolproject() -> DaftResult<()> { + use crate::logical_ops::ActorPoolProject; + use crate::logical_ops::Project; + use daft_dsl::functions::python::{PythonUDF, StatefulPythonUDF}; + use daft_dsl::functions::FunctionExpr; + use daft_dsl::Expr; + use std::default; + + let scan_op = dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Boolean), + Field::new("c", DataType::Int64), + ]); + let scan_node = dummy_scan_node(scan_op).build(); + let mock_stateful_udf = Expr::Function { + func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + name: Arc::new("my-udf".to_string()), + num_expressions: 1, + return_dtype: DataType::Utf8, + })), + inputs: vec![col("c")], + } + .arced(); + + // Select the `udf_results` column, so the ActorPoolProject should apply column pruning to the other columns + let actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + scan_node.clone(), + vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results")], + default::Default::default(), + 8, + )?) + .arced(); + let project = LogicalPlan::Project(Project::try_new( + actor_pool_project, + vec![col("udf_results")], + default::Default::default(), + )?) + .arced(); + + let expected_actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + scan_node.clone(), + vec![mock_stateful_udf.alias("udf_results")], + default::Default::default(), + 8, + )?) + .arced(); + + assert_optimized_plan_eq(project, expected_actor_pool_project)?; + Ok(()) + } + + /// Projection<-ActorPoolProject prunes ActorPoolProject entirely if the stateful projection column is pruned + #[cfg(not(feature = "python"))] + #[test] + fn test_projection_pushdown_into_actorpoolproject_completely_removed() -> DaftResult<()> { + use crate::logical_ops::ActorPoolProject; + use crate::logical_ops::Project; + use daft_dsl::functions::python::{PythonUDF, StatefulPythonUDF}; + use daft_dsl::functions::FunctionExpr; + use daft_dsl::Expr; + use std::default; + + let scan_op = dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Boolean), + Field::new("c", DataType::Int64), + ]); + let scan_node = dummy_scan_node(scan_op.clone()).build(); + let mock_stateful_udf = Expr::Function { + func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + name: Arc::new("my-udf".to_string()), + num_expressions: 1, + return_dtype: DataType::Utf8, + })), + inputs: vec![col("c")], + } + .arced(); + + // Select only col("a"), so the ActorPoolProject node is now redundant and should be removed + let actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + scan_node.clone(), + vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results")], + default::Default::default(), + 8, + )?) + .arced(); + let project = LogicalPlan::Project(Project::try_new( + actor_pool_project, + vec![col("a")], + default::Default::default(), + )?) + .arced(); + + // Optimized plan will push the projection all the way down into the scan + let expected_scan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns { + limit: None, + partition_filters: None, + columns: Some(Arc::new(vec!["a".to_string()])), + filters: None, + }, + ) + .build(); + + assert_optimized_plan_eq(project, expected_scan)?; Ok(()) } } diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 1dbe58323e..fd5cdc694d 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -14,6 +14,7 @@ pub use crate::logical_ops::*; pub enum LogicalPlan { Source(Source), Project(Project), + ActorPoolProject(ActorPoolProject), Filter(Filter), Limit(Limit), Explode(Explode), @@ -42,6 +43,9 @@ impl LogicalPlan { Self::Project(Project { projected_schema, .. }) => projected_schema.clone(), + Self::ActorPoolProject(ActorPoolProject { + projected_schema, .. + }) => projected_schema.clone(), Self::Filter(Filter { input, .. }) => input.schema(), Self::Limit(Limit { input, .. }) => input.schema(), Self::Explode(Explode { @@ -78,6 +82,10 @@ impl LogicalPlan { .collect(); vec![res] } + Self::ActorPoolProject(ActorPoolProject { projection, .. }) => { + let res = projection.iter().flat_map(get_required_columns).collect(); + vec![res] + } Self::Filter(filter) => { vec![get_required_columns(&filter.predicate) .iter() @@ -163,6 +171,7 @@ impl LogicalPlan { match self { Self::Source(..) => vec![], Self::Project(Project { input, .. }) => vec![input.clone()], + Self::ActorPoolProject(ActorPoolProject { input, .. }) => vec![input.clone()], Self::Filter(Filter { input, .. }) => vec![input.clone()], Self::Limit(Limit { input, .. }) => vec![input.clone()], Self::Explode(Explode { input, .. }) => vec![input.clone()], @@ -189,6 +198,7 @@ impl LogicalPlan { Self::Project(Project { projection, resource_request, .. }) => Self::Project(Project::try_new( input.clone(), projection.clone(), resource_request.clone(), ).unwrap()), + Self::ActorPoolProject(ActorPoolProject {projection, resource_request, num_actors, ..}) => Self::ActorPoolProject(ActorPoolProject::try_new(input.clone(), projection.clone(), resource_request.clone(), *num_actors).unwrap()), Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::try_new(input.clone(), predicate.clone()).unwrap()), Self::Limit(Limit { limit, eager, .. }) => Self::Limit(Limit::new(input.clone(), *limit, *eager)), Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()), @@ -233,6 +243,7 @@ impl LogicalPlan { let name = match self { Self::Source(..) => "Source", Self::Project(..) => "Project", + Self::ActorPoolProject(..) => "ActorPoolProject", Self::Filter(..) => "Filter", Self::Limit(..) => "Limit", Self::Explode(..) => "Explode", @@ -255,6 +266,7 @@ impl LogicalPlan { match self { Self::Source(source) => source.multiline_display(), Self::Project(projection) => projection.multiline_display(), + Self::ActorPoolProject(projection) => projection.multiline_display(), Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")], Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")], Self::Explode(explode) => explode.multiline_display(), diff --git a/src/daft-plan/src/physical_ops/actor_pool_project.rs b/src/daft-plan/src/physical_ops/actor_pool_project.rs new file mode 100644 index 0000000000..2b99a50d24 --- /dev/null +++ b/src/daft-plan/src/physical_ops/actor_pool_project.rs @@ -0,0 +1,104 @@ +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use common_treenode::TreeNode; +use daft_dsl::{ + functions::{ + python::{PythonUDF, StatefulPythonUDF}, + FunctionExpr, + }, + Expr, ExprRef, +}; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::{ + partitioning::translate_clustering_spec, ClusteringSpec, PhysicalPlanRef, ResourceRequest, +}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct ActorPoolProject { + pub input: PhysicalPlanRef, + pub projection: Vec, + pub resource_request: ResourceRequest, + pub clustering_spec: Arc, + pub num_actors: usize, +} + +impl ActorPoolProject { + pub(crate) fn try_new( + input: PhysicalPlanRef, + projection: Vec, + resource_request: ResourceRequest, + num_actors: usize, + ) -> DaftResult { + let clustering_spec = translate_clustering_spec(input.clustering_spec(), &projection); + + if !projection.iter().any(|expr| { + matches!( + expr.as_ref(), + Expr::Function { + func: FunctionExpr::Python(PythonUDF::Stateful(_)), + .. + } + ) + }) { + return Err(DaftError::InternalError("Cannot create ActorPoolProject from expressions that don't contain a stateful Python UDF".to_string())); + } + + Ok(ActorPoolProject { + input, + projection, + resource_request, + clustering_spec, + num_actors, + }) + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push("ActorPoolProject:".to_string()); + res.push(format!( + "Projection = [{}]", + self.projection.iter().map(|e| e.to_string()).join(", ") + )); + res.push(format!( + "UDFs = [{}]", + self.projection + .iter() + .flat_map(|proj| { + let mut udf_names = vec![]; + proj.apply(|e| { + if let Expr::Function { + func: + FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + name, + .. + })), + .. + } = e.as_ref() + { + udf_names.push(name.clone()); + } + Ok(common_treenode::TreeNodeRecursion::Continue) + }) + .unwrap(); + udf_names + }) + .join(", ") + )); + res.push(format!("Num actors = {}", self.num_actors,)); + res.push(format!( + "Clustering spec = {{ {} }}", + self.clustering_spec.multiline_display().join(", ") + )); + let resource_request = self.resource_request.multiline_display(); + if !resource_request.is_empty() { + res.push(format!( + "Resource request = {{ {} }}", + resource_request.join(", ") + )); + } + res + } +} diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index bfd66777c4..d37a48fd2a 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -1,3 +1,4 @@ +mod actor_pool_project; mod agg; mod broadcast_join; mod coalesce; @@ -31,6 +32,7 @@ mod sort_merge_join; mod split; mod unpivot; +pub use actor_pool_project::ActorPoolProject; pub use agg::Aggregate; pub use broadcast_join::BroadcastJoin; pub use coalesce::Coalesce; diff --git a/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs b/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs index a5eef7c432..1d81d11fa4 100644 --- a/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs +++ b/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs @@ -86,6 +86,16 @@ impl PhysicalOptimizerRule for ReorderPartitionKeys { )?); Ok(Transformed::yes(c.with_plan(new_plan.into()).propagate())) } + PhysicalPlan::ActorPoolProject(crate::physical_ops::ActorPoolProject { input, projection, resource_request, clustering_spec: _, num_actors }) => { + let new_plan = PhysicalPlan::ActorPoolProject(crate::physical_ops::ActorPoolProject { + input: input.clone(), + projection: projection.clone(), + resource_request: resource_request.clone(), + num_actors: *num_actors, + clustering_spec: new_spec.into(), + }); + Ok(Transformed::yes(c.with_plan(new_plan.into()).propagate())) + } PhysicalPlan::Explode(Explode { input, to_explode, .. }) => { // can't use try_new because we are setting the clustering spec ourselves let new_plan = PhysicalPlan::Explode(Explode { diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 04d0e70071..7aac3203d2 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -19,6 +19,7 @@ pub enum PhysicalPlan { TabularScan(TabularScan), EmptyScan(EmptyScan), Project(Project), + ActorPoolProject(ActorPoolProject), Filter(Filter), Limit(Limit), Explode(Explode), @@ -113,6 +114,9 @@ impl PhysicalPlan { Self::Project(Project { clustering_spec, .. }) => clustering_spec.clone(), + Self::ActorPoolProject(ActorPoolProject { + clustering_spec, .. + }) => clustering_spec.clone(), Self::Filter(Filter { input, .. }) => input.clustering_spec(), Self::Limit(Limit { input, .. }) => input.clustering_spec(), Self::Explode(Explode { @@ -299,7 +303,8 @@ impl PhysicalPlan { } } Self::Project(Project { input, .. }) - | Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, .. }) => { + | Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, .. }) + | Self::ActorPoolProject(ActorPoolProject { input, .. }) => { // TODO(sammy), we need the schema to estimate the new size per row input.approximate_stats() } @@ -410,6 +415,7 @@ impl PhysicalPlan { Self::InMemoryScan(..) => vec![], Self::TabularScan(..) | Self::EmptyScan(..) => vec![], Self::Project(Project { input, .. }) => vec![input.clone()], + Self::ActorPoolProject(ActorPoolProject { input, .. }) => vec![input.clone()], Self::Filter(Filter { input, .. }) => vec![input.clone()], Self::Limit(Limit { input, .. }) => vec![input.clone()], Self::Explode(Explode { input, .. }) => vec![input.clone()], @@ -453,7 +459,6 @@ impl PhysicalPlan { pub fn with_new_children(&self, children: &[PhysicalPlanRef]) -> PhysicalPlan { match children { [input] => match self { - #[cfg(feature = "python")] Self::InMemoryScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), Self::TabularScan(..) | Self::EmptyScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), @@ -461,6 +466,7 @@ impl PhysicalPlan { Self::Project(Project::new_with_clustering_spec( input.clone(), projection.clone(), resource_request.clone(), clustering_spec.clone(), ).unwrap()), + Self::ActorPoolProject(ActorPoolProject {projection, resource_request, num_actors, ..}) => Self::ActorPoolProject(ActorPoolProject::try_new(input.clone(), projection.clone(), resource_request.clone(), *num_actors).unwrap()), Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::new(input.clone(), predicate.clone())), Self::Limit(Limit { limit, eager, num_partitions, .. }) => Self::Limit(Limit::new(input.clone(), *limit, *eager, *num_partitions)), Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()), @@ -486,8 +492,7 @@ impl PhysicalPlan { Self::DeltaLakeWrite(DeltaLakeWrite {schema, delta_lake_info, .. }) => Self::DeltaLakeWrite(DeltaLakeWrite::new(schema.clone(), delta_lake_info.clone(), input.clone())), #[cfg(feature = "python")] Self::LanceWrite(LanceWrite { schema, lance_info, .. }) => Self::LanceWrite(LanceWrite::new(schema.clone(), lance_info.clone(), input.clone())), - // we should really remove this catch-all - _ => panic!("Physical op {:?} has two inputs, but got one", self), + Self::Concat(_) | Self::HashJoin(_) | Self::SortMergeJoin(_) | Self::BroadcastJoin(_) => panic!("{} requires more than 1 input, but received: {}", self, children.len()), }, [input1, input2] => match self { #[cfg(feature = "python")] @@ -516,6 +521,7 @@ impl PhysicalPlan { Self::TabularScan(..) => "TabularScan", Self::EmptyScan(..) => "EmptyScan", Self::Project(..) => "Project", + Self::ActorPoolProject(..) => "ActorPoolProject", Self::Filter(..) => "Filter", Self::Limit(..) => "Limit", Self::Explode(..) => "Explode", @@ -555,6 +561,7 @@ impl PhysicalPlan { Self::TabularScan(tabular_scan) => tabular_scan.multiline_display(), Self::EmptyScan(empty_scan) => empty_scan.multiline_display(), Self::Project(project) => project.multiline_display(), + Self::ActorPoolProject(ap_project) => ap_project.multiline_display(), Self::Filter(filter) => filter.multiline_display(), Self::Limit(limit) => limit.multiline_display(), Self::Explode(explode) => explode.multiline_display(), diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index b62a762db8..5a4e05ffbc 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -18,8 +18,9 @@ use daft_dsl::{is_partition_compatible, ExprRef}; use daft_scan::PhysicalScanInfo; use crate::logical_ops::{ - Aggregate as LogicalAggregate, Distinct as LogicalDistinct, Explode as LogicalExplode, - Filter as LogicalFilter, Join as LogicalJoin, Limit as LogicalLimit, + ActorPoolProject as LogicalActorPoolProject, Aggregate as LogicalAggregate, + Distinct as LogicalDistinct, Explode as LogicalExplode, Filter as LogicalFilter, + Join as LogicalJoin, Limit as LogicalLimit, MonotonicallyIncreasingId as LogicalMonotonicallyIncreasingId, Pivot as LogicalPivot, Project as LogicalProject, Repartition as LogicalRepartition, Sample as LogicalSample, Sink as LogicalSink, Sort as LogicalSort, Source, Unpivot as LogicalUnpivot, @@ -110,6 +111,21 @@ pub(super) fn translate_single_logical_node( )?) .arced()) } + LogicalPlan::ActorPoolProject(LogicalActorPoolProject { + projection, + resource_request, + num_actors, + .. + }) => { + let input_physical = physical_children.pop().expect("requires 1 input"); + Ok(PhysicalPlan::ActorPoolProject(ActorPoolProject::try_new( + input_physical, + projection.clone(), + resource_request.clone(), + *num_actors, + )?) + .arced()) + } LogicalPlan::Filter(LogicalFilter { predicate, .. }) => { let input_physical = physical_children.pop().expect("requires 1 input"); Ok(PhysicalPlan::Filter(Filter::new(input_physical, predicate.clone())).arced()) diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index dbe79d0961..8b35e5baa2 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -302,6 +302,60 @@ fn physical_plan_to_partition_tasks( .call1((upstream_iter, projection_pyexprs, resource_request.clone()))?; Ok(py_iter.into()) } + + PhysicalPlan::ActorPoolProject(ActorPoolProject { + input, + projection, + resource_request, + num_actors, + .. + }) => { + use daft_dsl::{ + common_treenode::TreeNode, + functions::{ + python::{PythonUDF, StatefulPythonUDF}, + FunctionExpr, + }, + }; + + // Extract any StatefulUDFs from the projection + let mut py_partial_udfs = HashMap::new(); + projection.iter().for_each(|e| { + e.apply(|child| { + if let Expr::Function { + func: + FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + name, + stateful_partial_func: py_partial_udf, + .. + })), + .. + } = child.as_ref() + { + py_partial_udfs.insert(name.as_ref().to_string(), py_partial_udf.0.clone()); + } + Ok(daft_dsl::common_treenode::TreeNodeRecursion::Continue) + }) + .unwrap(); + }); + + let upstream_iter = physical_plan_to_partition_tasks(input, py, psets)?; + let py_iter = py + .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "actor_pool_project"))? + .call1(( + upstream_iter, + projection + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect::>(), + py_partial_udfs, + resource_request.clone(), + *num_actors, + ))?; + Ok(py_iter.into()) + } + PhysicalPlan::Filter(Filter { input, predicate }) => { let upstream_iter = physical_plan_to_partition_tasks(input, py, psets)?; let expressions_mod = py.import(pyo3::intern!(py, "daft.expressions.expressions"))?;