Skip to content

Commit

Permalink
refactor: Make expression output type known (#19195)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 11, 2024
1 parent 1cc2cb8 commit 251e171
Show file tree
Hide file tree
Showing 19 changed files with 180 additions and 228 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ impl DataType {
ArrowDataType::Extension(name, _, _) if name.as_str() == "POLARS_EXTENSION_TYPE" => {
#[cfg(feature = "object")]
{
DataType::Object("extension", None)
DataType::Object("object", None)
}
#[cfg(not(feature = "object"))]
{
Expand Down
58 changes: 21 additions & 37 deletions crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ pub struct ApplyExpr {
function_operates_on_scalar: bool,
allow_rename: bool,
pass_name_to_apply: bool,
input_schema: Option<SchemaRef>,
input_schema: SchemaRef,
allow_threading: bool,
check_lengths: bool,
allow_group_aware: bool,
output_dtype: Option<DataType>,
output_field: Field,
}

impl ApplyExpr {
Expand All @@ -38,8 +38,8 @@ impl ApplyExpr {
expr: Expr,
options: FunctionOptions,
allow_threading: bool,
input_schema: Option<SchemaRef>,
output_dtype: Option<DataType>,
input_schema: SchemaRef,
output_field: Field,
returns_scalar: bool,
) -> Self {
#[cfg(debug_assertions)]
Expand All @@ -62,30 +62,7 @@ impl ApplyExpr {
allow_threading,
check_lengths: options.check_lengths(),
allow_group_aware: options.flags.contains(FunctionFlags::ALLOW_GROUP_AWARE),
output_dtype,
}
}

pub(crate) fn new_minimal(
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn ColumnsUdf>>,
expr: Expr,
collect_groups: ApplyOptions,
) -> Self {
Self {
inputs,
function,
expr,
collect_groups,
function_returns_scalar: false,
function_operates_on_scalar: false,
allow_rename: false,
pass_name_to_apply: false,
input_schema: None,
allow_threading: true,
check_lengths: true,
allow_group_aware: true,
output_dtype: None,
output_field,
}
}

Expand Down Expand Up @@ -123,19 +100,16 @@ impl ApplyExpr {
Ok(ac)
}

fn get_input_schema(&self, df: &DataFrame) -> Cow<Schema> {
match &self.input_schema {
Some(schema) => Cow::Borrowed(schema.as_ref()),
None => Cow::Owned(df.schema()),
}
fn get_input_schema(&self, _df: &DataFrame) -> Cow<Schema> {
Cow::Borrowed(self.input_schema.as_ref())
}

/// Evaluates and flattens `Option<Column>` to `Column`.
fn eval_and_flatten(&self, inputs: &mut [Column]) -> PolarsResult<Column> {
if let Some(out) = self.function.call_udf(inputs)? {
Ok(out)
} else {
let field = self.to_field(self.input_schema.as_ref().unwrap()).unwrap();
let field = self.to_field(self.input_schema.as_ref()).unwrap();
Ok(Column::full_null(field.name().clone(), 1, field.dtype()))
}
}
Expand Down Expand Up @@ -179,9 +153,11 @@ impl ApplyExpr {
};

let ca: ListChunked = if self.allow_threading {
let dtype = match &self.output_dtype {
Some(dtype) if dtype.is_known() && !dtype.is_null() => Some(dtype.clone()),
_ => None,
let dtype = if self.output_field.dtype.is_known() && !self.output_field.dtype.is_null()
{
Some(self.output_field.dtype.clone())
} else {
None
};

let lst = agg.list().unwrap();
Expand Down Expand Up @@ -287,6 +263,7 @@ impl ApplyExpr {
}
builder.finish()
} else {
// We still need this branch to materialize unknown/ data dependent types in eager. :(
(0..len)
.map(|_| {
container.clear();
Expand All @@ -303,6 +280,13 @@ impl ApplyExpr {
.collect::<PolarsResult<ListChunked>>()?
.with_name(field.name.clone())
};
#[cfg(debug_assertions)]
{
let inner = ca.dtype().inner_dtype().unwrap();
if field.dtype.is_known() {
assert_eq!(inner, &field.dtype);
}
}

drop(iters);

Expand Down
61 changes: 28 additions & 33 deletions crates/polars-expr/src/expressions/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExp
pub struct ColumnExpr {
name: PlSmallStr,
expr: Expr,
schema: Option<SchemaRef>,
schema: SchemaRef,
}

impl ColumnExpr {
pub fn new(name: PlSmallStr, expr: Expr, schema: Option<SchemaRef>) -> Self {
pub fn new(name: PlSmallStr, expr: Expr, schema: SchemaRef) -> Self {
Self { name, expr, schema }
}
}
Expand Down Expand Up @@ -141,42 +141,37 @@ impl PhysicalExpr for ColumnExpr {
Some(&self.expr)
}
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Series> {
let out = match &self.schema {
None => self.process_by_linear_search(df, state, false),
Some(schema) => {
match schema.get_full(&self.name) {
Some((idx, _, _)) => {
// check if the schema was correct
// if not do O(n) search
match df.get_columns().get(idx) {
Some(out) => self.process_by_idx(
out.as_materialized_series(),
state,
schema,
df,
true,
),
None => {
// partitioned group_by special case
if let Some(schema) = state.get_schema() {
self.process_from_state_schema(df, state, &schema)
} else {
self.process_by_linear_search(df, state, true)
}
},
}
},
// in the future we will throw an error here
// now we do a linear search first as the lazy reported schema may still be incorrect
// in debug builds we panic so that it can be fixed when occurring
let out = match self.schema.get_full(&self.name) {
Some((idx, _, _)) => {
// check if the schema was correct
// if not do O(n) search
match df.get_columns().get(idx) {
Some(out) => self.process_by_idx(
out.as_materialized_series(),
state,
&self.schema,
df,
true,
),
None => {
if self.name.starts_with(CSE_REPLACED) {
return self.process_cse(df, schema);
// partitioned group_by special case
if let Some(schema) = state.get_schema() {
self.process_from_state_schema(df, state, &schema)
} else {
self.process_by_linear_search(df, state, true)
}
self.process_by_linear_search(df, state, true)
},
}
},
// in the future we will throw an error here
// now we do a linear search first as the lazy reported schema may still be incorrect
// in debug builds we panic so that it can be fixed when occurring
None => {
if self.name.starts_with(CSE_REPLACED) {
return self.process_cse(df, &self.schema);
}
self.process_by_linear_search(df, state, true)
},
};
self.check_external_context(out, state)
}
Expand Down
66 changes: 36 additions & 30 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn create_physical_expressions_from_irs(
exprs: &[ExprIR],
context: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
create_physical_expressions_check_state(exprs, context, expr_arena, schema, state, ok_checker)
Expand All @@ -35,7 +35,7 @@ pub(crate) fn create_physical_expressions_check_state<F>(
exprs: &[ExprIR],
context: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
checker: F,
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
Expand All @@ -57,7 +57,7 @@ pub(crate) fn create_physical_expressions_from_nodes(
exprs: &[Node],
context: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
create_physical_expressions_from_nodes_check_state(
Expand All @@ -69,7 +69,7 @@ pub(crate) fn create_physical_expressions_from_nodes_check_state<F>(
exprs: &[Node],
context: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
checker: F,
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
Expand Down Expand Up @@ -165,7 +165,7 @@ pub fn create_physical_expr(
expr_ir: &ExprIR,
ctxt: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
) -> PolarsResult<Arc<dyn PhysicalExpr>> {
let phys_expr = create_physical_expr_inner(expr_ir.node(), ctxt, expr_arena, schema, state)?;
Expand All @@ -185,7 +185,7 @@ fn create_physical_expr_inner(
expression: Node,
ctxt: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
) -> PolarsResult<Arc<dyn PhysicalExpr>> {
use AExpr::*;
Expand Down Expand Up @@ -309,7 +309,7 @@ fn create_physical_expr_inner(
Column(column) => Ok(Arc::new(ColumnExpr::new(
column.clone(),
node_to_expr(expression, expr_arena),
schema.cloned(),
schema.clone(),
))),
Sort { expr, options } => {
let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
Expand Down Expand Up @@ -410,22 +410,18 @@ fn create_physical_expr_inner(
return Ok(Arc::new(AggQuantileExpr::new(input, quantile, *interpol)));
}

let field = schema
.map(|schema| {
expr_arena.get(expression).to_field(
schema,
Context::Aggregation,
expr_arena,
)
})
.transpose()?;
let field = expr_arena.get(expression).to_field(
schema,
Context::Aggregation,
expr_arena,
)?;

let groupby = GroupByMethod::from(agg.clone());
let agg_type = AggregationType {
groupby,
allow_threading: false,
};
Ok(Arc::new(AggregationExpr::new(input, agg_type, field)))
Ok(Arc::new(AggregationExpr::new(input, agg_type, Some(field))))
},
}
},
Expand Down Expand Up @@ -475,12 +471,10 @@ fn create_physical_expr_inner(
options,
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_dtype = schema.and_then(|schema| {
let output_dtype =
expr_arena
.get(expression)
.to_dtype(schema, Context::Default, expr_arena)
.ok()
});
.to_field(schema, Context::Default, expr_arena)?;

let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
&& matches!(options.collect_groups, ApplyOptions::GroupWise);
Expand All @@ -504,7 +498,7 @@ fn create_physical_expr_inner(
node_to_expr(expression, expr_arena),
*options,
state.allow_threading,
schema.cloned(),
schema.clone(),
output_dtype,
is_scalar,
)))
Expand All @@ -516,12 +510,10 @@ fn create_physical_expr_inner(
..
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_dtype = schema.and_then(|schema| {
let output_field =
expr_arena
.get(expression)
.to_dtype(schema, Context::Default, expr_arena)
.ok()
});
.to_field(schema, Context::Default, expr_arena)?;
let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
&& matches!(options.collect_groups, ApplyOptions::GroupWise);
// Will be reset in the function so get that here.
Expand All @@ -544,8 +536,8 @@ fn create_physical_expr_inner(
node_to_expr(expression, expr_arena),
*options,
state.allow_threading,
schema.cloned(),
output_dtype,
schema.clone(),
output_field,
is_scalar,
)))
},
Expand All @@ -570,11 +562,25 @@ fn create_physical_expr_inner(
let function = SpecialEq::new(Arc::new(
move |c: &mut [polars_core::frame::column::Column]| c[0].explode().map(Some),
) as Arc<dyn ColumnsUdf>);
Ok(Arc::new(ApplyExpr::new_minimal(

let field = expr_arena
.get(expression)
.to_field(schema, ctxt, expr_arena)?;
Ok(Arc::new(ApplyExpr::new(
vec![input],
function,
node_to_expr(expression, expr_arena),
ApplyOptions::GroupWise,
FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
fmt_str: "",
cast_to_supertypes: None,
check_lengths: Default::default(),
flags: Default::default(),
},
state.allow_threading,
schema.clone(),
field,
false,
)))
},
Alias(input, name) => {
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-lazy/src/dsl/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized {

// Ensure we get the new schema.
let output_field = eval_field_to_dtype(c.field().as_ref(), &expr, false);
let schema = Arc::new(Schema::from_iter(std::iter::once(output_field.clone())));

let expr = expr.clone();
let mut arena = Arena::with_capacity(10);
Expand All @@ -60,7 +61,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized {
&aexpr,
Context::Default,
&arena,
None,
&schema,
&mut ExpressionConversionState::new(true, 0),
)?;

Expand Down
Loading

0 comments on commit 251e171

Please sign in to comment.