Skip to content

Commit

Permalink
[FEAT] Add ActorPoolProject logical and physical plans (#2601)
Browse files Browse the repository at this point in the history
Adds new LogicalPlan and PhysicalPlan nodes for ActorPoolProject

Note that this is not yet reachable via user-facing code, since nothing
from our API should be able to create logical `ActorPoolProject` plans
yet (this will be left to a follow-on PR and we can feature-flag it)

---

**ActorPoolProject** works just like a normal Project (it has a list of
expressions that it is responsible for evaluating). The main differences
are:

1. The scheduler is responsible for running these projections on top of
an actor pool, instead of naively running it like a normal projection
2. The expressions are expected to contain `PyPartialUDFs`, which the
execution backend is responsible for initializing once-per-actor.
3. (TODO in a follow-up PR) During evaluation of the expressions, we
will be looking for the initialized Python classes from some global
state. The projections in an ActorPoolProject cannot be evaluated
without that global state having been initialized.

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
Co-authored-by: Desmond Cheong <[email protected]>
  • Loading branch information
3 people committed Aug 6, 2024
1 parent aef5999 commit fd8c940
Show file tree
Hide file tree
Showing 21 changed files with 619 additions and 55 deletions.
4 changes: 2 additions & 2 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1179,10 +1179,10 @@ def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ...
def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ...
def series_lit(item: PySeries) -> PyExpr: ...
def stateless_udf(
partial_stateless_udf: PartialStatelessUDF, expressions: list[PyExpr], return_dtype: PyDataType
name: str, partial_stateless_udf: PartialStatelessUDF, expressions: list[PyExpr], return_dtype: PyDataType
) -> PyExpr: ...
def stateful_udf(
partial_stateful_udf: PartialStatefulUDF, expressions: list[PyExpr], return_dtype: PyDataType
name: str, partial_stateful_udf: PartialStatefulUDF, expressions: list[PyExpr], return_dtype: PyDataType
) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
Expand Down
12 changes: 12 additions & 0 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties

from daft.udf import PartialStatefulUDF


# A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks.
InProgressPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionTaskBuilder[PartitionT]]]
Expand Down Expand Up @@ -199,6 +201,16 @@ def pipeline_instruction(
)


def actor_pool_project(
child_plan: InProgressPhysicalPlan[PartitionT],
projection: ExpressionsProjection,
partial_stateful_udfs: dict[str, PartialStatefulUDF],
resource_request: execution_step.ResourceRequest,
num_actors: int,
) -> InProgressPhysicalPlan[PartitionT]:
raise NotImplementedError("Execution of ActorPoolProjects not yet implemented")


def monotonically_increasing_id(
child_plan: InProgressPhysicalPlan[PartitionT], column_name: str
) -> InProgressPhysicalPlan[PartitionT]:
Expand Down
19 changes: 19 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties

from daft.udf import PartialStatefulUDF


def scan_with_tasks(
scan_tasks: list[ScanTask],
Expand Down Expand Up @@ -73,6 +75,23 @@ def project(
)


def actor_pool_project(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
projection: list[PyExpr],
partial_stateful_udfs: dict[str, PartialStatefulUDF],
resource_request: ResourceRequest,
num_actors: int,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in projection])
return physical_plan.actor_pool_project(
child_plan=input,
projection=expr_projection,
partial_stateful_udfs=partial_stateful_udfs,
resource_request=resource_request,
num_actors=num_actors,
)


class ShimExplodeOp(MapPartitionOp):
explode_columns: ExpressionsProjection

Expand Down
18 changes: 14 additions & 4 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,22 @@ def _to_expression(obj: object) -> Expression:

@staticmethod
def stateless_udf(
name: builtins.str,
partial: PartialStatelessUDF,
expressions: builtins.list[Expression],
return_dtype: DataType,
) -> Expression:
return Expression._from_pyexpr(_stateless_udf(partial, [e._expr for e in expressions], return_dtype._dtype))
return Expression._from_pyexpr(
_stateless_udf(name, partial, [e._expr for e in expressions], return_dtype._dtype)
)

@staticmethod
def stateful_udf(
partial: PartialStatefulUDF, expressions: builtins.list[Expression], return_dtype: DataType
name: builtins.str, partial: PartialStatefulUDF, expressions: builtins.list[Expression], return_dtype: DataType
) -> Expression:
return Expression._from_pyexpr(_stateful_udf(partial, [e._expr for e in expressions], return_dtype._dtype))
return Expression._from_pyexpr(
_stateful_udf(name, partial, [e._expr for e in expressions], return_dtype._dtype)
)

