Skip to content

Commit

Permalink
fix(rust): get_dtype needs to look at input schema for Select/HStack
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed May 30, 2024
1 parent 9241ebc commit 3a1bf85
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,17 @@ impl NodeTraverser {
let lp_arena = self.lp_arena.lock().unwrap();
let ir_node = lp_arena.get(self.root);
let expr_arena = self.expr_arena.lock().unwrap();
let schema = {
let schema = if let Some(schema) = ir_node.input_schema(&lp_arena) {
// TODO: This is a hack for CSE expressions when
// determining the dtype. It should be removed once
// to_field, or its moral equivalent can handle this in a
// proper way. The schema needs to include the dtype of
// CSE expressions for to_field to work with expressions
// that reference them, but is not part of the public
// schema of the node.
let schema = ir_node.schema(&lp_arena);
// schema of the input.
match ir_node {
// Both select and hstack must augment with any CSE
// expressions.
IR::Select { expr, .. } | IR::HStack { exprs: expr, .. } => {
let cse_exprs = expr.cse_exprs();
if cse_exprs.is_empty() {
Expand All @@ -153,6 +154,8 @@ impl NodeTraverser {
},
_ => schema,
}
} else {
raise_err!("Not able to compute input schema", ComputeError)
};
let field = expr_arena
.get(expr_node)
Expand Down

0 comments on commit 3a1bf85

Please sign in to comment.