Skip to content

Commit

Permalink
fix: Fix struct name resolving (pola-rs#15507)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Apr 6, 2024
1 parent 7dfc53e commit 93b194e
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 31 deletions.
15 changes: 13 additions & 2 deletions crates/polars-arrow/src/legacy/index.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
use std::fmt::Display;

use num_traits::{NumCast, Signed, Zero};
use polars_error::{polars_err, PolarsResult};
use polars_utils::IdxSize;

use crate::array::PrimitiveArray;

pub trait IndexToUsize {
pub trait IndexToUsize: Display {
/// Translate the negative index to an offset.
fn negative_to_usize(self, len: usize) -> Option<usize>;

fn try_negative_to_usize(self, len: usize) -> PolarsResult<usize>
where
Self: Sized + Copy,
{
self.negative_to_usize(len)
.ok_or_else(|| polars_err!(OutOfBounds: "index {} for length: {}", self, len))
}
}

impl<I> IndexToUsize for I
where
I: PartialOrd + PartialEq + NumCast + Signed + Zero,
I: PartialOrd + PartialEq + NumCast + Signed + Zero + Display,
{
#[inline]
fn negative_to_usize(self, len: usize) -> Option<usize> {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub(super) use self::rolling::RollingFunction;
#[cfg(feature = "strings")]
pub(crate) use self::strings::StringFunction;
#[cfg(feature = "dtype-struct")]
pub(super) use self::struct_::StructFunction;
pub(crate) use self::struct_::StructFunction;
#[cfg(feature = "trigonometry")]
pub(super) use self::trigonometry::TrigonometricFunction;
use super::*;
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,14 @@ impl FunctionExpr {
ExtendConstant => mapper.with_same_dtype(),
}
}

pub(crate) fn output_name(&self) -> Option<&ColumnName> {
match self {
#[cfg(feature = "dtype-struct")]
FunctionExpr::StructExpr(StructFunction::FieldByName(name)) => Some(name),
_ => None,
}
}
}

pub struct FieldsMapper<'a> {
Expand Down
10 changes: 1 addition & 9 deletions crates/polars-plan/src/dsl/function_expr/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl From<StructFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
fn from(func: StructFunction) -> Self {
use StructFunction::*;
match func {
FieldByIndex(index) => map!(struct_::get_by_index, index),
FieldByIndex(_) => panic!("should be replaced"),
FieldByName(name) => map!(struct_::get_by_name, name.clone()),
RenameFields(names) => map!(struct_::rename_fields, names.clone()),
PrefixFields(prefix) => map!(struct_::prefix_fields, prefix.clone()),
Expand All @@ -124,14 +124,6 @@ impl From<StructFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
}
}

pub(super) fn get_by_index(s: &Series, index: i64) -> PolarsResult<Series> {
let s = s.struct_()?;
let (index, _) = slice_offsets(index, 0, s.fields().len());
s.fields()
.get(index)
.cloned()
.ok_or_else(|| polars_err!(ComputeError: "struct field index out of bounds"))
}
pub(super) fn get_by_name(s: &Series, name: Arc<str>) -> PolarsResult<Series> {
let ca = s.struct_()?;
ca.field_by_name(name.as_ref())
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-plan/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena<AExpr>, state: &mut ConversionSta
function,
options,
} => {
if state.output_name.is_none() {
if let Some(name) = function.output_name() {
state.output_name = OutputName::ColumnLhs(name.clone())
}
}
state.prune_alias = false;
AExpr::Function {
input: to_aexprs(input, arena, state),
Expand Down
96 changes: 77 additions & 19 deletions crates/polars-plan/src/logical_plan/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,64 @@ fn expand_columns(

/// This replaces the dtypes Expr with a Column Expr. It also removes the Exclude Expr from the
/// expression chain.
pub(super) fn replace_dtype_with_column(expr: Expr, column_name: Arc<str>) -> Expr {
fn replace_dtype_with_column(expr: Expr, column_name: Arc<str>) -> Expr {
expr.map_expr(|e| match e {
Expr::DtypeColumn(_) => Expr::Column(column_name.clone()),
Expr::Exclude(input, _) => Arc::unwrap_or_clone(input),
e => e,
})
}

fn set_null_st(e: Expr, schema: &Schema) -> Expr {
e.map_expr(|mut e| {
if let Expr::Function {
input,
function: FunctionExpr::FillNull { super_type },
..
} = &mut e
{
if let Some(new_st) = early_supertype(input, schema) {
*super_type = new_st;
}
}
e
})
}

#[cfg(feature = "dtype-struct")]
fn struct_index_to_field(expr: Expr, schema: &Schema) -> PolarsResult<Expr> {
expr.try_map_expr(|e| match e {
Expr::Function {
input,
function: FunctionExpr::StructExpr(sf),
options,
} => {
if let StructFunction::FieldByIndex(index) = sf {
let dtype = input[0].to_field(schema, Context::Default)?.dtype;
let DataType::Struct(fields) = dtype else {
polars_bail!(InvalidOperation: "expected 'struct' dtype, got {:?}", dtype)
};
let index = index.try_negative_to_usize(fields.len())?;
let name = fields[index].name.as_str();
Ok(Expr::Function {
input,
function: FunctionExpr::StructExpr(StructFunction::FieldByName(
ColumnName::from(name),
)),
options,
})
} else {
Ok(Expr::Function {
input,
function: FunctionExpr::StructExpr(sf),
options,
})
}
},
e => Ok(e),
})
}

/// This replaces the columns Expr with a Column Expr. It also removes the Exclude Expr from the
/// expression chain.
pub(super) fn replace_columns_with_column(
Expand Down Expand Up @@ -374,6 +424,8 @@ struct ExpansionFlags {
replace_fill_null_type: bool,
has_selector: bool,
has_exclude: bool,
#[cfg(feature = "dtype-struct")]
has_struct_field_by_index: bool,
}

fn find_flags(expr: &Expr) -> ExpansionFlags {
Expand All @@ -383,9 +435,11 @@ fn find_flags(expr: &Expr) -> ExpansionFlags {
let mut replace_fill_null_type = false;
let mut has_selector = false;
let mut has_exclude = false;
#[cfg(feature = "dtype-struct")]
let mut has_struct_field_by_index = false;

// do a single pass and collect all flags at once.
// supertypes/modification that can be done in place are also don e in that pass
// Do a single pass and collect all flags at once.
// Supertypes/modification that can be done in place are also done in that pass
for expr in expr {
match expr {
Expr::Columns(_) | Expr::DtypeColumn(_) => multiple_columns = true,
Expand All @@ -396,6 +450,13 @@ fn find_flags(expr: &Expr) -> ExpansionFlags {
function: FunctionExpr::FillNull { .. },
..
} => replace_fill_null_type = true,
#[cfg(feature = "dtype-struct")]
Expr::Function {
function: FunctionExpr::StructExpr(StructFunction::FieldByIndex(_)),
..
} => {
has_struct_field_by_index = true;
},
Expr::Exclude(_, _) => has_exclude = true,
_ => {},
}
Expand All @@ -407,6 +468,8 @@ fn find_flags(expr: &Expr) -> ExpansionFlags {
replace_fill_null_type,
has_selector,
has_exclude,
#[cfg(feature = "dtype-struct")]
has_struct_field_by_index,
}
}

Expand All @@ -422,7 +485,7 @@ pub(crate) fn rewrite_projections(
for mut expr in exprs {
let result_offset = result.len();

// functions can have col(["a", "b"]) or col(String) as inputs
// Functions can have col(["a", "b"]) or col(String) as inputs.
expr = expand_function_inputs(expr, schema);

let mut flags = find_flags(&expr);
Expand All @@ -434,27 +497,22 @@ pub(crate) fn rewrite_projections(

replace_and_add_to_results(expr, flags, &mut result, schema, keys)?;

// this is done after all expansion (wildcard, column, dtypes)
// This is done after all expansion (wildcard, column, dtypes)
// have been done. This will ensure the conversion to aexpr does
// not panic because of an unexpected wildcard etc.

// the expanded expressions are written to result, so we pick
// The expanded expressions are written to result, so we pick
// them up there.
if flags.replace_fill_null_type {
for e in &mut result[result_offset..] {
*e = e.clone().map_expr(|mut e| {
if let Expr::Function {
input,
function: FunctionExpr::FillNull { super_type },
..
} = &mut e
{
if let Some(new_st) = early_supertype(input, schema) {
*super_type = new_st;
}
}
e
});
*e = set_null_st(std::mem::take(e), schema);
}
}

#[cfg(feature = "dtype-struct")]
if flags.has_struct_field_by_index {
for e in &mut result[result_offset..] {
*e = struct_index_to_field(std::mem::take(e), schema)?;
}
}
}
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/unit/test_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,23 @@ def test_exclude_selection() -> None:
assert ldf.select(pl.all().exclude(pl.Boolean)).columns == ["a", "b"]
assert ldf.select(pl.all().exclude([pl.Boolean])).columns == ["a", "b"]
assert ldf.select(pl.all().exclude(NUMERIC_DTYPES)).columns == ["c"]


def test_struct_name_resolving_15430() -> None:
q = pl.LazyFrame([{"a": {"b": "c"}}])
a = (
q.with_columns(pl.col("a").struct.field("b"))
.drop("a")
.collect(projection_pushdown=True)
)

b = (
q.with_columns(pl.col("a").struct[0])
.drop("a")
.collect(projection_pushdown=True)
)

assert a["b"].item() == "c"
assert b["b"].item() == "c"
assert a.columns == ["b"]
assert b.columns == ["b"]

0 comments on commit 93b194e

Please sign in to comment.