def __bool__(self) -> bool:
raise ValueError(
Expand Down Expand Up @@ -799,7 +804,12 @@ def apply(self, func: Callable, return_dtype: DataType) -> Expression:
def batch_func(self_series):
return [func(x) for x in self_series.to_pylist()]

return StatelessUDF(func=batch_func, return_dtype=return_dtype)(self)
name = getattr(func, "__module__", "") # type: ignore[call-overload]
if name:
name = name + "."
name = name + getattr(func, "__qualname__") # type: ignore[call-overload]

return StatelessUDF(name=name, func=batch_func, return_dtype=return_dtype)(self)

def is_null(self) -> Expression:
"""Checks if values in the Expression are Null (a special value indicating missing data)
Expand Down
12 changes: 12 additions & 0 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class PartialStatefulUDF:

@dataclasses.dataclass
class StatelessUDF(UDF):
name: str
func: UserProvidedPythonFunction
return_dtype: DataType

Expand All @@ -178,6 +179,7 @@ def __call__(self, *args, **kwargs) -> Expression:
bound_args = BoundUDFArgs(self.bind_func(*args, **kwargs))
expressions = list(bound_args.expressions().values())
return Expression.stateless_udf(
name=self.name,
partial=PartialStatelessUDF(self.func, self.return_dtype, bound_args),
expressions=expressions,
return_dtype=self.return_dtype,
Expand All @@ -195,6 +197,7 @@ def __hash__(self) -> int:

@dataclasses.dataclass
class StatefulUDF(UDF):
name: str
cls: type
return_dtype: DataType

Expand All @@ -210,6 +213,7 @@ def __call__(self, *args, **kwargs) -> Expression:
bound_args = BoundUDFArgs(self.bind_func(*args, **kwargs))
expressions = list(bound_args.expressions().values())
return Expression.stateful_udf(
name=self.name,
partial=PartialStatefulUDF(self.cls, self.return_dtype, bound_args),
expressions=expressions,
return_dtype=self.return_dtype,
Expand Down Expand Up @@ -284,13 +288,21 @@ def udf(
"""

def _udf(f: UserProvidedPythonFunction | type) -> UDF:
# Grab a name for the UDF. It **should** be unique.
name = getattr(f, "__module__", "") # type: ignore[call-overload]
if name:
name = name + "."
name = name + getattr(f, "__qualname__") # type: ignore[call-overload]

if inspect.isclass(f):
return StatefulUDF(
name=name,
cls=f,
return_dtype=return_dtype,
)
else:
return StatelessUDF(
name=name,
func=f,
return_dtype=return_dtype,
)
Expand Down
4 changes: 0 additions & 4 deletions src/daft-dsl/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ use daft_core::{datatypes::Field, schema::Schema, series::Series};

use serde::{Deserialize, Serialize};

#[cfg(feature = "python")]
pub mod python;
#[cfg(feature = "python")]
use python::PythonUDF;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
Expand All @@ -52,7 +50,6 @@ pub enum FunctionExpr {
Struct(StructExpr),
Json(JsonExpr),
Image(ImageExpr),
#[cfg(feature = "python")]
Python(PythonUDF),
Partitioning(PartitioningExpr),
}
Expand Down Expand Up @@ -83,7 +80,6 @@ impl FunctionExpr {
Struct(expr) => expr.get_evaluator(),
Json(expr) => expr.get_evaluator(),
Image(expr) => expr.get_evaluator(),
#[cfg(feature = "python")]
Python(expr) => expr.get_evaluator(),
Partitioning(expr) => expr.get_evaluator(),
}
Expand Down
47 changes: 46 additions & 1 deletion src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#[cfg(feature = "python")]
mod partial_udf;
mod udf;

use std::sync::Arc;

use common_error::DaftResult;
use daft_core::datatypes::DataType;
use serde::{Deserialize, Serialize};
Expand All @@ -27,25 +30,32 @@ impl PythonUDF {

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct StatelessPythonUDF {
pub name: Arc<String>,
#[cfg(feature = "python")]
partial_func: partial_udf::PyPartialUDF,
num_expressions: usize,
pub return_dtype: DataType,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct StatefulPythonUDF {
pub name: Arc<String>,
#[cfg(feature = "python")]
pub stateful_partial_func: partial_udf::PyPartialUDF,
num_expressions: usize,
pub num_expressions: usize,
pub return_dtype: DataType,
}

#[cfg(feature = "python")]
pub fn stateless_udf(
name: &str,
py_partial_stateless_udf: pyo3::PyObject,
expressions: &[ExprRef],
return_dtype: DataType,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF {
name: name.to_string().into(),
partial_func: partial_udf::PyPartialUDF(py_partial_stateless_udf),
num_expressions: expressions.len(),
return_dtype,
Expand All @@ -54,17 +64,52 @@ pub fn stateless_udf(
})
}

#[cfg(not(feature = "python"))]
pub fn stateless_udf(
name: &str,
expressions: &[ExprRef],
return_dtype: DataType,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF {
name: name.to_string().into(),
num_expressions: expressions.len(),
return_dtype,
})),
inputs: expressions.into(),
})
}

#[cfg(feature = "python")]
pub fn stateful_udf(
name: &str,
py_stateful_partial_func: pyo3::PyObject,
expressions: &[ExprRef],
return_dtype: DataType,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
name: name.to_string().into(),
stateful_partial_func: partial_udf::PyPartialUDF(py_stateful_partial_func),
num_expressions: expressions.len(),
return_dtype,
})),
inputs: expressions.into(),
})
}

#[cfg(not(feature = "python"))]
pub fn stateful_udf(
name: &str,
expressions: &[ExprRef],
return_dtype: DataType,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
name: name.to_string().into(),
num_expressions: expressions.len(),
return_dtype,
})),
inputs: expressions.into(),
})
}
Loading

0 comments on commit fd8c940

Please sign in to comment.