Skip to content

Commit

Permalink
[FEAT] Infer timedelta literal as duration (#3011)
Browse files Browse the repository at this point in the history
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Oct 7, 2024
1 parent 396c004 commit f5cf5af
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 6 deletions.
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,7 @@ def lit(item: Any) -> PyExpr: ...
def date_lit(item: int) -> PyExpr: ...
def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ...
def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ...
def duration_lit(item: int, tu: PyTimeUnit) -> PyExpr: ...
def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ...
def series_lit(item: PySeries) -> PyExpr: ...
def stateless_udf(
Expand Down
9 changes: 8 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import os
import warnings
from datetime import date, datetime, time
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Expand All @@ -23,6 +23,7 @@
from daft.daft import col as _col
from daft.daft import date_lit as _date_lit
from daft.daft import decimal_lit as _decimal_lit
from daft.daft import duration_lit as _duration_lit
from daft.daft import list_sort as _list_sort
from daft.daft import lit as _lit
from daft.daft import series_lit as _series_lit
Expand Down Expand Up @@ -115,6 +116,12 @@ def lit(value: object) -> Expression:
i64_value = pa_time.cast(pa.int64()).as_py()
time_unit = TimeUnit.from_str(pa.type_for_alias(str(pa_time.type)).unit)._timeunit
lit_value = _time_lit(i64_value, time_unit)
elif isinstance(value, timedelta):
# pyo3 timedelta (PyDelta) is not available when running in abi3 mode, workaround
pa_duration = pa.scalar(value)
i64_value = pa_duration.cast(pa.int64()).as_py()
time_unit = TimeUnit.from_str(pa_duration.type.unit)._timeunit
lit_value = _duration_lit(i64_value, time_unit)
elif isinstance(value, Decimal):
sign, digits, exponent = value.as_tuple()
assert isinstance(exponent, int)
Expand Down
20 changes: 18 additions & 2 deletions src/daft-core/src/array/ops/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use crate::{
NullArray, UInt64Array, Utf8Array,
},
series::Series,
utils::display::{display_date32, display_decimal128, display_time64, display_timestamp},
utils::display::{
display_date32, display_decimal128, display_duration, display_time64, display_timestamp,
},
with_match_daft_types,
};

Expand All @@ -34,7 +36,6 @@ macro_rules! impl_array_str_value {

impl_array_str_value!(BooleanArray, "{}");
impl_array_str_value!(ExtensionArray, "{:?}");
impl_array_str_value!(DurationArray, "{}");

fn pretty_print_bytes(bytes: &[u8], max_len: usize) -> DaftResult<String> {
/// influenced by pythons bytes repr
Expand Down Expand Up @@ -192,6 +193,21 @@ impl TimestampArray {
}
}

impl DurationArray {
pub fn str_value(&self, idx: usize) -> DaftResult<String> {
let res = self.get(idx).map_or_else(
|| "None".to_string(),
|val| -> String {
let DataType::Duration(time_unit) = &self.field.dtype else {
panic!("Wrong dtype for DurationArray: {}", self.field.dtype)
};
display_duration(val, time_unit)
},
);
Ok(res)
}
}

impl Decimal128Array {
pub fn str_value(&self, idx: usize) -> DaftResult<String> {
let res = self.get(idx).map_or_else(
Expand Down
48 changes: 48 additions & 0 deletions src/daft-core/src/utils/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,54 @@ pub fn display_timestamp(val: i64, unit: &TimeUnit, timezone: &Option<String>) -
)
}

const UNITS: [&str; 4] = ["d", "h", "m", "s"];
const SIZES: [[i64; 4]; 4] = [
[
86_400_000_000_000,
3_600_000_000_000,
60_000_000_000,
1_000_000_000,
], // Nanoseconds
[86_400_000_000, 3_600_000_000, 60_000_000, 1_000_000], // Microseconds
[86_400_000, 3_600_000, 60_000, 1_000], // Milliseconds
[86_400, 3_600, 60, 1], // Seconds
];

pub fn display_duration(val: i64, unit: &TimeUnit) -> String {
let mut output = String::new();
let (sizes, suffix, remainder_divisor) = match unit {
TimeUnit::Nanoseconds => (&SIZES[0], "ns", 1_000_000_000),
TimeUnit::Microseconds => (&SIZES[1], "µs", 1_000_000),
TimeUnit::Milliseconds => (&SIZES[2], "ms", 1_000),
TimeUnit::Seconds => (&SIZES[3], "s", 1),
};

if val == 0 {
return format!("0{}", suffix);
}

for (i, &size) in sizes.iter().enumerate() {
let whole_num = if i == 0 {
val / size
} else {
(val % sizes[i - 1]) / size
};
if whole_num != 0 {
output.push_str(&format!("{}{}", whole_num, UNITS[i]));
if val % size != 0 {
output.push(' ');
}
}
}

let remainder = val % remainder_divisor;
if remainder != 0 && suffix != "s" {
output.push_str(&format!("{}{}", remainder, suffix));
}

output
}

pub fn display_decimal128(val: i128, _precision: u8, scale: i8) -> String {
if scale < 0 {
unimplemented!();
Expand Down
1 change: 1 addition & 0 deletions src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_function(wrap_pyfunction_bound!(python::date_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::time_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::timestamp_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::duration_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::decimal_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::series_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::stateless_udf, parent)?)?;
Expand Down
18 changes: 15 additions & 3 deletions src/daft-dsl/src/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use common_hashable_float_wrapper::FloatWrapper;
use daft_core::{
prelude::*,
utils::display::{
display_date32, display_decimal128, display_series_literal, display_time64,
display_timestamp,
display_date32, display_decimal128, display_duration, display_series_literal,
display_time64, display_timestamp,
},
};
use indexmap::IndexMap;
Expand Down Expand Up @@ -60,6 +60,8 @@ pub enum LiteralValue {
Date(i32),
/// An [`i64`] representing a time in microseconds or nanoseconds since midnight.
Time(i64, TimeUnit),
/// An [`i64`] representing a measure of elapsed time. This elapsed time is a physical duration (i.e. 1s as defined in S.I.)
Duration(i64, TimeUnit),
/// A 64-bit floating point number.
Float64(f64),
/// An [`i128`] representing a decimal number with the provided precision and scale.
Expand Down Expand Up @@ -99,6 +101,10 @@ impl Hash for LiteralValue {
tu.hash(state);
tz.hash(state);
}
Duration(n, tu) => {
n.hash(state);
tu.hash(state);
}
// Wrap float64 in hashable newtype.
Float64(n) => FloatWrapper(*n).hash(state),
Decimal(n, precision, scale) => {
Expand Down Expand Up @@ -141,6 +147,7 @@ impl Display for LiteralValue {
Date(val) => write!(f, "{}", display_date32(*val)),
Time(val, tu) => write!(f, "{}", display_time64(*val, tu)),
Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)),
Duration(val, tu) => write!(f, "{}", display_duration(*val, tu)),
Float64(val) => write!(f, "{val:.1}"),
Decimal(val, precision, scale) => {
write!(f, "{}", display_decimal128(*val, *precision, *scale))
Expand Down Expand Up @@ -181,6 +188,7 @@ impl LiteralValue {
Date(_) => DataType::Date,
Time(_, tu) => DataType::Time(*tu),
Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()),
Duration(_, tu) => DataType::Duration(*tu),
Float64(_) => DataType::Float64,
Decimal(_, precision, scale) => {
DataType::Decimal128(*precision as usize, *scale as usize)
Expand Down Expand Up @@ -215,6 +223,10 @@ impl LiteralValue {
let physical = Int64Array::from(("literal", [*val].as_slice()));
TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series()
}
Duration(val, ..) => {
let physical = Int64Array::from(("literal", [*val].as_slice()));
DurationArray::new(Field::new("literal", self.get_type()), physical).into_series()
}
Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(),
Decimal(val, ..) => {
let physical = Int128Array::from(("literal", [*val].as_slice()));
Expand Down Expand Up @@ -259,7 +271,7 @@ impl LiteralValue {
display_timestamp(*val, tu, tz).replace('T', " ")
),
// TODO(Colin): Implement the rest of the types in future work for SQL pushdowns.
Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err,
Decimal(..) | Series(..) | Time(..) | Binary(..) | Duration(..) => display_sql_err,
#[cfg(feature = "python")]
Python(..) => display_sql_err,
Struct(..) => display_sql_err,
Expand Down
6 changes: 6 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ pub fn timestamp_lit(val: i64, tu: PyTimeUnit, tz: Option<String>) -> PyResult<P
Ok(expr.into())
}

#[pyfunction]
pub fn duration_lit(val: i64, tu: PyTimeUnit) -> PyResult<PyExpr> {
let expr = Expr::Literal(LiteralValue::Duration(val, tu.timeunit));
Ok(expr.into())
}

fn decimal_from_digits(digits: Vec<u8>, exp: i32) -> Option<(i128, usize)> {
const MAX_ABS_DEC: i128 = 10_i128.pow(38) - 1;
let mut v = 0_i128;
Expand Down
27 changes: 27 additions & 0 deletions tests/dataframe/test_temporals.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,33 @@ def test_python_duration() -> None:
assert res == duration


def test_temporal_arithmetic_with_duration_lit() -> None:
df = daft.from_pydict(
{
"duration": [timedelta(days=1)],
"date": [datetime(2021, 1, 1)],
"timestamp": [datetime(2021, 1, 1)],
}
)

df = df.select(
(df["date"] + timedelta(days=1)).alias("add_date"),
(df["date"] - timedelta(days=1)).alias("sub_date"),
(df["timestamp"] + timedelta(days=1)).alias("add_timestamp"),
(df["timestamp"] - timedelta(days=1)).alias("sub_timestamp"),
(df["duration"] + timedelta(days=1)).alias("add_dur"),
(df["duration"] - timedelta(days=1)).alias("sub_dur"),
)

result = df.to_pydict()
assert result["add_date"] == [datetime(2021, 1, 2)]
assert result["sub_date"] == [datetime(2020, 12, 31)]
assert result["add_timestamp"] == [datetime(2021, 1, 2)]
assert result["sub_timestamp"] == [datetime(2020, 12, 31)]
assert result["add_dur"] == [timedelta(days=2)]
assert result["sub_dur"] == [timedelta(0)]


@pytest.mark.parametrize(
"timeunit",
["s", "ms", "us", "ns"],
Expand Down
27 changes: 27 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,33 @@ def test_datetime_lit_different_timeunits(timeunit, expected) -> None:
assert timestamp_repr == expected


@pytest.mark.parametrize(
"input, expected",
[
(
timedelta(days=1),
"lit(1d)",
),
(
timedelta(days=1, hours=12, minutes=30, seconds=59),
"lit(1d 12h 30m 59s)",
),
(
timedelta(days=1, hours=12, minutes=30, seconds=59, microseconds=123456),
"lit(1d 12h 30m 59s 123456µs)",
),
(
timedelta(weeks=1, days=1, hours=12, minutes=30, seconds=59, microseconds=123456),
"lit(8d 12h 30m 59s 123456µs)",
),
],
)
def test_duration_lit(input, expected) -> None:
d = lit(input)
output = repr(d)
assert output == expected


def test_repr_series_lit() -> None:
s = lit(Series.from_pylist([1, 2, 3]))
output = repr(s)
Expand Down

0 comments on commit f5cf5af

Please sign in to comment.