Skip to content

Commit

Permalink
fix: to_any_value supports all LiteralValue type
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Mar 29, 2024
1 parent 9c46183 commit d55783d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
41 changes: 40 additions & 1 deletion crates/polars-plan/src/logical_plan/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::hash::{Hash, Hasher};
#[cfg(feature = "temporal")]
use polars_core::export::chrono::{Duration as ChronoDuration, NaiveDate, NaiveDateTime};
use polars_core::prelude::*;
use polars_core::utils::NoNull;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -116,7 +117,45 @@ impl LiteralValue {
DateTime(v, tu, tz) => AnyValue::Datetime(*v, *tu, tz),
#[cfg(feature = "dtype-time")]
Time(v) => AnyValue::Time(*v),
_ => return None,
Series(s) => {
AnyValue::List(s.0.clone().into_series())
},
Range {
low,
high,
data_type,
} => {
let s = match data_type {
DataType::Int32 => {
if *low < i32::MIN as i64 || *high > i32::MAX as i64 {
return None;
}

let low = *low as i32;
let high = *high as i32;
let ca: NoNull<Int32Chunked> = (low..high).collect();
ca.into_inner().into_series()
},
DataType::Int64 => {
let low = *low;
let high = *high;
let ca: NoNull<Int64Chunked> = (low..high).collect();
ca.into_inner().into_series()
},
DataType::UInt32 => {
if *low < 0 || *high > u32::MAX as i64 {
return None;
}
let low = *low as u32;
let high = *high as u32;
let ca: NoNull<UInt32Chunked> = (low..high).collect();
ca.into_inner().into_series()
},
_ => return None,
};
AnyValue::List(s)
},
Binary(v) => AnyValue::Binary(v),
};
Some(av)
}
Expand Down
2 changes: 2 additions & 0 deletions py-polars/tests/unit/functions/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
(8, 2, pl.UInt8, pl.UInt8),
(date(2023, 2, 2), 3, pl.Datetime, pl.Datetime),
(7.5, 5, pl.UInt16, pl.UInt16),
([1, 2, 3], 2, pl.List(pl.Int64), pl.List(pl.Int64)),
(b"ab12", 3, pl.Binary, pl.Binary),
],
)
def test_repeat(
Expand Down

0 comments on commit d55783d

Please sign in to comment.