diff --git a/Cargo.lock b/Cargo.lock index 661c396729e82..99c5dc7fbd3c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3122,6 +3122,7 @@ name = "polars-sql" version = "0.39.2" dependencies = [ "hex", + "once_cell", "polars-arrow", "polars-core", "polars-error", diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 1f2d32413563f..6eae9faa22273 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -16,6 +16,7 @@ polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_ polars-plan = { workspace = true } hex = { workspace = true } +once_cell = { workspace = true } rand = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 6fc6ac5599687..e594ce5b9e0c6 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -370,8 +370,9 @@ impl SQLContext { let mut contains_wildcard_exclude = false; // Filter expression. + let schema = Some(lf.schema()?); if let Some(expr) = select_stmt.selection.as_ref() { - let mut filter_expression = parse_sql_expr(expr, self)?; + let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?; lf = self.process_subqueries(lf, vec![&mut filter_expression]); lf = lf.filter(filter_expression); } @@ -382,9 +383,9 @@ impl SQLContext { .iter() .map(|select_item| { Ok(match select_item { - SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self)?, + SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self, schema.as_deref())?, SelectItem::ExprWithAlias { expr, alias } => { - let expr = parse_sql_expr(expr, self)?; + let expr = parse_sql_expr(expr, self, schema.as_deref())?; expr.alias(&alias.value) }, SelectItem::QualifiedWildcard(oname, wildcard_options) => self @@ -427,7 +428,7 @@ impl SQLContext { ComputeError: "group_by error: a positive number or an expression expected", )), - _ => parse_sql_expr(e, self), + _ => parse_sql_expr(e, self, schema.as_deref()), }) .collect::>()? } else { @@ -506,8 +507,9 @@ impl SQLContext { lf = self.process_order_by(lf, &query.order_by)?; // Apply optional 'having' clause, post-aggregation. + let schema = Some(lf.schema()?); match select_stmt.having.as_ref() { - Some(expr) => lf.filter(parse_sql_expr(expr, self)?), + Some(expr) => lf.filter(parse_sql_expr(expr, self, schema.as_deref())?), None => lf, } }; @@ -517,10 +519,11 @@ impl SQLContext { Some(Distinct::Distinct) => lf.unique_stable(None, UniqueKeepStrategy::Any), Some(Distinct::On(exprs)) => { // TODO: support exprs in `unique` see https://github.com/pola-rs/polars/issues/5760 + let schema = Some(lf.schema()?); let cols = exprs .iter() .map(|e| { - let expr = parse_sql_expr(e, self)?; + let expr = parse_sql_expr(e, self, schema.as_deref())?; if let Expr::Column(name) = expr { Ok(name.to_string()) } else { @@ -664,8 +667,9 @@ impl SQLContext { let mut by = Vec::with_capacity(ob.len()); let mut descending = Vec::with_capacity(ob.len()); + let schema = Some(lf.schema()?); for ob in ob { - by.push(parse_sql_expr(&ob.expr, self)?); + by.push(parse_sql_expr(&ob.expr, self, schema.as_deref())?); descending.push(!ob.asc.unwrap_or(true)); polars_ensure!( ob.nulls_first.is_none(), diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 05cd4bc8959ad..78eb5023d6a48 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1074,7 +1074,7 @@ impl SQLFunctionVisitor<'_> { .into_iter() .map(|arg| { if let FunctionArgExpr::Expr(e) = arg { - parse_sql_expr(e, self.ctx) + parse_sql_expr(e, self.ctx, None) } else { polars_bail!(ComputeError: "Only expressions are supported in UDFs") } @@ -1130,7 +1130,7 @@ impl SQLFunctionVisitor<'_> { let (order_by, desc): (Vec, Vec) = order_by .iter() .map(|o| { - let expr = parse_sql_expr(&o.expr, self.ctx)?; + let expr = parse_sql_expr(&o.expr, self.ctx, None)?; Ok(match o.asc { Some(b) => (expr, !b), None => (expr, false), @@ -1157,7 +1157,7 @@ impl SQLFunctionVisitor<'_> { let args = extract_args(self.func); match args.as_slice() { [FunctionArgExpr::Expr(sql_expr)] => { - let expr = parse_sql_expr(sql_expr, self.ctx)?; + let expr = parse_sql_expr(sql_expr, self.ctx, None)?; // apply the function on the inner expr -- e.g. SUM(a) -> SUM Ok(f(expr)) }, @@ -1179,7 +1179,7 @@ impl SQLFunctionVisitor<'_> { let args = extract_args(self.func); match args.as_slice() { [FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2)] => { - let expr1 = parse_sql_expr(sql_expr1, self.ctx)?; + let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?; let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?; f(expr1, expr2) }, @@ -1199,7 +1199,7 @@ impl SQLFunctionVisitor<'_> { let mut expr_args = vec![]; for arg in args { if let FunctionArgExpr::Expr(sql_expr) = arg { - expr_args.push(parse_sql_expr(sql_expr, self.ctx)?); + expr_args.push(parse_sql_expr(sql_expr, self.ctx, None)?); } else { return self.not_supported_error(); }; @@ -1215,7 +1215,7 @@ impl SQLFunctionVisitor<'_> { match args.as_slice() { [FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2), FunctionArgExpr::Expr(sql_expr3)] => { - let expr1 = parse_sql_expr(sql_expr1, self.ctx)?; + let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?; let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?; let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?; f(expr1, expr2, expr3) @@ -1239,7 +1239,7 @@ impl SQLFunctionVisitor<'_> { (false, []) => Ok(len()), // count(column_name) (false, [FunctionArgExpr::Expr(sql_expr)]) => { - let expr = parse_sql_expr(sql_expr, self.ctx)?; + let expr = parse_sql_expr(sql_expr, self.ctx, None)?; let expr = self.apply_window_spec(expr, &self.func.over)?; Ok(expr.count()) }, @@ -1247,7 +1247,7 @@ impl SQLFunctionVisitor<'_> { (false, [FunctionArgExpr::Wildcard]) => Ok(len()), // count(distinct column_name) (true, [FunctionArgExpr::Expr(sql_expr)]) => { - let expr = parse_sql_expr(sql_expr, self.ctx)?; + let expr = parse_sql_expr(sql_expr, self.ctx, None)?; let expr = self.apply_window_spec(expr, &self.func.over)?; Ok(expr.n_unique()) }, @@ -1267,7 +1267,7 @@ impl SQLFunctionVisitor<'_> { .order_by .iter() .map(|o| { - let e = parse_sql_expr(&o.expr, self.ctx)?; + let e = parse_sql_expr(&o.expr, self.ctx, None)?; Ok(o.asc.map_or(e.clone(), |b| { e.sort(SortOptions::default().with_order_descending(!b)) })) @@ -1279,7 +1279,7 @@ impl SQLFunctionVisitor<'_> { let partition_by = window_spec .partition_by .iter() - .map(|p| parse_sql_expr(p, self.ctx)) + .map(|p| parse_sql_expr(p, self.ctx, None)) .collect::>>()?; expr.over(partition_by) } @@ -1388,6 +1388,6 @@ impl FromSQLExpr for Expr { where Self: Sized, { - parse_sql_expr(expr, ctx) + parse_sql_expr(expr, ctx, None) } } diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index b20fde159b4fc..eac1a56fbe052 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -8,6 +8,7 @@ use polars_plan::prelude::typed_lit; use polars_plan::prelude::LiteralValue::Null; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; +use regex::{Regex, RegexBuilder}; #[cfg(feature = "dtype-decimal")] use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::{ @@ -22,6 +23,21 @@ use sqlparser::parser::{Parser, ParserOptions}; use crate::functions::SQLFunctionVisitor; use crate::SQLContext; +static DATE_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); +static TIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); + +fn timeunit_from_precision(prec: &Option) -> PolarsResult { + Ok(match prec { + None => TimeUnit::Microseconds, + Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds, + Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, + Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, + Some(n) => { + polars_bail!(ComputeError: "invalid temporal type precision; expected 1-9, found {}", n) + }, + }) +} + pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { Ok(match data_type { // --------------------------------- @@ -106,22 +122,12 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { - let tu = match prec { - None => TimeUnit::Microseconds, - Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds, - Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, - Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, - Some(n) => { - polars_bail!(ComputeError: "unsupported `timestamp` precision; expected a value between 1 and 9, found {}", n) - }, - }; - match tz { - TimezoneInfo::None => DataType::Datetime(tu, None), - _ => { - polars_bail!(ComputeError: "`timestamp` with timezone is not (yet) supported; found tz={}", tz) - }, - } + SQLDataType::Datetime(prec) => DataType::Datetime(timeunit_from_precision(prec)?, None), + SQLDataType::Timestamp(prec, tz) => match tz { + TimezoneInfo::None => DataType::Datetime(timeunit_from_precision(prec)?, None), + _ => { + polars_bail!(ComputeError: "`timestamp` with timezone is not (yet) supported; found tz={}", tz) + }, }, // --------------------------------- @@ -173,6 +179,7 @@ pub enum SubqueryRestriction { /// Recursively walks a SQL Expr to create a polars Expr pub(crate) struct SQLExprVisitor<'a> { ctx: &'a mut SQLContext, + active_schema: Option<&'a Schema>, } impl SQLExprVisitor<'_> { @@ -396,9 +403,70 @@ impl SQLExprVisitor<'_> { } } + /// Handle implicit temporal string comparisons. + /// + /// eg: "dt >= '2024-04-30'", or "dtm::date = '2077-10-10'" + fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr { + if let (Some(name), Some(s), expr_dtype) = match (left, right) { + // identify "col string" expressions + (Expr::Column(name), Expr::Literal(LiteralValue::String(s))) => { + (Some(name.clone()), Some(s), None) + }, + // identify "CAST(expr AS type) string" and/or "expr::type string" expressions + ( + Expr::Cast { + expr, data_type, .. + }, + Expr::Literal(LiteralValue::String(s)), + ) => { + if let Expr::Column(name) = &**expr { + (Some(name.clone()), Some(s), Some(data_type)) + } else { + (None, Some(s), Some(data_type)) + } + }, + _ => (None, None, None), + } { + if expr_dtype.is_none() && self.active_schema.is_none() { + right.clone() + } else { + let left_dtype = expr_dtype + .unwrap_or_else(|| self.active_schema.as_ref().unwrap().get(&name).unwrap()); + + let dt_regex = DATE_LITERAL_RE + .get_or_init(|| RegexBuilder::new(r"^\d{4}-[01]\d-[0-3]\d").build().unwrap()); + let tm_regex = TIME_LITERAL_RE.get_or_init(|| { + RegexBuilder::new(r"^[012]\d:[0-5]\d:[0-5]\d") + .build() + .unwrap() + }); + + match left_dtype { + DataType::Time if tm_regex.is_match(s) => { + right.clone().strict_cast(left_dtype.clone()) + }, + DataType::Date if dt_regex.is_match(s) => { + right.clone().strict_cast(left_dtype.clone()) + }, + DataType::Datetime(_, _) if dt_regex.is_match(s) => { + if s.len() == 10 { + // handle upcast from ISO date string (10 chars) to datetime + lit(format!("{}T00:00:00", s)).strict_cast(left_dtype.clone()) + } else { + lit(s.replacen(' ', "T", 1)).strict_cast(left_dtype.clone()) + } + }, + _ => right.clone(), + } + } + } else { + right.clone() + } + } + /// Visit a SQL binary operator. /// - /// e.g. column + 1 or column1 / column2 + /// e.g. "column + 1", "column1 <= column2" fn visit_binary_op( &mut self, left: &SQLExpr, @@ -406,7 +474,9 @@ impl SQLExprVisitor<'_> { right: &SQLExpr, ) -> PolarsResult { let left = self.visit_expr(left)?; - let right = self.visit_expr(right)?; + let mut right = self.visit_expr(right)?; + right = self.convert_temporal_strings(&left, &right); + Ok(match op { SQLBinaryOperator::And => left.and(right), SQLBinaryOperator::Divide => left / right, @@ -747,8 +817,25 @@ impl SQLExprVisitor<'_> { } }) .collect::>>()?; - let s = Series::from_any_values("", &list, true)?; + let mut s = Series::from_any_values("", &list, true)?; + + // handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')". + // (not yet as versatile as the temporal string conversions in visit_binary_op) + if s.dtype() == &DataType::String { + // handle implicit temporal string comparisons, eg: "dt >= '2024-04-30'" + if let Expr::Column(name) = &expr { + if self.active_schema.is_some() { + let schema = self.active_schema.as_ref().unwrap(); + let left_dtype = schema.get(name); + if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) = + left_dtype + { + s = s.strict_cast(&left_dtype.unwrap().clone())?; + } + } + } + } if negated { Ok(expr.is_in(lit(s)).not()) } else { @@ -1011,16 +1098,20 @@ pub fn sql_expr>(s: S) -> PolarsResult { Ok(match &expr { SelectItem::ExprWithAlias { expr, alias } => { - let expr = parse_sql_expr(expr, &mut ctx)?; + let expr = parse_sql_expr(expr, &mut ctx, None)?; expr.alias(&alias.value) }, - SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx)?, + SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?, _ => polars_bail!(InvalidOperation: "Unable to parse '{}' as Expr", s.as_ref()), }) } -pub(crate) fn parse_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult { - let mut visitor = SQLExprVisitor { ctx }; +pub(crate) fn parse_sql_expr( + expr: &SQLExpr, + ctx: &mut SQLContext, + active_schema: Option<&Schema>, +) -> PolarsResult { + let mut visitor = SQLExprVisitor { ctx, active_schema }; visitor.visit_expr(expr) } diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 92a69a03ea0c4..9a1338adf5fee 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -144,6 +144,37 @@ fn test_literal_exprs() { assert!(df_sql.equals_missing(&df_pl)); } +#[test] +fn test_implicit_date_string() { + let df = df! { + "idx" => &[Some(0), Some(1), Some(2), Some(3)], + "dt" => &[Some("1955-10-01"), None, Some("2007-07-05"), Some("2077-06-11")], + } + .unwrap() + .lazy() + .select(vec![col("idx"), col("dt").cast(DataType::Date)]) + .collect() + .unwrap(); + + let mut context = SQLContext::new(); + context.register("frame", df.clone().lazy()); + for sql in [ + "SELECT idx, dt FROM frame WHERE dt >= '2007-07-05'", + "SELECT idx, dt FROM frame WHERE dt::date >= '2007-07-05'", + "SELECT idx, dt FROM frame WHERE dt::datetime >= '2007-07-05 00:00:00'", + "SELECT idx, dt FROM frame WHERE dt::timestamp >= '2007-07-05 00:00:00'", + ] { + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + let df_pl = df + .clone() + .lazy() + .filter(col("idx").gt_eq(lit(2))) + .collect() + .unwrap(); + assert!(df_sql.equals(&df_pl)); + } +} + #[test] fn test_prefixed_column_names() { let df = create_sample_df().unwrap(); @@ -331,7 +362,7 @@ fn test_agg_functions() { } #[test] -fn create_table() { +fn test_create_table() { let df = create_sample_df().unwrap(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index 4babd435374f1..c9d95c84c5ad2 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -147,6 +147,63 @@ def test_extract_century_millennium(dt: date, expected: list[int]) -> None: ) +@pytest.mark.parametrize( + ("constraint", "expected"), + [ + ("dtm >= '2020-12-30T10:30:45.987'", [0, 2]), + ("dtm::date > '2006-01-01'", [0, 2]), + ("dtm > '2006-01-01'", [0, 1, 2]), # << implies '2006-01-01 00:00:00' + ("dtm <= '2006-01-01'", []), # << implies '2006-01-01 00:00:00' + ("dt != '1960-01-07'", [0, 1]), + ("dt::datetime = '1960-01-07'", [2]), + ("dt::datetime = '1960-01-07 00:00:00'", [2]), + ("dt IN ('1960-01-07','2077-01-01','2222-02-22')", [1, 2]), + ( + "dtm = '2024-01-07 01:02:03.123456000' OR dtm = '2020-12-30 10:30:45.987654'", + [0, 2], + ), + ], +) +def test_implicit_temporal_strings(constraint: str, expected: list[int]) -> None: + df = pl.DataFrame( + { + "idx": [0, 1, 2], + "dtm": [ + datetime(2024, 1, 7, 1, 2, 3, 123456), + datetime(2006, 1, 1, 23, 59, 59, 555555), + datetime(2020, 12, 30, 10, 30, 45, 987654), + ], + "dt": [ + date(2020, 12, 30), + date(2077, 1, 1), + date(1960, 1, 7), + ], + } + ) + res = df.sql(f"SELECT idx FROM self WHERE {constraint}") + actual = sorted(res["idx"]) + assert actual == expected + + +@pytest.mark.parametrize( + "dtval", + [ + "2020-12-30T10:30:45", + "yyyy-mm-dd", + "2222-22-22", + "10:30:45", + ], +) +def test_implicit_temporal_string_errors(dtval: str) -> None: + df = pl.DataFrame({"dt": [date(2020, 12, 30)]}) + + with pytest.raises( + ComputeError, + match="(conversion.*failed)|(cannot compare.*string.*temporal)", + ): + df.sql(f"SELECT * FROM self WHERE dt = '{dtval}'") + + @pytest.mark.parametrize( ("unit", "expected"), [ @@ -182,6 +239,6 @@ def test_timestamp_time_unit_errors() -> None: for prec in (0, 15): with pytest.raises( ComputeError, - match=f"unsupported `timestamp` precision; expected a value between 1 and 9, found {prec}", + match=f"invalid temporal type precision; expected 1-9, found {prec}", ): ctx.execute(f"SELECT ts::timestamp({prec}) FROM frame_data")