Skip to content

Commit

Permalink
feat(python): Add post-optimization callback (pola-rs#15972)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored and Wouittone committed Jun 22, 2024
1 parent dbb4954 commit 67424fa
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 19 deletions.
37 changes: 29 additions & 8 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,17 +590,22 @@ impl LazyFrame {
Ok(lp_top)
}

#[allow(unused_mut)]
fn prepare_collect(
mut self,
fn prepare_collect_post_opt<P>(
self,
check_sink: bool,
) -> PolarsResult<(ExecutionState, Box<dyn Executor>, bool)> {
let mut expr_arena = Arena::with_capacity(256);
let mut lp_arena = Arena::with_capacity(128);
post_opt: P,
) -> PolarsResult<(ExecutionState, Box<dyn Executor>, bool)>
where
P: Fn(Node, &mut Arena<IR>, &mut Arena<AExpr>) -> PolarsResult<()>,
{
let mut expr_arena = Arena::with_capacity(16);
let mut lp_arena = Arena::with_capacity(16);
let mut scratch = vec![];
let lp_top =
self.optimize_with_scratch(&mut lp_arena, &mut expr_arena, &mut scratch, false)?;

post_opt(lp_top, &mut lp_arena, &mut expr_arena)?;

// sink should be replaced
let no_file_sink = if check_sink {
!matches!(lp_arena.get(lp_top), IR::Sink { .. })
Expand All @@ -613,6 +618,23 @@ impl LazyFrame {
Ok((state, physical_plan, no_file_sink))
}

// post_opt: A function that is called after optimization. This can be used to modify the IR jit.
pub fn _collect_post_opt<P>(self, post_opt: P) -> PolarsResult<DataFrame>
where
P: Fn(Node, &mut Arena<IR>, &mut Arena<AExpr>) -> PolarsResult<()>,
{
let (mut state, mut physical_plan, _) = self.prepare_collect_post_opt(false, post_opt)?;
physical_plan.execute(&mut state)
}

#[allow(unused_mut)]
fn prepare_collect(
self,
check_sink: bool,
) -> PolarsResult<(ExecutionState, Box<dyn Executor>, bool)> {
self.prepare_collect_post_opt(check_sink, |_, _, _| Ok(()))
}

/// Execute all the lazy operations and collect them into a [`DataFrame`].
///
/// The query is optimized prior to execution.
Expand All @@ -631,8 +653,7 @@ impl LazyFrame {
/// }
/// ```
pub fn collect(self) -> PolarsResult<DataFrame> {
let (mut state, mut physical_plan, _) = self.prepare_collect(false)?;
physical_plan.execute(&mut state)
self._collect_post_opt(|_, _, _| Ok(()))
}

/// Profile a LazyFrame.
Expand Down
6 changes: 5 additions & 1 deletion py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,7 @@ def collect(
streaming: bool = False,
background: bool = False,
_eager: bool = False,
**_kwargs: Any,
) -> DataFrame | InProcessQuery:
"""
Materialize this LazyFrame into a DataFrame.
Expand Down Expand Up @@ -1807,7 +1808,10 @@ def collect(
if background:
return InProcessQuery(ldf.collect_concurrently())

return wrap_df(ldf.collect())
# Only for testing purposes atm.
callback = _kwargs.get("post_opt_callback")

return wrap_df(ldf.collect(callback))

@overload
def collect_async(
Expand Down
35 changes: 33 additions & 2 deletions py-polars/src/lazyframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::arrow_interop::to_rust::pyarrow_schema_to_rust;
use crate::error::PyPolarsErr;
use crate::expr::ToExprs;
use crate::file::get_file_like;
use crate::lazyframe::visit::NodeTraverser;
use crate::prelude::*;
use crate::{PyDataFrame, PyExpr, PyLazyGroupBy};

Expand Down Expand Up @@ -566,12 +567,42 @@ impl PyLazyFrame {
Ok((df.into(), time_df.into()))
}

fn collect(&self, py: Python) -> PyResult<PyDataFrame> {
fn collect(&self, py: Python, lamdba_post_opt: Option<PyObject>) -> PyResult<PyDataFrame> {
// if we don't allow threads and we have udfs trying to acquire the gil from different
// threads we deadlock.
let df = py.allow_threads(|| {
let ldf = self.ldf.clone();
ldf.collect().map_err(PyPolarsErr::from)
if let Some(lambda) = lamdba_post_opt {
ldf._collect_post_opt(|root, lp_arena, expr_arena| {
Python::with_gil(|py| {
let nt = NodeTraverser::new(
root,
std::mem::take(lp_arena),
std::mem::take(expr_arena),
);

// Get a copy of the arena's.
let arenas = nt.get_arenas();

// Pass the node visitor which allows the python callback to replace parts of the query plan.
// Remove "cuda" or specify better once we have multiple post-opt callbacks.
lambda.call1(py, (nt,)).map_err(
|e| polars_err!(ComputeError: "'cuda' conversion failed: {}", e),
)?;

// Unpack the arena's.
// At this point the `nt` is useless.

std::mem::swap(lp_arena, &mut *arenas.0.lock().unwrap());
std::mem::swap(expr_arena, &mut *arenas.1.lock().unwrap());

Ok(())
})
})
} else {
ldf.collect()
}
.map_err(PyPolarsErr::from)
})?;
Ok(df.into())
}
Expand Down
30 changes: 26 additions & 4 deletions py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl From<&ExprIR> for PyExprIR {
}

#[pyclass]
struct NodeTraverser {
pub(crate) struct NodeTraverser {
root: Node,
lp_arena: Arc<Mutex<Arena<IR>>>,
expr_arena: Arc<Mutex<Arena<AExpr>>>,
Expand All @@ -48,6 +48,22 @@ struct NodeTraverser {
}

impl NodeTraverser {
pub(crate) fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
Self {
root,
lp_arena: Arc::new(Mutex::new(lp_arena)),
expr_arena: Arc::new(Mutex::new(expr_arena)),
scratch: vec![],
expr_scratch: vec![],
expr_mapping: None,
}
}

#[allow(clippy::type_complexity)]
pub(crate) fn get_arenas(&self) -> (Arc<Mutex<Arena<IR>>>, Arc<Mutex<Arena<AExpr>>>) {
(self.lp_arena.clone(), self.expr_arena.clone())
}

fn fill_inputs(&mut self) {
let lp_arena = self.lp_arena.lock().unwrap();
let this_node = lp_arena.get(self.root);
Expand Down Expand Up @@ -120,12 +136,19 @@ impl NodeTraverser {
self.root = Node(node);
}

/// Get the current node in the plan.
fn get_node(&mut self) -> usize {
self.root.0
}

/// Set a python UDF that will replace the subtree location with this function src.
fn set_udf(&mut self, function: PyObject, schema: Wrap<Schema>) {
fn set_udf(&mut self, function: PyObject) {
let mut lp_arena = self.lp_arena.lock().unwrap();
let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();
let ir = IR::PythonScan {
options: PythonOptions {
scan_fn: Some(function.into()),
schema: Arc::new(schema.0),
schema,
output_schema: None,
with_columns: None,
pyarrow: false,
Expand All @@ -134,7 +157,6 @@ impl NodeTraverser {
},
predicate: None,
};
let mut lp_arena = self.lp_arena.lock().unwrap();
lp_arena.replace(self.root, ir);
}

Expand Down
8 changes: 4 additions & 4 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ static ALLOC: Jemalloc = Jemalloc;
static ALLOC: MiMalloc = MiMalloc;

#[pymodule]
fn nodes(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
fn _ir_nodes(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
use crate::lazyframe::visitor::nodes::*;
m.add_class::<PythonScan>().unwrap();
m.add_class::<Slice>().unwrap();
Expand All @@ -120,7 +120,7 @@ fn nodes(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
}

#[pymodule]
fn expr_nodes(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
fn _expr_nodes(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
use crate::lazyframe::visitor::expr_nodes::*;
use crate::lazyframe::PyExprIR;
// Expressions
Expand Down Expand Up @@ -164,9 +164,9 @@ fn polars(py: Python, m: &Bound<PyModule>) -> PyResult<()> {

// Submodules
// LogicalPlan objects
m.add_wrapped(wrap_pymodule!(nodes))?;
m.add_wrapped(wrap_pymodule!(_ir_nodes))?;
// Expr objects
m.add_wrapped(wrap_pymodule!(expr_nodes))?;
m.add_wrapped(wrap_pymodule!(_expr_nodes))?;
// Functions - eager
m.add_wrapped(wrap_pyfunction!(functions::concat_df))
.unwrap();
Expand Down
89 changes: 89 additions & 0 deletions py-polars/tests/unit/lazyframe/cuda/test_node_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

import typing
from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable

import polars as pl
from polars._utils.wrap import wrap_df
from polars.polars import _ir_nodes

if TYPE_CHECKING:
import pandas as pd


@typing.no_type_check
def test_run_on_pandas() -> None:
# Simple join example, missing multiple columns, slices, etc.
@typing.no_type_check
def join(inputs: list[Callable], obj: Any, _node_traverer: Any) -> Callable:
assert len(obj.left_on) == 1
assert len(obj.right_on) == 1
left_on = obj.left_on[0].output_name
right_on = obj.right_on[0].output_name

assert len(inputs) == 2

def run(inputs: list[Callable]):
# materialize inputs
inputs = [call() for call in inputs]
return inputs[0].merge(inputs[1], left_on=left_on, right_on=right_on)

return partial(run, inputs)

# Simple scan example, missing predicates, columns pruning, slices, etc.
@typing.no_type_check
def df_scan(_inputs: None, obj: Any, _: Any) -> pd.DataFrame:
assert obj.selection is None
return lambda: wrap_df(obj.df).to_pandas()

@lru_cache(1)
@typing.no_type_check
def get_node_converters():
return {
_ir_nodes.Join: join,
_ir_nodes.DataFrameScan: df_scan,
}

@typing.no_type_check
def get_input(node_traverser):
current_node = node_traverser.get_node()

inputs_callable = []
for inp in node_traverser.get_inputs():
node_traverser.set_node(inp)
inputs_callable.append(get_input(node_traverser))

node_traverser.set_node(current_node)
ir_node = node_traverser.view_current_node()
return get_node_converters()[ir_node.__class__](
inputs_callable, ir_node, node_traverser
)

@typing.no_type_check
def run_on_pandas(node_traverser) -> None:
current_node = node_traverser.get_node()

callback = get_input(node_traverser)

@typing.no_type_check
def run_callback(
columns: list[str] | None, _: Any, n_rows: int | None
) -> pl.DataFrame:
assert n_rows is None
assert columns is None

# produce a wrong result to ensure the callback has run.
return pl.from_pandas(callback() * 2)

node_traverser.set_node(current_node)
node_traverser.set_udf(run_callback)

# Polars query that will run on pandas
q1 = pl.LazyFrame({"foo": [1, 2, 3]})
q2 = pl.LazyFrame({"foo": [1], "bar": [2]})
q = q1.join(q2, on="foo")
assert q.collect(post_opt_callback=run_on_pandas).to_dict(as_series=False) == {
"foo": [2],
"bar": [4],
}

0 comments on commit 67424fa

Please sign in to comment.