diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index f6c7b64e471c..bc73d230f9de 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -29,7 +29,9 @@ pub use series_trait::{IsSorted, *}; use crate::chunked_array::Settings; #[cfg(feature = "zip_with")] use crate::series::arithmetic::coerce_lhs_rhs; -use crate::utils::{_split_offsets, handle_casting_failures, split_ca, split_series, Wrap}; +use crate::utils::{ + _split_offsets, handle_casting_failures, materialize_dyn_int, split_ca, split_series, Wrap, +}; use crate::POOL; /// # Series @@ -309,9 +311,39 @@ impl Series { /// Cast `[Series]` to another `[DataType]`. pub fn cast(&self, dtype: &DataType) -> PolarsResult { - // Best leave as is. - if !dtype.is_known() || (dtype.is_primitive() && dtype == self.dtype()) { - return Ok(self.clone()); + match dtype { + DataType::Unknown(kind) => { + return match kind { + // Best leave as is. + UnknownKind::Any => Ok(self.clone()), + UnknownKind::Int(v) => { + if self.dtype().is_integer() { + Ok(self.clone()) + } else { + self.cast(&materialize_dyn_int(*v).dtype()) + } + }, + UnknownKind::Float => { + if self.dtype().is_float() { + Ok(self.clone()) + } else { + self.cast(&DataType::Float64) + } + }, + UnknownKind::Str => { + if self.dtype().is_string() | self.dtype().is_categorical() { + Ok(self.clone()) + } else { + self.cast(&DataType::String) + } + }, + }; + }, + // Best leave as is. + dt if dt.is_primitive() && dt == self.dtype() => { + return Ok(self.clone()); + }, + _ => {}, } let ret = self.0.cast(dtype); let len = self.len(); diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs index 456f021d1e26..f0eb0051b803 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs @@ -240,7 +240,6 @@ pub(super) fn process_binary( right: node_right, })); }, - (Unknown(lhs), Unknown(rhs)) if lhs == rhs => return Ok(None), _ => { unpack!(early_escape(&type_left, &type_right)); }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index c97b89e52613..d38d58b027ef 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -31,54 +31,6 @@ fn modify_supertype( type_left: &DataType, type_right: &DataType, ) -> DataType { - use AExpr::*; - - let dynamic_st_or_unknown = matches!(st, DataType::Unknown(_)); - - match (left, right) { - ( - Literal( - lv_left @ (LiteralValue::Int(_) - | LiteralValue::Float(_) - | LiteralValue::StrCat(_) - | LiteralValue::Null), - ), - Literal( - lv_right @ (LiteralValue::Int(_) - | LiteralValue::Float(_) - | LiteralValue::StrCat(_) - | LiteralValue::Null), - ), - ) => { - let lhs = lv_left.to_any_value().unwrap().dtype(); - let rhs = lv_right.to_any_value().unwrap().dtype(); - st = get_supertype(&lhs, &rhs).unwrap(); - return st; - }, - // Materialize dynamic types - ( - Literal( - lv_left @ (LiteralValue::Int(_) | LiteralValue::Float(_) | LiteralValue::StrCat(_)), - ), - _, - ) if dynamic_st_or_unknown => { - st = lv_left.to_any_value().unwrap().dtype(); - return st; - }, - ( - _, - Literal( - lv_right - @ (LiteralValue::Int(_) | LiteralValue::Float(_) | LiteralValue::StrCat(_)), - ), - ) if dynamic_st_or_unknown => { - st = lv_right.to_any_value().unwrap().dtype(); - return st; - }, - // do nothing - _ => {}, - } - // TODO! This must be removed and dealt properly with dynamic str. use DataType::*; match (type_left, type_right, left, right) { @@ -185,44 +137,9 @@ impl OptimizationRule for TypeCoercionRule { let (falsy, type_false) = unpack!(get_aexpr_and_type(expr_arena, falsy_node, &input_schema)); - match (&type_true, &type_false) { - (DataType::Unknown(lhs), DataType::Unknown(rhs)) => { - match (lhs, rhs) { - (UnknownKind::Any, _) | (_, UnknownKind::Any) => return Ok(None), - // continue - (UnknownKind::Int(_), UnknownKind::Float) - | (UnknownKind::Float, UnknownKind::Int(_)) => {}, - (lhs, rhs) if lhs == rhs => { - let falsy = materialize(falsy); - let truthy = materialize(truthy); - - if falsy.is_none() && truthy.is_none() { - return Ok(None); - } - - let falsy = if let Some(falsy) = falsy { - expr_arena.add(falsy) - } else { - falsy_node - }; - let truthy = if let Some(truthy) = truthy { - expr_arena.add(truthy) - } else { - truthy_node - }; - return Ok(Some(AExpr::Ternary { - truthy, - falsy, - predicate, - })); - }, - _ => {}, - } - }, - (lhs, rhs) if lhs == rhs => return Ok(None), - _ => {}, + if type_true == type_false { + return Ok(None); } - let st = unpack!(get_supertype(&type_true, &type_false)); let st = modify_supertype(st, truthy, falsy, &type_true, &type_false); @@ -612,13 +529,6 @@ fn inline_or_prune_cast( fn early_escape(type_self: &DataType, type_other: &DataType) -> Option<()> { match (type_self, type_other) { - (DataType::Unknown(lhs), DataType::Unknown(rhs)) => match (lhs, rhs) { - (UnknownKind::Any, _) | (_, UnknownKind::Any) => None, - (UnknownKind::Int(_), UnknownKind::Float) - | (UnknownKind::Float, UnknownKind::Int(_)) => Some(()), - (lhs, rhs) if lhs == rhs => None, - _ => Some(()), - }, (lhs, rhs) if lhs == rhs => None, _ => Some(()), } diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index 7625fece9987..8315b0801597 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -604,3 +604,22 @@ def test_when_then_supertype_15975() -> None: assert df.with_columns( pl.when(True).then(1 ** pl.col("a") + 1.0 * pl.col("a")) ).to_dict(as_series=False) == {"a": [1, 2, 3], "literal": [2.0, 3.0, 4.0]} + + +def test_when_then_supertype_15975_comment() -> None: + df = pl.LazyFrame({"foo": [1, 3, 4], "bar": [3, 4, 0]}) + + q = df.with_columns( + pl.when(pl.col("foo") == 1) + .then(1) + .when(pl.col("foo") == 2) + .then(4) + .when(pl.col("foo") == 3) + .then(1.5) + .when(pl.col("foo") == 4) + .then(16) + .otherwise(0) + .alias("val") + ) + + assert q.collect()["val"].to_list() == [1.0, 1.5, 16.0]