From 1ad2d55f70d315bc69b2c785b429a32bd965d7a0 Mon Sep 17 00:00:00 2001 From: Ethan Donowitz Date: Thu, 23 May 2024 17:09:15 -0400 Subject: [PATCH] nom-sql: Add parsing for `EXTRACT` built-in This commit adds parsing for the built-in `EXTRACT` function. This function is present in both MySQL and PostgreSQL, but the supported fields across the two databases are different. To keep things simple and scoped, only support for the PostgreSQL fields have been added. Change-Id: Ic73ef858478e73b6c466695a84ddb0266d881e92 --- nom-sql/src/analysis.rs | 5 +- nom-sql/src/analysis/visit.rs | 1 + nom-sql/src/analysis/visit_mut.rs | 1 + nom-sql/src/common.rs | 183 +++++++++++++++++- nom-sql/src/expression.rs | 15 +- nom-sql/src/lib.rs | 2 +- .../src/controller/sql/mir/grouped.rs | 2 +- .../src/controller/sql/query_graph.rs | 4 +- 8 files changed, 206 insertions(+), 7 deletions(-) diff --git a/nom-sql/src/analysis.rs b/nom-sql/src/analysis.rs index bb22dd0993..e37fb70d51 100644 --- a/nom-sql/src/analysis.rs +++ b/nom-sql/src/analysis.rs @@ -95,6 +95,7 @@ impl<'a> ReferredColumnsIter<'a> { Avg { expr, .. } => self.visit_expr(expr), Count { expr, .. } => self.visit_expr(expr), CountStar => None, + Extract { expr, .. } => self.visit_expr(expr), Sum { expr, .. } => self.visit_expr(expr), Max(arg) => self.visit_expr(arg), Min(arg) => self.visit_expr(arg), @@ -206,6 +207,7 @@ impl<'a> ReferredColumnsMut<'a> { Avg { expr, .. } => self.visit_expr(expr), Count { expr, .. } => self.visit_expr(expr), CountStar => None, + Extract { expr, .. } => self.visit_expr(expr), Sum { expr, .. } => self.visit_expr(expr), Max(arg) => self.visit_expr(arg), Min(arg) => self.visit_expr(arg), @@ -345,7 +347,8 @@ pub fn is_aggregate(function: &FunctionExpr) -> bool { | FunctionExpr::Max(_) | FunctionExpr::Min(_) | FunctionExpr::GroupConcat { .. } => true, - FunctionExpr::Substring { .. } + FunctionExpr::Extract { .. } + | FunctionExpr::Substring { .. } // For now, assume all "generic" function calls are not aggregates | FunctionExpr::Call { .. } => false, } diff --git a/nom-sql/src/analysis/visit.rs b/nom-sql/src/analysis/visit.rs index 82751a0467..b71e0a059e 100644 --- a/nom-sql/src/analysis/visit.rs +++ b/nom-sql/src/analysis/visit.rs @@ -508,6 +508,7 @@ pub fn walk_function_expr<'ast, V: Visitor<'ast>>( FunctionExpr::Max(expr) => visitor.visit_expr(expr.as_ref()), FunctionExpr::Min(expr) => visitor.visit_expr(expr.as_ref()), FunctionExpr::GroupConcat { expr, .. } => visitor.visit_expr(expr.as_ref()), + FunctionExpr::Extract { expr, .. } => visitor.visit_expr(expr.as_ref()), FunctionExpr::Call { arguments, .. } => { for arg in arguments { visitor.visit_expr(arg)?; diff --git a/nom-sql/src/analysis/visit_mut.rs b/nom-sql/src/analysis/visit_mut.rs index d183b456a9..275b3b3f19 100644 --- a/nom-sql/src/analysis/visit_mut.rs +++ b/nom-sql/src/analysis/visit_mut.rs @@ -519,6 +519,7 @@ pub fn walk_function_expr<'ast, V: VisitorMut<'ast>>( FunctionExpr::Avg { expr, .. } => visitor.visit_expr(expr.as_mut()), FunctionExpr::Count { expr, .. } => visitor.visit_expr(expr.as_mut()), FunctionExpr::CountStar => Ok(()), + FunctionExpr::Extract { expr, .. } => visitor.visit_expr(expr.as_mut()), FunctionExpr::Sum { expr, .. } => visitor.visit_expr(expr.as_mut()), FunctionExpr::Max(expr) => visitor.visit_expr(expr.as_mut()), FunctionExpr::Min(expr) => visitor.visit_expr(expr.as_mut()), diff --git a/nom-sql/src/common.rs b/nom-sql/src/common.rs index dd2e61f630..61d1eb74dd 100644 --- a/nom-sql/src/common.rs +++ b/nom-sql/src/common.rs @@ -7,7 +7,7 @@ use std::str::FromStr; use itertools::Itertools; use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case, take_until}; -use nom::character::complete::{digit1, line_ending}; +use nom::character::complete::{char, digit1, line_ending}; use nom::combinator::{map, map_res, not, opt, peek}; use nom::error::{ErrorKind, ParseError}; use nom::multi::{separated_list0, separated_list1}; @@ -477,6 +477,127 @@ fn function_call_without_parens(i: LocatedSpan<&[u8]>) -> NomSqlResult<&[u8], Fu )) } +#[derive( + Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize, Arbitrary, +)] +pub enum TimestampField { + Century, + Day, + Decade, + Dow, + Doy, + Epoch, + Hour, + Isodow, + Isoyear, + Julian, + Microseconds, + Millennium, + Milliseconds, + Minute, + Month, + Quarter, + Second, + Timezone, + TimezoneHour, + TimezoneMinute, + Week, + Year, +} + +impl fmt::Display for TimestampField { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Century => write!(f, "CENTURY"), + Self::Day => write!(f, "DAY"), + Self::Decade => write!(f, "DECADE"), + Self::Dow => write!(f, "DOW"), + Self::Doy => write!(f, "DOY"), + Self::Epoch => write!(f, "EPOCH"), + Self::Hour => write!(f, "HOUR"), + Self::Isodow => write!(f, "ISODOW"), + Self::Isoyear => write!(f, "ISOYEAR"), + Self::Julian => write!(f, "JULIAN"), + Self::Microseconds => write!(f, "MICROSECONDS"), + Self::Millennium => write!(f, "MILLENNIUM"), + Self::Milliseconds => write!(f, "MILLISECONDS"), + Self::Minute => write!(f, "MINUTE"), + Self::Month => write!(f, "MONTH"), + Self::Quarter => write!(f, "QUARTER"), + Self::Second => write!(f, "SECOND"), + Self::Timezone => write!(f, "TIMEZONE"), + Self::TimezoneHour => write!(f, "TIMEZONE_HOUR"), + Self::TimezoneMinute => write!(f, "TIMEZONE_MINUTE"), + Self::Week => write!(f, "WEEK"), + Self::Year => write!(f, "YEAR"), + } + } +} + +fn timestamp_field() -> impl Fn(LocatedSpan<&[u8]>) -> NomSqlResult<&[u8], TimestampField> { + move |i| { + let alt1 = alt(( + map(tag_no_case("century"), |_| TimestampField::Century), + map(tag_no_case("day"), |_| TimestampField::Day), + map(tag_no_case("decade"), |_| TimestampField::Decade), + map(tag_no_case("dow"), |_| TimestampField::Dow), + map(tag_no_case("doy"), |_| TimestampField::Doy), + map(tag_no_case("epoch"), |_| TimestampField::Epoch), + map(tag_no_case("hour"), |_| TimestampField::Hour), + map(tag_no_case("isodow"), |_| TimestampField::Isodow), + map(tag_no_case("isoyear"), |_| TimestampField::Isoyear), + map(tag_no_case("julian"), |_| TimestampField::Julian), + map(tag_no_case("microseconds"), |_| { + TimestampField::Microseconds + }), + map(tag_no_case("millennium"), |_| TimestampField::Millennium), + map(tag_no_case("milliseconds"), |_| { + TimestampField::Milliseconds + }), + map(tag_no_case("minute"), |_| TimestampField::Minute), + map(tag_no_case("month"), |_| TimestampField::Month), + map(tag_no_case("quarter"), |_| TimestampField::Quarter), + map(tag_no_case("second"), |_| TimestampField::Second), + map(tag_no_case("timezone_hour"), |_| { + TimestampField::TimezoneHour + }), + map(tag_no_case("timezone_minute"), |_| { + TimestampField::TimezoneMinute + }), + map(tag_no_case("timezone"), |_| TimestampField::Timezone), + map(tag_no_case("week"), |_| TimestampField::Week), + )); + + // `alt` has an upper limit on the number of items it supports in tuples, so we have to + // split the parsing for these fields into separate invocations + alt((alt1, map(tag_no_case("year"), |_| TimestampField::Year)))(i) + } +} + +fn extract(dialect: Dialect) -> impl Fn(LocatedSpan<&[u8]>) -> NomSqlResult<&[u8], FunctionExpr> { + move |i| { + let (i, _) = tag_no_case("EXTRACT")(i)?; + let (i, _) = whitespace0(i)?; + let (i, _) = char('(')(i)?; + let (i, _) = whitespace0(i)?; + let (i, field) = timestamp_field()(i)?; + let (i, _) = whitespace1(i)?; + let (i, _) = tag_no_case("FROM")(i)?; + let (i, _) = whitespace1(i)?; + let (i, expr) = expression(dialect)(i)?; + let (i, _) = whitespace0(i)?; + let (i, _) = char(')')(i)?; + + Ok(( + i, + FunctionExpr::Extract { + field, + expr: Box::new(expr), + }, + )) + } +} + fn substring(dialect: Dialect) -> impl Fn(LocatedSpan<&[u8]>) -> NomSqlResult<&[u8], FunctionExpr> { move |i| { let (i, _) = alt((tag_no_case("substring"), tag_no_case("substr")))(i)?; @@ -581,6 +702,7 @@ pub fn function_expr( separator, }, ), + extract(dialect), substring(dialect), function_call(dialect), function_call_without_parens, @@ -1283,4 +1405,63 @@ mod tests { assert_eq!(res2, expected); } } + + mod extract { + use super::*; + + macro_rules! extract_test { + ($field:ident, $field_variant:ident, $field_expr:expr) => { + mod $field { + use super::*; + + #[test] + fn parse_extract_expr() { + let expr = format!("EXTRACT({} FROM \"col\")", $field_expr); + assert_eq!( + test_parse!(extract(Dialect::PostgreSQL), expr.as_bytes()), + FunctionExpr::Extract { + field: TimestampField::$field_variant, + expr: Box::new(Expr::Column(Column { + name: "col".into(), + table: None, + })), + }, + ); + } + + #[test] + fn format_round_trip() { + let expected = format!("EXTRACT({} FROM \"col\")", $field_expr); + let actual = test_parse!(extract(Dialect::PostgreSQL), expected.as_bytes()) + .display(Dialect::PostgreSQL) + .to_string(); + + assert_eq!(expected, actual); + } + } + }; + } + + extract_test!(century, Century, "CENTURY"); + extract_test!(decade, Decade, "DECADE"); + extract_test!(dow, Dow, "DOW"); + extract_test!(doy, Doy, "DOY"); + extract_test!(epoch, Epoch, "EPOCH"); + extract_test!(hour, Hour, "HOUR"); + extract_test!(isodow, Isodow, "ISODOW"); + extract_test!(isoyear, Isoyear, "ISOYEAR"); + extract_test!(julian, Julian, "JULIAN"); + extract_test!(microseconds, Microseconds, "MICROSECONDS"); + extract_test!(millennium, Millennium, "MILLENNIUM"); + extract_test!(milliseconds, Milliseconds, "MILLISECONDS"); + extract_test!(minute, Minute, "MINUTE"); + extract_test!(month, Month, "MONTH"); + extract_test!(quarter, Quarter, "QUARTER"); + extract_test!(second, Second, "SECOND"); + extract_test!(timezone_hour, TimezoneHour, "TIMEZONE_HOUR"); + extract_test!(timezone_minute, TimezoneMinute, "TIMEZONE_MINUTE"); + extract_test!(timezone, Timezone, "TIMEZONE"); + extract_test!(week, Week, "WEEK"); + extract_test!(year, Year, "YEAR"); + } } diff --git a/nom-sql/src/expression.rs b/nom-sql/src/expression.rs index b79345421f..bbf815b60c 100644 --- a/nom-sql/src/expression.rs +++ b/nom-sql/src/expression.rs @@ -19,7 +19,7 @@ use readyset_util::fmt::fmt_with; use serde::{Deserialize, Serialize}; use test_strategy::Arbitrary; -use crate::common::{column_identifier_no_alias, function_expr, ws_sep_comma}; +use crate::common::{column_identifier_no_alias, function_expr, ws_sep_comma, TimestampField}; use crate::literal::{literal, Double, Float}; use crate::select::nested_selection; use crate::set::{variable_scope_prefix, Variable}; @@ -41,6 +41,11 @@ pub enum FunctionExpr { /// `COUNT(*)` aggregation CountStar, + Extract { + field: TimestampField, + expr: Box, + }, + /// `SUM` aggregation Sum { expr: Box, distinct: bool }, @@ -89,7 +94,8 @@ impl FunctionExpr { | FunctionExpr::Sum { expr: arg, .. } | FunctionExpr::Max(arg) | FunctionExpr::Min(arg) - | FunctionExpr::GroupConcat { expr: arg, .. } => { + | FunctionExpr::GroupConcat { expr: arg, .. } + | FunctionExpr::Extract { expr: arg, .. } => { concrete_iter!(iter::once(arg.as_ref())) } FunctionExpr::CountStar => concrete_iter!(iter::empty()), @@ -156,6 +162,9 @@ impl DialectDisplay for FunctionExpr { write!(f, ")") } + FunctionExpr::Extract { field, expr } => { + write!(f, "EXTRACT({field} FROM {})", expr.display(dialect)) + } }) } } @@ -690,6 +699,8 @@ impl Arbitrary for Expr { Just(FunctionExpr::CountStar), (box_expr.clone(), any::()) .prop_map(|(expr, distinct)| FunctionExpr::Sum { expr, distinct }), + (box_expr.clone(), any::()) + .prop_map(|(expr, field)| FunctionExpr::Extract { expr, field }), box_expr.clone().prop_map(FunctionExpr::Max), box_expr.clone().prop_map(FunctionExpr::Min), (box_expr.clone(), any::>()).prop_map(|(expr, separator)| { diff --git a/nom-sql/src/lib.rs b/nom-sql/src/lib.rs index 48326f42af..a032cd5936 100644 --- a/nom-sql/src/lib.rs +++ b/nom-sql/src/lib.rs @@ -18,7 +18,7 @@ pub use self::alter::{ }; pub use self::column::{Column, ColumnConstraint, ColumnSpecification}; pub use self::comment::CommentStatement; -pub use self::common::{FieldDefinitionExpr, FieldReference, IndexType, TableKey}; +pub use self::common::{FieldDefinitionExpr, FieldReference, IndexType, TableKey, TimestampField}; pub use self::compound_select::{CompoundSelectOperator, CompoundSelectStatement}; pub use self::create::{ CacheInner, CreateCacheStatement, CreateTableBody, CreateTableStatement, CreateViewStatement, diff --git a/readyset-server/src/controller/sql/mir/grouped.rs b/readyset-server/src/controller/sql/mir/grouped.rs index d35f71e624..ddacf5b894 100644 --- a/readyset-server/src/controller/sql/mir/grouped.rs +++ b/readyset-server/src/controller/sql/mir/grouped.rs @@ -326,7 +326,7 @@ pub(super) fn post_lookup_aggregates( GroupConcat { separator, .. } => PostLookupAggregateFunction::GroupConcat { separator: separator.clone().unwrap_or_else(|| ",".to_owned()), }, - Call { .. } | Substring { .. } => continue, + Extract { .. } | Call { .. } | Substring { .. } => continue, }, }); } diff --git a/readyset-server/src/controller/sql/query_graph.rs b/readyset-server/src/controller/sql/query_graph.rs index 8990759461..d9aad9175d 100644 --- a/readyset-server/src/controller/sql/query_graph.rs +++ b/readyset-server/src/controller/sql/query_graph.rs @@ -869,7 +869,9 @@ fn default_row_for_select(st: &SelectStatement) -> Option> { FunctionExpr::Max(..) => DfValue::None, FunctionExpr::Min(..) => DfValue::None, FunctionExpr::GroupConcat { .. } => DfValue::None, - FunctionExpr::Call { .. } | FunctionExpr::Substring { .. } => DfValue::None, + FunctionExpr::Extract { .. } + | FunctionExpr::Call { .. } + | FunctionExpr::Substring { .. } => DfValue::None, }, _ => DfValue::None, })