Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow implicit string → temporal conversion in SQL comparisons #15958

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_
polars-plan = { workspace = true }

hex = { workspace = true }
once_cell = { workspace = true }
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
Expand Down
18 changes: 11 additions & 7 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,9 @@ impl SQLContext {
let mut contains_wildcard_exclude = false;

// Filter expression.
let schema = Some(lf.schema()?);
if let Some(expr) = select_stmt.selection.as_ref() {
let mut filter_expression = parse_sql_expr(expr, self)?;
let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?;
lf = self.process_subqueries(lf, vec![&mut filter_expression]);
lf = lf.filter(filter_expression);
}
Expand All @@ -382,9 +383,9 @@ impl SQLContext {
.iter()
.map(|select_item| {
Ok(match select_item {
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self)?,
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self, schema.as_deref())?,
SelectItem::ExprWithAlias { expr, alias } => {
let expr = parse_sql_expr(expr, self)?;
let expr = parse_sql_expr(expr, self, schema.as_deref())?;
expr.alias(&alias.value)
},
SelectItem::QualifiedWildcard(oname, wildcard_options) => self
Expand Down Expand Up @@ -427,7 +428,7 @@ impl SQLContext {
ComputeError:
"group_by error: a positive number or an expression expected",
)),
_ => parse_sql_expr(e, self),
_ => parse_sql_expr(e, self, schema.as_deref()),
})
.collect::<PolarsResult<_>>()?
} else {
Expand Down Expand Up @@ -506,8 +507,9 @@ impl SQLContext {
lf = self.process_order_by(lf, &query.order_by)?;

// Apply optional 'having' clause, post-aggregation.
let schema = Some(lf.schema()?);
match select_stmt.having.as_ref() {
Some(expr) => lf.filter(parse_sql_expr(expr, self)?),
Some(expr) => lf.filter(parse_sql_expr(expr, self, schema.as_deref())?),
None => lf,
}
};
Expand All @@ -517,10 +519,11 @@ impl SQLContext {
Some(Distinct::Distinct) => lf.unique_stable(None, UniqueKeepStrategy::Any),
Some(Distinct::On(exprs)) => {
// TODO: support exprs in `unique` see https://github.com/pola-rs/polars/issues/5760
let schema = Some(lf.schema()?);
let cols = exprs
.iter()
.map(|e| {
let expr = parse_sql_expr(e, self)?;
let expr = parse_sql_expr(e, self, schema.as_deref())?;
if let Expr::Column(name) = expr {
Ok(name.to_string())
} else {
Expand Down Expand Up @@ -664,8 +667,9 @@ impl SQLContext {
let mut by = Vec::with_capacity(ob.len());
let mut descending = Vec::with_capacity(ob.len());

let schema = Some(lf.schema()?);
for ob in ob {
by.push(parse_sql_expr(&ob.expr, self)?);
by.push(parse_sql_expr(&ob.expr, self, schema.as_deref())?);
descending.push(!ob.asc.unwrap_or(true));
polars_ensure!(
ob.nulls_first.is_none(),
Expand Down
22 changes: 11 additions & 11 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ impl SQLFunctionVisitor<'_> {
.into_iter()
.map(|arg| {
if let FunctionArgExpr::Expr(e) = arg {
parse_sql_expr(e, self.ctx)
parse_sql_expr(e, self.ctx, None)
} else {
polars_bail!(ComputeError: "Only expressions are supported in UDFs")
}
Expand Down Expand Up @@ -1130,7 +1130,7 @@ impl SQLFunctionVisitor<'_> {
let (order_by, desc): (Vec<Expr>, Vec<bool>) = order_by
.iter()
.map(|o| {
let expr = parse_sql_expr(&o.expr, self.ctx)?;
let expr = parse_sql_expr(&o.expr, self.ctx, None)?;
Ok(match o.asc {
Some(b) => (expr, !b),
None => (expr, false),
Expand All @@ -1157,7 +1157,7 @@ impl SQLFunctionVisitor<'_> {
let args = extract_args(self.func);
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr)] => {
let expr = parse_sql_expr(sql_expr, self.ctx)?;
let expr = parse_sql_expr(sql_expr, self.ctx, None)?;
// apply the function on the inner expr -- e.g. SUM(a) -> SUM
Ok(f(expr))
},
Expand All @@ -1179,7 +1179,7 @@ impl SQLFunctionVisitor<'_> {
let args = extract_args(self.func);
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2)] => {
let expr1 = parse_sql_expr(sql_expr1, self.ctx)?;
let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?;
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
f(expr1, expr2)
},
Expand All @@ -1199,7 +1199,7 @@ impl SQLFunctionVisitor<'_> {
let mut expr_args = vec![];
for arg in args {
if let FunctionArgExpr::Expr(sql_expr) = arg {
expr_args.push(parse_sql_expr(sql_expr, self.ctx)?);
expr_args.push(parse_sql_expr(sql_expr, self.ctx, None)?);
} else {
return self.not_supported_error();
};
Expand All @@ -1215,7 +1215,7 @@ impl SQLFunctionVisitor<'_> {
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2), FunctionArgExpr::Expr(sql_expr3)] =>
{
let expr1 = parse_sql_expr(sql_expr1, self.ctx)?;
let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?;
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?;
f(expr1, expr2, expr3)
Expand All @@ -1239,15 +1239,15 @@ impl SQLFunctionVisitor<'_> {
(false, []) => Ok(len()),
// count(column_name)
(false, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx)?;
let expr = parse_sql_expr(sql_expr, self.ctx, None)?;
let expr = self.apply_window_spec(expr, &self.func.over)?;
Ok(expr.count())
},
// count(*)
(false, [FunctionArgExpr::Wildcard]) => Ok(len()),
// count(distinct column_name)
(true, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx)?;
let expr = parse_sql_expr(sql_expr, self.ctx, None)?;
let expr = self.apply_window_spec(expr, &self.func.over)?;
Ok(expr.n_unique())
},
Expand All @@ -1267,7 +1267,7 @@ impl SQLFunctionVisitor<'_> {
.order_by
.iter()
.map(|o| {
let e = parse_sql_expr(&o.expr, self.ctx)?;
let e = parse_sql_expr(&o.expr, self.ctx, None)?;
Ok(o.asc.map_or(e.clone(), |b| {
e.sort(SortOptions::default().with_order_descending(!b))
}))
Expand All @@ -1279,7 +1279,7 @@ impl SQLFunctionVisitor<'_> {
let partition_by = window_spec
.partition_by
.iter()
.map(|p| parse_sql_expr(p, self.ctx))
.map(|p| parse_sql_expr(p, self.ctx, None))
.collect::<PolarsResult<Vec<_>>>()?;
expr.over(partition_by)
}
Expand Down Expand Up @@ -1388,6 +1388,6 @@ impl FromSQLExpr for Expr {
where
Self: Sized,
{
parse_sql_expr(expr, ctx)
parse_sql_expr(expr, ctx, None)
}
}
137 changes: 114 additions & 23 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use polars_plan::prelude::typed_lit;
use polars_plan::prelude::LiteralValue::Null;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use regex::{Regex, RegexBuilder};
#[cfg(feature = "dtype-decimal")]
use sqlparser::ast::ExactNumberInfo;
use sqlparser::ast::{
Expand All @@ -22,6 +23,21 @@ use sqlparser::parser::{Parser, ParserOptions};
use crate::functions::SQLFunctionVisitor;
use crate::SQLContext;

static DATE_LITERAL_RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
static TIME_LITERAL_RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();

fn timeunit_from_precision(prec: &Option<u64>) -> PolarsResult<TimeUnit> {
Ok(match prec {
None => TimeUnit::Microseconds,
Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds,
Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds,
Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds,
Some(n) => {
polars_bail!(ComputeError: "invalid temporal type precision; expected 1-9, found {}", n)
},
})
}

pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<DataType> {
Ok(match data_type {
// ---------------------------------
Expand Down Expand Up @@ -106,22 +122,12 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<D
polars_bail!(ComputeError: "`time` with timezone is not supported; found tz={}", tz)
},
},
SQLDataType::Timestamp(prec, tz) => {
let tu = match prec {
None => TimeUnit::Microseconds,
Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds,
Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds,
Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds,
Some(n) => {
polars_bail!(ComputeError: "unsupported `timestamp` precision; expected a value between 1 and 9, found {}", n)
},
};
match tz {
TimezoneInfo::None => DataType::Datetime(tu, None),
_ => {
polars_bail!(ComputeError: "`timestamp` with timezone is not (yet) supported; found tz={}", tz)
},
}
SQLDataType::Datetime(prec) => DataType::Datetime(timeunit_from_precision(prec)?, None),
SQLDataType::Timestamp(prec, tz) => match tz {
TimezoneInfo::None => DataType::Datetime(timeunit_from_precision(prec)?, None),
_ => {
polars_bail!(ComputeError: "`timestamp` with timezone is not (yet) supported; found tz={}", tz)
},
},

// ---------------------------------
Expand Down Expand Up @@ -173,6 +179,7 @@ pub enum SubqueryRestriction {
/// Recursively walks a SQL Expr to create a polars Expr
pub(crate) struct SQLExprVisitor<'a> {
ctx: &'a mut SQLContext,
active_schema: Option<&'a Schema>,
}

impl SQLExprVisitor<'_> {
Expand Down Expand Up @@ -396,17 +403,80 @@ impl SQLExprVisitor<'_> {
}
}

/// Handle implicit temporal string comparisons.
///
/// eg: "dt >= '2024-04-30'", or "dtm::date = '2077-10-10'"
fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
if let (Some(name), Some(s), expr_dtype) = match (left, right) {
// identify "col <op> string" expressions
(Expr::Column(name), Expr::Literal(LiteralValue::String(s))) => {
(Some(name.clone()), Some(s), None)
},
// identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
(
Expr::Cast {
expr, data_type, ..
},
Expr::Literal(LiteralValue::String(s)),
) => {
if let Expr::Column(name) = &**expr {
(Some(name.clone()), Some(s), Some(data_type))
} else {
(None, Some(s), Some(data_type))
}
},
_ => (None, None, None),
} {
if expr_dtype.is_none() && self.active_schema.is_none() {
right.clone()
} else {
let left_dtype = expr_dtype
.unwrap_or_else(|| self.active_schema.as_ref().unwrap().get(&name).unwrap());

let dt_regex = DATE_LITERAL_RE
.get_or_init(|| RegexBuilder::new(r"^\d{4}-[01]\d-[0-3]\d").build().unwrap());
let tm_regex = TIME_LITERAL_RE.get_or_init(|| {
RegexBuilder::new(r"^[012]\d:[0-5]\d:[0-5]\d")
.build()
.unwrap()
});

match left_dtype {
DataType::Time if tm_regex.is_match(s) => {
right.clone().strict_cast(left_dtype.clone())
},
DataType::Date if dt_regex.is_match(s) => {
right.clone().strict_cast(left_dtype.clone())
},
DataType::Datetime(_, _) if dt_regex.is_match(s) => {
if s.len() == 10 {
// handle upcast from ISO date string (10 chars) to datetime
lit(format!("{}T00:00:00", s)).strict_cast(left_dtype.clone())
} else {
lit(s.replacen(' ', "T", 1)).strict_cast(left_dtype.clone())
}
},
_ => right.clone(),
}
}
} else {
right.clone()
}
}

/// Visit a SQL binary operator.
///
/// e.g. column + 1 or column1 / column2
/// e.g. "column + 1", "column1 <= column2"
fn visit_binary_op(
&mut self,
left: &SQLExpr,
op: &BinaryOperator,
right: &SQLExpr,
) -> PolarsResult<Expr> {
let left = self.visit_expr(left)?;
let right = self.visit_expr(right)?;
let mut right = self.visit_expr(right)?;
right = self.convert_temporal_strings(&left, &right);

Ok(match op {
SQLBinaryOperator::And => left.and(right),
SQLBinaryOperator::Divide => left / right,
Expand Down Expand Up @@ -747,8 +817,25 @@ impl SQLExprVisitor<'_> {
}
})
.collect::<PolarsResult<Vec<_>>>()?;
let s = Series::from_any_values("", &list, true)?;

let mut s = Series::from_any_values("", &list, true)?;

// handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
// (not yet as versatile as the temporal string conversions in visit_binary_op)
if s.dtype() == &DataType::String {
// handle implicit temporal string comparisons, eg: "dt >= '2024-04-30'"
if let Expr::Column(name) = &expr {
if self.active_schema.is_some() {
let schema = self.active_schema.as_ref().unwrap();
let left_dtype = schema.get(name);
if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) =
left_dtype
{
s = s.strict_cast(&left_dtype.unwrap().clone())?;
}
}
}
}
if negated {
Ok(expr.is_in(lit(s)).not())
} else {
Expand Down Expand Up @@ -1011,16 +1098,20 @@ pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {

Ok(match &expr {
SelectItem::ExprWithAlias { expr, alias } => {
let expr = parse_sql_expr(expr, &mut ctx)?;
let expr = parse_sql_expr(expr, &mut ctx, None)?;
expr.alias(&alias.value)
},
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx)?,
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
_ => polars_bail!(InvalidOperation: "Unable to parse '{}' as Expr", s.as_ref()),
})
}

pub(crate) fn parse_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Expr> {
let mut visitor = SQLExprVisitor { ctx };
pub(crate) fn parse_sql_expr(
expr: &SQLExpr,
ctx: &mut SQLContext,
active_schema: Option<&Schema>,
) -> PolarsResult<Expr> {
let mut visitor = SQLExprVisitor { ctx, active_schema };
visitor.visit_expr(expr)
}

Expand Down
Loading