From ab5c0ee4f963cf82ceefac11bffe9c8bbcffe469 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Sun, 31 Mar 2024 15:59:33 +0800 Subject: [PATCH] fix: `to_any_value` should supports all LiteralValue type (#15387) --- crates/polars-plan/src/logical_plan/lit.rs | 38 ++++++++++++++++++- py-polars/tests/unit/functions/test_repeat.py | 2 + 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/crates/polars-plan/src/logical_plan/lit.rs b/crates/polars-plan/src/logical_plan/lit.rs index efea5e55b4b8..ac206ba7e9fa 100644 --- a/crates/polars-plan/src/logical_plan/lit.rs +++ b/crates/polars-plan/src/logical_plan/lit.rs @@ -116,7 +116,43 @@ impl LiteralValue { DateTime(v, tu, tz) => AnyValue::Datetime(*v, *tu, tz), #[cfg(feature = "dtype-time")] Time(v) => AnyValue::Time(*v), - _ => return None, + Series(s) => AnyValue::List(s.0.clone().into_series()), + Range { + low, + high, + data_type, + } => { + let opt_s = match data_type { + DataType::Int32 => { + if *low < i32::MIN as i64 || *high > i32::MAX as i64 { + return None; + } + + let low = *low as i32; + let high = *high as i32; + new_int_range::(low, high, 1, "range").ok() + }, + DataType::Int64 => { + let low = *low; + let high = *high; + new_int_range::(low, high, 1, "range").ok() + }, + DataType::UInt32 => { + if *low < 0 || *high > u32::MAX as i64 { + return None; + } + let low = *low as u32; + let high = *high as u32; + new_int_range::(low, high, 1, "range").ok() + }, + _ => return None, + }; + match opt_s { + Some(s) => AnyValue::List(s), + None => return None, + } + }, + Binary(v) => AnyValue::Binary(v), }; Some(av) } diff --git a/py-polars/tests/unit/functions/test_repeat.py b/py-polars/tests/unit/functions/test_repeat.py index 4b1d3138b592..b9c37aded947 100644 --- a/py-polars/tests/unit/functions/test_repeat.py +++ b/py-polars/tests/unit/functions/test_repeat.py @@ -28,6 +28,8 @@ (8, 2, pl.UInt8, pl.UInt8), (date(2023, 2, 2), 3, pl.Datetime, pl.Datetime), (7.5, 5, pl.UInt16, pl.UInt16), + ([1, 2, 3], 2, pl.List(pl.Int64), pl.List(pl.Int64)), + (b"ab12", 3, pl.Binary, pl.Binary), ], ) def test_repeat(