Skip to content

Commit

Permalink
fix: Ternary supertype dynamics
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 1, 2024
1 parent 8929395 commit 9de4ea4
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 97 deletions.
40 changes: 36 additions & 4 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ pub use series_trait::{IsSorted, *};
use crate::chunked_array::Settings;
#[cfg(feature = "zip_with")]
use crate::series::arithmetic::coerce_lhs_rhs;
use crate::utils::{_split_offsets, handle_casting_failures, split_ca, split_series, Wrap};
use crate::utils::{
_split_offsets, handle_casting_failures, materialize_dyn_int, split_ca, split_series, Wrap,
};
use crate::POOL;

/// # Series
Expand Down Expand Up @@ -309,9 +311,39 @@ impl Series {

/// Cast `[Series]` to another `[DataType]`.
pub fn cast(&self, dtype: &DataType) -> PolarsResult<Self> {
// Best leave as is.
if !dtype.is_known() || (dtype.is_primitive() && dtype == self.dtype()) {
return Ok(self.clone());
match dtype {
DataType::Unknown(kind) => {
return match kind {
// Best leave as is.
UnknownKind::Any => Ok(self.clone()),
UnknownKind::Int(v) => {
if self.dtype().is_integer() {
Ok(self.clone())
} else {
self.cast(&materialize_dyn_int(*v).dtype())
}
},
UnknownKind::Float => {
if self.dtype().is_float() {
Ok(self.clone())
} else {
self.cast(&DataType::Float64)
}
},
UnknownKind::Str => {
if self.dtype().is_string() | self.dtype().is_categorical() {
Ok(self.clone())
} else {
self.cast(&DataType::String)
}
},
};
},
// Best leave as is.
dt if dt.is_primitive() && dt == self.dtype() => {
return Ok(self.clone());
},
_ => {},
}
let ret = self.0.cast(dtype);
let len = self.len();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ pub(super) fn process_binary(
right: node_right,
}));
},
(Unknown(lhs), Unknown(rhs)) if lhs == rhs => return Ok(None),
_ => {
unpack!(early_escape(&type_left, &type_right));
},
Expand Down
94 changes: 2 additions & 92 deletions crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,54 +31,6 @@ fn modify_supertype(
type_left: &DataType,
type_right: &DataType,
) -> DataType {
use AExpr::*;

let dynamic_st_or_unknown = matches!(st, DataType::Unknown(_));

match (left, right) {
(
Literal(
lv_left @ (LiteralValue::Int(_)
| LiteralValue::Float(_)
| LiteralValue::StrCat(_)
| LiteralValue::Null),
),
Literal(
lv_right @ (LiteralValue::Int(_)
| LiteralValue::Float(_)
| LiteralValue::StrCat(_)
| LiteralValue::Null),
),
) => {
let lhs = lv_left.to_any_value().unwrap().dtype();
let rhs = lv_right.to_any_value().unwrap().dtype();
st = get_supertype(&lhs, &rhs).unwrap();
return st;
},
// Materialize dynamic types
(
Literal(
lv_left @ (LiteralValue::Int(_) | LiteralValue::Float(_) | LiteralValue::StrCat(_)),
),
_,
) if dynamic_st_or_unknown => {
st = lv_left.to_any_value().unwrap().dtype();
return st;
},
(
_,
Literal(
lv_right
@ (LiteralValue::Int(_) | LiteralValue::Float(_) | LiteralValue::StrCat(_)),
),
) if dynamic_st_or_unknown => {
st = lv_right.to_any_value().unwrap().dtype();
return st;
},
// do nothing
_ => {},
}

// TODO! This must be removed and dealt properly with dynamic str.
use DataType::*;
match (type_left, type_right, left, right) {
Expand Down Expand Up @@ -185,44 +137,9 @@ impl OptimizationRule for TypeCoercionRule {
let (falsy, type_false) =
unpack!(get_aexpr_and_type(expr_arena, falsy_node, &input_schema));

match (&type_true, &type_false) {
(DataType::Unknown(lhs), DataType::Unknown(rhs)) => {
match (lhs, rhs) {
(UnknownKind::Any, _) | (_, UnknownKind::Any) => return Ok(None),
// continue
(UnknownKind::Int(_), UnknownKind::Float)
| (UnknownKind::Float, UnknownKind::Int(_)) => {},
(lhs, rhs) if lhs == rhs => {
let falsy = materialize(falsy);
let truthy = materialize(truthy);

if falsy.is_none() && truthy.is_none() {
return Ok(None);
}

let falsy = if let Some(falsy) = falsy {
expr_arena.add(falsy)
} else {
falsy_node
};
let truthy = if let Some(truthy) = truthy {
expr_arena.add(truthy)
} else {
truthy_node
};
return Ok(Some(AExpr::Ternary {
truthy,
falsy,
predicate,
}));
},
_ => {},
}
},
(lhs, rhs) if lhs == rhs => return Ok(None),
_ => {},
if type_true == type_false {
return Ok(None);
}

let st = unpack!(get_supertype(&type_true, &type_false));
let st = modify_supertype(st, truthy, falsy, &type_true, &type_false);

Expand Down Expand Up @@ -612,13 +529,6 @@ fn inline_or_prune_cast(

fn early_escape(type_self: &DataType, type_other: &DataType) -> Option<()> {
match (type_self, type_other) {
(DataType::Unknown(lhs), DataType::Unknown(rhs)) => match (lhs, rhs) {
(UnknownKind::Any, _) | (_, UnknownKind::Any) => None,
(UnknownKind::Int(_), UnknownKind::Float)
| (UnknownKind::Float, UnknownKind::Int(_)) => Some(()),
(lhs, rhs) if lhs == rhs => None,
_ => Some(()),
},
(lhs, rhs) if lhs == rhs => None,
_ => Some(()),
}
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/functions/test_when_then.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,22 @@ def test_when_then_supertype_15975() -> None:
assert df.with_columns(
pl.when(True).then(1 ** pl.col("a") + 1.0 * pl.col("a"))
).to_dict(as_series=False) == {"a": [1, 2, 3], "literal": [2.0, 3.0, 4.0]}


def test_when_then_supertype_15975_comment() -> None:
df = pl.LazyFrame({"foo": [1, 3, 4], "bar": [3, 4, 0]})

q = df.with_columns(
pl.when(pl.col("foo") == 1)
.then(1)
.when(pl.col("foo") == 2)
.then(4)
.when(pl.col("foo") == 3)
.then(1.5)
.when(pl.col("foo") == 4)
.then(16)
.otherwise(0)
.alias("val")
)

assert q.collect()["val"].to_list() == [1.0, 1.5, 16.0]

0 comments on commit 9de4ea4

Please sign in to comment.