Skip to content

Commit

Permalink
fix: Set intersection supertype
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 24, 2024
1 parent cc7fc32 commit 0a7c768
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 40 deletions.
31 changes: 24 additions & 7 deletions crates/polars-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,24 @@ pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult<DataType> {
)
}

#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Default)]
pub struct SuperTypeOptions {
pub implode_list: bool,
}

pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
get_supertype_with_options(l, r, SuperTypeOptions::default())
}

/// Given two data types, determine the data type that both types can safely be cast to.
///
/// Returns [`None`] if no such data type exists.
pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
fn inner(l: &DataType, r: &DataType) -> Option<DataType> {
pub fn get_supertype_with_options(
l: &DataType,
r: &DataType,
options: SuperTypeOptions,
) -> Option<DataType> {
fn inner(l: &DataType, r: &DataType, options: SuperTypeOptions) -> Option<DataType> {
use DataType::*;
if l == r {
return Some(l.clone());
Expand Down Expand Up @@ -233,22 +246,26 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
}
(List(inner_left), List(inner_right)) => {
let st = get_supertype(inner_left, inner_right)?;
Some(DataType::List(Box::new(st)))
Some(List(Box::new(st)))
}
#[cfg(feature = "dtype-array")]
(List(inner_left), Array(inner_right, _)) | (Array(inner_left, _), List(inner_right)) => {
let st = get_supertype(inner_left, inner_right)?;
Some(DataType::List(Box::new(st)))
Some(List(Box::new(st)))
}
#[cfg(feature = "dtype-array")]
(Array(inner_left, width_left), Array(inner_right, width_right)) if *width_left == *width_right => {
let st = get_supertype(inner_left, inner_right)?;
Some(DataType::Array(Box::new(st), *width_left))
Some(Array(Box::new(st), *width_left))
}
(List(inner), other) | (other, List(inner)) if options.implode_list => {
let st = get_supertype(inner, other)?;
Some(List(Box::new(st)))
}
#[cfg(feature = "dtype-array")]
(Array(inner_left, _), Array(inner_right, _)) => {
let st = get_supertype(inner_left, inner_right)?;
Some(DataType::List(Box::new(st)))
Some(List(Box::new(st)))
}
#[cfg(feature = "dtype-struct")]
(Struct(inner), right @ Unknown(UnknownKind::Float | UnknownKind::Int(_))) => {
Expand Down Expand Up @@ -328,7 +345,7 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
}
}

inner(l, r).or_else(|| inner(r, l))
inner(l, r, options).or_else(|| inner(r, l, options))
}

/// Given multiple data types, determine the data type that all types can safely be cast to.
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/functions/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub fn concat_expr<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
..Default::default()
},
})
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-plan/src/dsl/functions/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {
function,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
returns_scalar: true,
..Default::default()
},
Expand All @@ -35,7 +35,7 @@ pub fn pearson_corr(a: Expr, b: Expr, ddof: u8) -> Expr {
function,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
returns_scalar: true,
..Default::default()
},
Expand Down Expand Up @@ -63,7 +63,7 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E
function,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
returns_scalar: true,
..Default::default()
},
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ pub fn sum_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
cast_to_supertypes: false,
cast_to_supertypes: None,
..Default::default()
},
})
Expand All @@ -293,7 +293,7 @@ pub fn mean_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
cast_to_supertypes: false,
cast_to_supertypes: None,
..Default::default()
},
})
Expand All @@ -309,7 +309,7 @@ pub fn coalesce(exprs: &[Expr]) -> Expr {
function: FunctionExpr::Coalesce,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
input_wildcard_expansion: true,
..Default::default()
},
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/functions/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub fn datetime_range(
}),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
allow_rename: true,
..Default::default()
},
Expand Down Expand Up @@ -118,7 +118,7 @@ pub fn datetime_ranges(
}),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
allow_rename: true,
..Default::default()
},
Expand Down
24 changes: 13 additions & 11 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::sync::RwLock;
use polars_core::prelude::*;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
#[cfg(feature = "list_sets")]
use polars_core::utils::SuperTypeOptions;

use crate::prelude::function_expr::ListFunction;
use crate::prelude::*;
Expand Down Expand Up @@ -359,17 +361,17 @@ impl ListNameSpace {

#[cfg(feature = "list_sets")]
fn set_operation(self, other: Expr, set_operation: SetOperation) -> Expr {
self.0
.map_many_private(
FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)),
&[other],
false,
true,
)
.with_function_options(|mut options| {
options.input_wildcard_expansion = true;
options
})
Expr::Function {
input: vec![self.0, other],
function: FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
returns_scalar: false,
cast_to_supertypes: Some(SuperTypeOptions { implode_list: true }),
input_wildcard_expansion: true,
..Default::default()
},
}
}

