Skip to content

Commit

Permalink
feat: Allow implicit string → temporal conversion in SQL comparisons (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored and Wouittone committed Jun 22, 2024
1 parent 51be09b commit 8fa4e8f
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 43 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_
polars-plan = { workspace = true }

hex = { workspace = true }
once_cell = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
Expand Down
18 changes: 11 additions & 7 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,9 @@ impl SQLContext {
let mut contains_wildcard_exclude = false;

// Filter expression.
let schema = Some(lf.schema()?);
if let Some(expr) = select_stmt.selection.as_ref() {
let mut filter_expression = parse_sql_expr(expr, self)?;
let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?;
lf = self.process_subqueries(lf, vec![&mut filter_expression]);
lf = lf.filter(filter_expression);
}
Expand All @@ -382,9 +383,9 @@ impl SQLContext {
.iter()
.map(|select_item| {
Ok(match select_item {
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self)?,
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self, schema.as_deref())?,
SelectItem::ExprWithAlias { expr, alias } => {
let expr = parse_sql_expr(expr, self)?;
let expr = parse_sql_expr(expr, self, schema.as_deref())?;
expr.alias(&alias.value)
},
SelectItem::QualifiedWildcard(oname, wildcard_options) => self
Expand Down Expand Up @@ -427,7 +428,7 @@ impl SQLContext {
ComputeError:
"group_by error: a positive number or an expression expected",
)),
_ => parse_sql_expr(e, self),
_ => parse_sql_expr(e, self, schema.as_deref()),
})
.collect::<PolarsResult<_>>()?
} else {
Expand Down Expand Up @@ -506,8 +507,9 @@ impl SQLContext {
lf = self.process_order_by(lf, &query.order_by)?;

// Apply optional 'having' clause, post-aggregation.
let schema = Some(lf.schema()?);
match select_stmt.having.as_ref() {
Some(expr) => lf.filter(parse_sql_expr(expr, self)?),
Some(expr) => lf.filter(parse_sql_expr(expr, self, schema.as_deref())?),
None => lf,
}
};
Expand All @@ -517,10 +519,11 @@ impl SQLContext {
Some(Distinct::Distinct) => lf.unique_stable(None, UniqueKeepStrategy::Any),
Some(Distinct::On(exprs)) => {
// TODO: support exprs in `unique` see https://github.com/pola-rs/polars/issues/5760
let schema = Some(lf.schema()?);
let cols = exprs
.iter()
.map(|e| {
let expr = parse_sql_expr(e, self)?;
let expr = parse_sql_expr(e, self, schema.as_deref())?;
if let Expr::Column(name) = expr {
Ok(name.to_string())
} else {
Expand Down Expand Up @@ -664,8 +667,9 @@ impl SQLContext {
let mut by = Vec::with_capacity(ob.len());
let mut descending = Vec::with_capacity(ob.len());

let schema = Some(lf.schema()?);
for ob in ob {
by.push(parse_sql_expr(&ob.expr, self)?);
by.push(parse_sql_expr(&ob.expr, self, schema.as_deref())?);
descending.push(!ob.asc.unwrap_or(true));
polars_ensure!(
ob.nulls_first.is_none(),
Expand Down
22 changes: 11 additions & 11 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ impl SQLFunctionVisitor<'_> {
.into_iter()
.map(|arg| {
if let FunctionArgExpr::Expr(e) = arg {
parse_sql_expr(e, self.ctx)
parse_sql_expr(e, self.ctx, None)
} else {
polars_bail!(ComputeError: "Only expressions are supported in UDFs")
}
Expand Down Expand Up @@ -1130,7 +1130,7 @@ impl SQLFunctionVisitor<'_> {
let (order_by, desc): (Vec<Expr>, Vec<bool>) = order_by
.iter()
.map(|o| {
let expr = parse_sql_expr(&o.expr, self.ctx)?;
let expr = parse_sql_expr(&o.expr, self.ctx, None)?;
Ok(match o.asc {
Some(b) => (expr, !b),
None => (expr, false),
Expand All @@ -1157,7 +1157,7 @@ impl SQLFunctionVisitor<'_> {
let args = extract_args(self.func);
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr)] => {
let expr = parse_sql_expr(sql_expr, self.ctx)?;
let expr = parse_sql_expr(sql_expr, self.ctx, None)?;
// apply the function on the inner expr -- e.g. SUM(a) -> SUM
Ok(f(expr))
},
Expand All @@ -1179,7 +1179,7 @@ impl SQLFunctionVisitor<'_> {
let args = extract_args(self.func);
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2)] => {
let expr1 = parse_sql_expr(sql_expr1, self.ctx)?;
let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?;
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
f(expr1, expr2)
},
Expand All @@ -1199,7 +1199,7 @@ impl SQLFunctionVisitor<'_> {
let mut expr_args = vec![];
for arg in args {
if let FunctionArgExpr::Expr(sql_expr) = arg {
expr_args.push(parse_sql_expr(sql_expr, self.ctx)?);
expr_args.push(parse_sql_expr(sql_expr, self.ctx, None)?);
} else {
return self.not_supported_error();
};
Expand All @@ -1215,7 +1215,7 @@ impl SQLFunctionVisitor<'_> {
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2), FunctionArgExpr::Expr(sql_expr3)] =>
{
let expr1 = parse_sql_expr(sql_expr1, self.ctx)?;
let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?;
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?;
f(expr1, expr2, expr3)
Expand All @@ -1239,15 +1239,15 @@ impl SQLFunctionVisitor<'_> {
(false, []) => Ok(len()),
// count(column_name)
(false, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx)?;
let expr = parse_sql_expr(sql_expr, self.ctx, None)?;
let expr = self.apply_window_spec(expr, &self.func.over)?;
Ok(expr.count())
},
// count(*)
(false, [FunctionArgExpr::Wildcard]) => Ok(len()),
// count(distinct column_name)
(true, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx)?;
let expr = parse_sql_expr(sql_expr, self.ctx, None)?;
let expr = self.apply_window_spec(expr, &self.func.over)?;
Ok(expr.n_unique())
},
Expand All @@ -1267,7 +1267,7 @@ impl SQLFunctionVisitor<'_> {
.order_by
.iter()
.map(|o| {
let e = parse_sql_expr(&o.expr, self.ctx)?;
let e = parse_sql_expr(&o.expr, self.ctx, None)?;
Ok(o.asc.map_or(e.clone(), |b| {
e.sort(SortOptions::default().with_order_descending(!b))
}))
Expand All @@ -1279,7 +1279,7 @@ impl SQLFunctionVisitor<'_> {
let partition_by = window_spec
.partition_by
.iter()
.map(|p| parse_sql_expr(p, self.ctx))
.map(|p| parse_sql_expr(p, self.ctx, None))
.collect::<PolarsResult<Vec<_>>>()?;
expr.over(partition_by)
}
Expand Down Expand Up @@ -1388,6 +1388,6 @@ impl FromSQLExpr for Expr {
where
Self: Sized,
{
parse_sql_expr(expr, ctx)
parse_sql_expr(expr, ctx, None)
}
}
137 changes: 114 additions & 23 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use polars_plan::prelude::typed_lit;
use polars_plan::prelude::LiteralValue::Null;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use regex::{Regex, RegexBuilder};
#[cfg(feature = "dtype-decimal")]
use sqlparser::ast::ExactNumberInfo;
use sqlparser::ast::{
Expand All @@ -22,6 +23,21 @@ use sqlparser::parser::{Parser, ParserOptions};
use crate::functions::SQLFunctionVisitor;
use crate::SQLContext;

static DATE_LITERAL_RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
static TIME_LITERAL_RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();

fn timeunit_from_precision(prec: &Option<u64>) -> PolarsResult<TimeUnit> {
Ok(match prec {
None => TimeUnit::Microseconds,
Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds,
Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds,
Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds,
Some(n) => {
polars_bail!(ComputeError: "invalid temporal type precision; expected 1-9, found {}", n)
},
})
}

pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<DataType> {
Ok(match data_type {
// ---------------------------------
Expand Down Expand Up @@ -106,22 +122,12 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<D
polars_bail!(ComputeError: "`time` with timezone is not supported; found tz={}", tz)
},
},
SQLDataType::Timestamp(prec, tz) => {
let tu = match prec {
None => TimeUnit::Microseconds,
Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds,
Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds,
Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds,
Some(n) => {
polars_bail!(ComputeError: "unsupported `timestamp` precision; expected a value between 1 and 9, found {}", n)
},
};
match tz {
TimezoneInfo::None => DataType::Datetime(tu, None),
_ => {
polars_bail!(ComputeError: "`timestamp` with timezone is not (yet) supported; found tz={}", tz)
},
}
SQLDataType::Datetime(prec) => DataType::Datetime(timeunit_from_precision(prec)?, None),
SQLDataType::Timestamp(prec, tz) => match tz {
TimezoneInfo::None => DataType::Datetime(timeunit_from_precision(prec)?, None),
_ => {
polars_bail!(ComputeError: "`timestamp` with timezone is not (yet) supported; found tz={}", tz)
},
},

// ---------------------------------
Expand Down Expand Up @@ -173,6 +179,7 @@ pub enum SubqueryRestriction {
/// Recursively walks a SQL Expr to create a polars Expr
pub(crate) struct SQLExprVisitor<'a> {
ctx: &'a mut SQLContext,
active_schema: Option<&'a Schema>,
}

impl SQLExprVisitor<'_> {
Expand Down Expand Up @@ -396,17 +403,80 @@ impl SQLExprVisitor<'_> {
}
}

/// Handle implicit temporal string comparisons.
///
/// eg: "dt >= '2024-04-30'", or "dtm::date = '2077-10-10'"
fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
if let (Some(name), Some(s), expr_dtype) = match (left, right) {
// identify "col <op> string" expressions
(Expr::Column(name), Expr::Literal(LiteralValue::String(s))) => {
(Some(name.clone()), Some(s), None)
},
// identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
(
Expr::Cast {
expr, data_type, ..
},
Expr::Literal(LiteralValue::String(s)),
) => {
if let Expr::Column(name) = &**expr {
(Some(name.clone()), Some(s), Some(data_type))
} else {
(None, Some(s), Some(data_type))
}
},
_ => (None, None, None),
} {
if expr_dtype.is_none() && self.active_schema.is_none() {
right.clone()
} else {
let left_dtype = expr_dtype
.unwrap_or_else(|| self.active_schema.as_ref().unwrap().get(&name).unwrap());

let dt_regex = DATE_LITERAL_RE
.get_or_init(|| RegexBuilder::new(r"^\d{4}-[01]\d-[0-3]\d").build().unwrap());
let tm_regex = TIME_LITERAL_RE.get_or_init(|| {
RegexBuilder::new(r"^[012]\d:[0-5]\d:[0-5]\d")
.build()
.unwrap()
});

match left_dtype {
DataType::Time if tm_regex.is_match(s) => {
right.clone().strict_cast(left_dtype.clone())
},
DataType::Date if dt_regex.is_match(s) => {
right.clone().strict_cast(left_dtype.clone())
},
DataType::Datetime(_, _) if dt_regex.is_match(s) => {
if s.len() == 10 {
// handle upcast from ISO date string (10 chars) to datetime
lit(format!("{}T00:00:00", s)).strict_cast(left_dtype.clone())
} else {
lit(s.replacen(' ', "T", 1)).strict_cast(left_dtype.clone())
}
},
_ => right.clone(),
}
}
} else {
right.clone()
}
}

/// Visit a SQL binary operator.
///
/// e.g. column + 1 or column1 / column2
/// e.g. "column + 1", "column1 <= column2"
fn visit_binary_op(
&mut self,
left: &SQLExpr,
op: &BinaryOperator,
right: &SQLExpr,
) -> PolarsResult<Expr> {
let left = self.visit_expr(left)?;
let right = self.visit_expr(right)?;
let mut right = self.visit_expr(right)?;
right = self.convert_temporal_strings(&left, &right);

Ok(match op {
SQLBinaryOperator::And => left.and(right),
SQLBinaryOperator::Divide => left / right,
Expand Down Expand Up @@ -747,8 +817,25 @@ impl SQLExprVisitor<'_> {
}
})
.collect::<PolarsResult<Vec<_>>>()?;
let s = Series::from_any_values("", &list, true)?;

let mut s = Series::from_any_values("", &list, true)?;

// handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
// (not yet as versatile as the temporal string conversions in visit_binary_op)
if s.dtype() == &DataType::String {
// handle implicit temporal string comparisons, eg: "dt >= '2024-04-30'"
if let Expr::Column(name) = &expr {
if self.active_schema.is_some() {
let schema = self.active_schema.as_ref().unwrap();
let left_dtype = schema.get(name);
if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) =
left_dtype
{
s = s.strict_cast(&left_dtype.unwrap().clone())?;
}
}
}
}
if negated {
Ok(expr.is_in(lit(s)).not())
} else {
Expand Down Expand Up @@ -1011,16 +1098,20 @@ pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {

Ok(match &expr {
SelectItem::ExprWithAlias { expr, alias } => {
let expr = parse_sql_expr(expr, &mut ctx)?;
let expr = parse_sql_expr(expr, &mut ctx, None)?;
expr.alias(&alias.value)
},
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx)?,
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
_ => polars_bail!(InvalidOperation: "Unable to parse '{}' as Expr", s.as_ref()),
})
}

pub(crate) fn parse_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Expr> {
let mut visitor = SQLExprVisitor { ctx };
pub(crate) fn parse_sql_expr(
expr: &SQLExpr,
ctx: &mut SQLContext,
active_schema: Option<&Schema>,
) -> PolarsResult<Expr> {
let mut visitor = SQLExprVisitor { ctx, active_schema };
visitor.visit_expr(expr)
}

Expand Down
Loading

0 comments on commit 8fa4e8f

Please sign in to comment.