/// Return the SET UNION between both list arrays.
Expand Down
16 changes: 14 additions & 2 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ impl Expr {
collect_groups: ApplyOptions::GroupWise,
returns_scalar: true,
fmt_str: "search_sorted",
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
..Default::default()
},
}
Expand Down Expand Up @@ -685,6 +685,12 @@ impl Expr {
input.push(self);
input.extend_from_slice(arguments);

let cast_to_supertypes = if cast_to_supertypes {
Some(Default::default())
} else {
None
};

Expr::Function {
input,
function: function_expr,
Expand All @@ -708,6 +714,12 @@ impl Expr {
input.push(self);
input.extend_from_slice(arguments);

let cast_to_supertypes = if cast_to_supertypes {
Some(Default::default())
} else {
None
};

Expr::Function {
input,
function: function_expr,
Expand Down Expand Up @@ -1007,7 +1019,7 @@ impl Expr {
function: FunctionExpr::FillNull,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
..Default::default()
},
}
Expand Down
18 changes: 11 additions & 7 deletions crates/polars-plan/src/plans/conversion/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use arrow::legacy::utils::CustomIterTools;
use binary::process_binary;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::utils::{get_supertype, materialize_dyn_int};
use polars_core::utils::{get_supertype, get_supertype_with_options, materialize_dyn_int};
use polars_utils::idx_vec::UnitVec;
use polars_utils::{format_list, unitvec};

Expand Down Expand Up @@ -347,7 +347,7 @@ impl OptimizationRule for TypeCoercionRule {
ref function,
ref input,
mut options,
} if options.cast_to_supertypes => {
} if options.cast_to_supertypes.is_some() => {
let input_schema = get_schema(lp_arena, lp_node);
let mut dtypes = Vec::with_capacity(input.len());
for e in input {
Expand All @@ -357,15 +357,15 @@ impl OptimizationRule for TypeCoercionRule {
// We will raise if we cannot find the supertype later.
match dtype {
DataType::Unknown(UnknownKind::Any) => {
options.cast_to_supertypes = false;
options.cast_to_supertypes = None;
return Ok(None);
},
_ => dtypes.push(dtype),
}
}

if dtypes.iter().all_equal() {
options.cast_to_supertypes = false;
options.cast_to_supertypes = None;
return Ok(None);
}

Expand All @@ -379,7 +379,11 @@ impl OptimizationRule for TypeCoercionRule {
let (other, type_other) =
unpack!(get_aexpr_and_type(expr_arena, other.node(), &input_schema));

let Some(new_st) = get_supertype(&super_type, &type_other) else {
let Some(new_st) = get_supertype_with_options(
&super_type,
&type_other,
options.cast_to_supertypes.unwrap(),
) else {
polars_bail!(InvalidOperation: "could not determine supertype of: {}", format_list!(dtypes));
};
if input.len() == 2 {
Expand Down Expand Up @@ -432,8 +436,8 @@ impl OptimizationRule for TypeCoercionRule {
})
.collect::<Vec<_>>();

// ensure we don't go through this on next iteration
options.cast_to_supertypes = false;
// Ensure we don't go through this on next iteration.
options.cast_to_supertypes = None;
Some(AExpr::Function {
function,
input,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/optimizer/fused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena<AExpr>) -> AEx
.collect();
let mut options = FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_to_supertypes: true,
cast_to_supertypes: Some(Default::default()),
..Default::default()
};
// order of operations change because of FMA
Expand Down
8 changes: 6 additions & 2 deletions crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::num::NonZeroUsize;
use std::path::PathBuf;

use polars_core::prelude::*;
use polars_core::utils::SuperTypeOptions;
#[cfg(feature = "csv")]
use polars_io::csv::write::CsvWriterOptions;
#[cfg(feature = "ipc")]
Expand Down Expand Up @@ -137,7 +138,10 @@ pub struct FunctionOptions {
/// sum(x) -> {4}
pub returns_scalar: bool,
// if the expression and its inputs should be cast to supertypes
pub cast_to_supertypes: bool,
// `None` -> Don't cast.
// `Some` -> cast with given options.
#[cfg_attr(feature = "serde", serde(skip))]
pub cast_to_supertypes: Option<SuperTypeOptions>,
// The physical expression may rename the output of this function.
// If set to `false` the physical engine will ensure the left input
// expression is the output name.
Expand Down Expand Up @@ -179,7 +183,7 @@ impl Default for FunctionOptions {
input_wildcard_expansion: false,
returns_scalar: false,
fmt_str: "",
cast_to_supertypes: false,
cast_to_supertypes: None,
allow_rename: false,
pass_name_to_apply: false,
changes_length: false,
Expand Down
8 changes: 7 additions & 1 deletion py-polars/src/functions/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ pub fn register_plugin_function(
ApplyOptions::GroupWise
};

let cast_to_supertypes = if cast_to_supertype {
Some(Default::default())
} else {
None
};

Ok(Expr::Function {
input: args.to_exprs(),
function: FunctionExpr::FfiPlugin {
Expand All @@ -45,7 +51,7 @@ pub fn register_plugin_function(
collect_groups,
input_wildcard_expansion,
returns_scalar,
cast_to_supertypes: cast_to_supertype,
cast_to_supertypes,
pass_name_to_apply,
changes_length,
..Default::default()
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/unit/operations/test_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ def test_set_intersection_13765() -> None:
df.select(pl.col("a").list.set_intersection("a_other")).to_dict(as_series=False)


def test_set_intersection_st_17129() -> None:
df = pl.DataFrame({"a": [1, 2, 2], "b": [2, 2, 4]})

assert df.with_columns(
pl.col("b")
.over("a", mapping_strategy="join")
.list.set_intersection([4, 8])
.alias("intersect")
).to_dict(as_series=False) == {
"a": [1, 2, 2],
"b": [2, 2, 4],
"intersect": [[], [4], [4]],
}


@pytest.mark.parametrize(
("set_operation", "outcome"),
[
Expand Down

0 comments on commit 0a7c768

Please sign in to comment.