diff --git a/Cargo.lock b/Cargo.lock index d3171fecec..65dc2c312f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2012,6 +2012,7 @@ dependencies = [ "rstest", "snafu", "sqlparser", + "strum 0.26.3", ] [[package]] diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 5487542b2a..4aac933e85 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -943,6 +943,14 @@ impl Expr { .and_then(|_| String::from_utf8(buffer).ok()) } + /// If the expression is a literal, return it. Otherwise, return None. + pub fn as_literal(&self) -> Option<&lit::LiteralValue> { + match self { + Expr::Literal(lit) => Some(lit), + _ => None, + } + } + pub fn has_agg(&self) -> bool { use Expr::*; diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 201ae30eb6..1531d2b61c 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -247,6 +247,71 @@ impl LiteralValue { Python(..) => display_sql_err, } } + + /// If the liter is a boolean, return it. Otherwise, return None. + pub fn as_bool(&self) -> Option { + match self { + LiteralValue::Boolean(b) => Some(*b), + _ => None, + } + } + + /// If the literal is a string, return it. Otherwise, return None. + pub fn as_str(&self) -> Option<&str> { + match self { + LiteralValue::Utf8(s) => Some(s), + _ => None, + } + } + /// If the literal is `Binary`, return it. Otherwise, return None. + pub fn as_binary(&self) -> Option<&[u8]> { + match self { + LiteralValue::Binary(b) => Some(b), + _ => None, + } + } + /// If the literal is `Int32`, return it. Otherwise, return None. + pub fn as_i32(&self) -> Option { + match self { + LiteralValue::Int32(i) => Some(*i), + _ => None, + } + } + /// If the literal is `UInt32`, return it. Otherwise, return None. + pub fn as_u32(&self) -> Option { + match self { + LiteralValue::UInt32(i) => Some(*i), + _ => None, + } + } + /// If the literal is `Int64`, return it. Otherwise, return None. + pub fn as_i64(&self) -> Option { + match self { + LiteralValue::Int64(i) => Some(*i), + _ => None, + } + } + /// If the literal is `UInt64`, return it. Otherwise, return None. + pub fn as_u64(&self) -> Option { + match self { + LiteralValue::UInt64(i) => Some(*i), + _ => None, + } + } + /// If the literal is `Float64`, return it. Otherwise, return None. + pub fn as_f64(&self) -> Option { + match self { + LiteralValue::Float64(f) => Some(*f), + _ => None, + } + } + /// If the literal is a series, return it. Otherwise, return None. + pub fn as_series(&self) -> Option<&Series> { + match self { + LiteralValue::Series(series) => Some(series), + _ => None, + } + } } pub trait Literal { diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 99276adba1..75810e41a7 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -5,6 +5,7 @@ daft-dsl = {path = "../daft-dsl"} daft-plan = {path = "../daft-plan"} pyo3 = {workspace = true, optional = true} sqlparser = {workspace = true} +strum = {version = "0.26.3", features = ["derive"]} snafu.workspace = true [dev-dependencies] diff --git a/src/daft-sql/src/error.rs b/src/daft-sql/src/error.rs index 9db84d0c01..d4a923afe8 100644 --- a/src/daft-sql/src/error.rs +++ b/src/daft-sql/src/error.rs @@ -42,6 +42,13 @@ impl From for PlannerError { PlannerError::SQLParserError { source: value } } } +impl From for PlannerError { + fn from(value: strum::ParseError) -> Self { + PlannerError::ParseError { + message: value.to_string(), + } + } +} impl PlannerError { pub fn column_not_found, B: Into>(column_name: A, relation: B) -> Self { @@ -98,6 +105,14 @@ macro_rules! invalid_operation_err { return Err($crate::error::PlannerError::invalid_operation(format!($($arg)*))) }; } +#[macro_export] +macro_rules! ensure { + ($condition:expr, $($arg:tt)*) => { + if !$condition { + return Err($crate::error::PlannerError::invalid_operation(format!($($arg)*))) + } + }; +} impl From for DaftError { fn from(value: PlannerError) -> Self { diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index b62cf2a860..dcc49f9928 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -1,9 +1,9 @@ -use std::str::FromStr; - -use daft_dsl::{Expr, ExprRef, LiteralValue}; +use daft_dsl::{ExprRef, LiteralValue}; use sqlparser::ast::{Function, FunctionArg, FunctionArgExpr}; +use strum::EnumString; use crate::{ + ensure, error::{PlannerError, SQLPlannerResult}, invalid_operation_err, planner::{Relation, SQLPlanner}, @@ -11,29 +11,305 @@ use crate::{ }; // TODO: expand this to support more functions +#[derive(EnumString)] +#[strum(serialize_all = "snake_case")] +#[strum(ascii_case_insensitive)] pub enum SQLFunctions { + // ------------------------------------------------ + // Numeric Functions + // ------------------------------------------------ /// SQL `abs()` function + /// # Example + /// ```sql + /// SELECT abs(-1); + /// ``` Abs, + /// SQL `ceil()` function + /// # Example + /// ```sql + /// SELECT ceil(1.1); + /// ``` + Ceil, + /// SQL `floor()` function + /// # Example + /// ```sql + /// SELECT floor(1.1); + /// ``` + Floor, /// SQL `sign()` function + /// # Example + /// ```sql + /// SELECT sign(-1); + /// ``` Sign, - /// SQL `round()` function + /// SQL `round()` function + /// # Example + /// ```sql + /// SELECT round(1.1); + /// ``` Round, - /// SQL `max` function - Max, -} + /// SQL `sqrt()` function + /// # Example + /// ```sql + /// SELECT sqrt(4); + /// ``` + Sqrt, + /// SQL `sin()` function + /// # Example + /// ```sql + /// SELECT sin(0); + /// ``` + Sin, + /// SQL `cos()` function + /// # Example + /// ```sql + /// SELECT cos(0); + /// ``` + Cos, + /// SQL `tan()` function + /// # Example + /// ```sql + /// SELECT tan(0); + /// ``` + Tan, + /// SQL `cot()` function + /// # Example + /// ```sql + /// SELECT cot(0); + /// ``` + Cot, + /// SQL `asin()` function + /// # Example + /// ```sql + /// SELECT asin(0); + /// ``` + #[strum(serialize = "asin", serialize = "arcsin")] + ArcSin, + /// SQL `acos()` function + /// # Example + /// ```sql + /// SELECT acos(0); + /// ``` + #[strum(serialize = "acos", serialize = "arccos")] + ArcCos, + /// SQL `atan()` function + /// # Example + /// ```sql + /// SELECT atan(0); + /// ``` + #[strum(serialize = "atan", serialize = "arctan")] + ArcTan, + /// SQL `atan2()` function + /// # Example + /// ```sql + /// SELECT atan2(0); + /// ``` + #[strum(serialize = "atan2", serialize = "arctan2")] + ArcTan2, + /// SQL `radians()` function + /// # Example + /// ```sql + /// SELECT radians(0); + /// ``` + Radians, + /// SQL `degrees()` function + /// # Example + /// ```sql + /// SELECT degrees(0); + /// ``` + Degrees, + /// SQL `log2()` function + /// # Example + /// ```sql + /// SELECT log2(1); + /// ``` + Log2, + /// SQL `log10()` function + /// # Example + /// ```sql + /// SELECT log10(1); + /// ``` + Log10, + /// SQL `ln()` function + /// # Example + /// ```sql + /// SELECT ln(1); + /// ``` + Ln, + /// SQL `log()` function + /// # Example + /// ```sql + /// SELECT log(1, 10); + /// ``` + Log, + /// SQL `exp()` function + /// # Example + /// ```sql + /// SELECT exp(1); + /// ``` + Exp, + /// SQL `atanh()` function + /// # Example + /// ```sql + /// SELECT atanh(0); + /// ``` + #[strum(serialize = "atanh", serialize = "arctanh")] + ArcTanh, + /// SQL `acosh()` function + /// # Example + /// ```sql + /// SELECT acosh(1); + /// ``` + #[strum(serialize = "acosh", serialize = "arccosh")] + ArcCosh, + /// SQL `asinh()` function + /// # Example + /// ```sql + /// SELECT asinh(0); + /// ``` + #[strum(serialize = "asinh", serialize = "arcsinh")] + ArcSinh, -impl FromStr for SQLFunctions { - type Err = PlannerError; + // ------------------------------------------------ + // String Functions + // ------------------------------------------------ + /// SQL `ends_with()` function + /// # Example + /// ```sql + /// SELECT ends_with('hello', 'lo'); + /// ``` + EndsWith, + /// SQL `starts_with()` function + /// # Example + /// ```sql + /// SELECT starts_with('hello', 'he'); + /// ``` + StartsWith, + /// SQL `contains()` function + /// # Example + /// ```sql + /// SELECT contains('hello', 'el'); + /// ``` + Contains, + /// SQL `split()` function + /// # Example + /// ```sql + /// SELECT split('hello,world', ','); + /// ``` + Split, - fn from_str(value: &str) -> Result { - match value { - "abs" => Ok(SQLFunctions::Abs), - "sign" => Ok(SQLFunctions::Sign), - "round" => Ok(SQLFunctions::Round), - "max" => Ok(SQLFunctions::Max), - _ => unsupported_sql_err!("unknown function: '{value}'"), - } - } + Extract, + ExtractAll, + /// SQL 'replace()' function + /// # Example + /// ```sql + /// SELECT replace('hello', 'l', 'r'); + /// ``` + Replace, + /// SQL 'length()' function + /// # Example + /// ```sql + /// SELECT length('hello'); + /// ``` + Length, + /// SQL 'lower()' function + /// # Example + /// ```sql + /// SELECT lower('HELLO'); + /// ``` + Lower, + /// SQL 'upper()' function + /// # Example + /// ```sql + /// SELECT upper('hello'); + /// ``` + Upper, + /// SQL 'lstrip()' function + /// # Example + /// ```sql + /// SELECT lstrip(' hello'); + /// ``` + Lstrip, + /// SQL 'rstrip()' function + /// # Example + /// ```sql + /// SELECT rstrip('hello '); + /// ``` + Rstrip, + /// SQL 'reverse()' function + /// # Example + /// ```sql + /// SELECT reverse('olleh'); + /// ``` + Reverse, + /// SQL 'capitalize()' function + /// # Example + /// ```sql + /// SELECT capitalize('hello'); + /// ``` + Capitalize, + /// SQL 'left()' function + /// # Example + /// ```sql + /// SELECT left('hello', 2); + /// ``` + Left, + + /// SQL 'right()' function + /// # Example + /// ```sql + /// SELECT right('hello', 2); + /// ``` + Right, + /// SQL 'find()' function + /// # Example + /// ```sql + /// SELECT find('hello', 'l'); + /// ``` + Find, + /// SQL 'rpad()' function + /// # Example + /// ```sql + /// SELECT rpad('hello', 10, ' '); + /// ``` + Rpad, + /// SQL 'lpad()' function + /// # Example + /// ```sql + /// SELECT lpad('hello', 10, ' '); + /// ``` + Lpad, + /// SQL 'repeat()' function + /// # Example + /// ```sql + /// SELECT repeat('X', 2); + /// ``` + Repeat, + // Like, + // Ilike, + /// SQL 'substring()' function + /// # Example + /// ```sql + /// SELECT substring('hello', 1, 2); + /// ``` + Substr, + /// SQL 'to_date()` function + /// # Example + /// ```sql + /// SELECT to_date('2021-01-01', 'YYYY-MM-DD'); + /// ``` + ToDate, + /// SQL 'to_datetime()' function + /// # Example + /// ```sql + /// SELECT to_datetime('2021-01-01 00:00:00', 'YYYY-MM-DD HH:MM:SS'); + /// ``` + ToDatetime, + + // ------------------------------------------------ + // Aggregate Functions + // ------------------------------------------------ + Max, } impl SQLPlanner { @@ -78,38 +354,283 @@ impl SQLPlanner { match func { SQLFunctions::Abs => { - if args.len() != 1 { - invalid_operation_err!("abs takes exactly one argument"); - } + ensure!(args.len() == 1, "abs takes exactly one argument"); Ok(daft_dsl::functions::numeric::abs(args[0].clone())) } + SQLFunctions::Ceil => { + ensure!(args.len() == 1, "ceil takes exactly one argument"); + Ok(daft_dsl::functions::numeric::ceil(args[0].clone())) + } + SQLFunctions::Floor => { + ensure!(args.len() == 1, "floor takes exactly one argument"); + Ok(daft_dsl::functions::numeric::floor(args[0].clone())) + } SQLFunctions::Sign => { - if args.len() != 1 { - invalid_operation_err!("sign takes exactly one argument"); - } + ensure!(args.len() == 1, "sign takes exactly one argument"); Ok(daft_dsl::functions::numeric::sign(args[0].clone())) } SQLFunctions::Round => { - if args.len() != 2 { - invalid_operation_err!("round takes exactly one argument"); - } - - let precision = match args[1].as_ref() { - Expr::Literal(LiteralValue::Int32(i)) => *i, - Expr::Literal(LiteralValue::UInt32(u)) => *u as i32, - Expr::Literal(LiteralValue::Int64(i)) => *i as i32, + ensure!(args.len() == 2, "round takes exactly two arguments"); + let precision = match args[1].as_ref().as_literal() { + Some(LiteralValue::Int32(i)) => *i, + Some(LiteralValue::UInt32(u)) => *u as i32, + Some(LiteralValue::Int64(i)) => *i as i32, _ => invalid_operation_err!("round precision must be an integer"), }; - Ok(daft_dsl::functions::numeric::round( args[0].clone(), precision, )) } + SQLFunctions::Sqrt => { + ensure!(args.len() == 1, "sqrt takes exactly one argument"); + Ok(daft_dsl::functions::numeric::sqrt(args[0].clone())) + } + SQLFunctions::Sin => { + ensure!(args.len() == 1, "sin takes exactly one argument"); + Ok(daft_dsl::functions::numeric::sin(args[0].clone())) + } + SQLFunctions::Cos => { + ensure!(args.len() == 1, "cos takes exactly one argument"); + Ok(daft_dsl::functions::numeric::cos(args[0].clone())) + } + SQLFunctions::Tan => { + ensure!(args.len() == 1, "tan takes exactly one argument"); + Ok(daft_dsl::functions::numeric::tan(args[0].clone())) + } + SQLFunctions::Cot => { + ensure!(args.len() == 1, "cot takes exactly one argument"); + Ok(daft_dsl::functions::numeric::cot(args[0].clone())) + } + SQLFunctions::ArcSin => { + ensure!(args.len() == 1, "asin takes exactly one argument"); + Ok(daft_dsl::functions::numeric::arcsin(args[0].clone())) + } + SQLFunctions::ArcCos => { + ensure!(args.len() == 1, "acos takes exactly one argument"); + Ok(daft_dsl::functions::numeric::arccos(args[0].clone())) + } + SQLFunctions::ArcTan => { + ensure!(args.len() == 1, "atan takes exactly one argument"); + Ok(daft_dsl::functions::numeric::arctan(args[0].clone())) + } + SQLFunctions::ArcTan2 => { + ensure!(args.len() == 2, "atan2 takes exactly two arguments"); + Ok(daft_dsl::functions::numeric::arctan2( + args[0].clone(), + args[1].clone(), + )) + } + SQLFunctions::Degrees => { + ensure!(args.len() == 1, "degrees takes exactly one argument"); + Ok(daft_dsl::functions::numeric::degrees(args[0].clone())) + } + SQLFunctions::Radians => { + ensure!(args.len() == 1, "radians takes exactly one argument"); + Ok(daft_dsl::functions::numeric::radians(args[0].clone())) + } + SQLFunctions::Log2 => { + ensure!(args.len() == 1, "log2 takes exactly one argument"); + Ok(daft_dsl::functions::numeric::log2(args[0].clone())) + } + SQLFunctions::Log10 => { + ensure!(args.len() == 1, "log10 takes exactly one argument"); + Ok(daft_dsl::functions::numeric::log10(args[0].clone())) + } + SQLFunctions::Ln => { + ensure!(args.len() == 1, "ln takes exactly one argument"); + Ok(daft_dsl::functions::numeric::ln(args[0].clone())) + } + SQLFunctions::Log => { + ensure!(args.len() == 2, "log takes exactly two arguments"); + let base = args[1] + .as_literal() + .and_then(|lit| match lit { + LiteralValue::Float64(f) => Some(*f), + LiteralValue::Int32(i) => Some(*i as f64), + LiteralValue::UInt32(u) => Some(*u as f64), + LiteralValue::Int64(i) => Some(*i as f64), + LiteralValue::UInt64(u) => Some(*u as f64), + _ => None, + }) + .ok_or_else(|| PlannerError::InvalidOperation { + message: "log base must be a float or a number".to_string(), + })?; + + Ok(daft_dsl::functions::numeric::log(args[0].clone(), base)) + } + SQLFunctions::Exp => { + ensure!(args.len() == 1, "exp takes exactly one argument"); + Ok(daft_dsl::functions::numeric::exp(args[0].clone())) + } + SQLFunctions::ArcTanh => { + ensure!(args.len() == 1, "atanh takes exactly one argument"); + Ok(daft_dsl::functions::numeric::arctanh(args[0].clone())) + } + SQLFunctions::ArcCosh => { + ensure!(args.len() == 1, "acosh takes exactly one argument"); + Ok(daft_dsl::functions::numeric::arccosh(args[0].clone())) + } + SQLFunctions::ArcSinh => { + ensure!(args.len() == 1, "asinh takes exactly one argument"); + Ok(daft_dsl::functions::numeric::arcsinh(args[0].clone())) + } + SQLFunctions::EndsWith => { + ensure!(args.len() == 2, "endswith takes exactly two arguments"); + Ok(daft_dsl::functions::utf8::endswith( + args[0].clone(), + args[1].clone(), + )) + } + SQLFunctions::StartsWith => { + ensure!(args.len() == 2, "startswith takes exactly two arguments"); + Ok(daft_dsl::functions::utf8::startswith( + args[0].clone(), + args[1].clone(), + )) + } + SQLFunctions::Contains => { + ensure!(args.len() == 2, "contains takes exactly two arguments"); + Ok(daft_dsl::functions::utf8::contains( + args[0].clone(), + args[1].clone(), + )) + } + SQLFunctions::Split => { + ensure!(args.len() == 2, "split takes exactly two arguments"); + Ok(daft_dsl::functions::utf8::split( + args[0].clone(), + args[1].clone(), + false, + )) + } + + SQLFunctions::Extract => { + unsupported_sql_err!("extract") + } + SQLFunctions::ExtractAll => { + unsupported_sql_err!("extract_all") + } + SQLFunctions::Replace => { + ensure!(args.len() == 3, "replace takes exactly three arguments"); + Ok(daft_dsl::functions::utf8::replace( + args[0].clone(), + args[1].clone(), + args[2].clone(), + false, + )) + } + SQLFunctions::Length => { + ensure!(args.len() == 1, "length takes exactly one argument"); + Ok(daft_dsl::functions::utf8::length(args[0].clone())) + } + SQLFunctions::Lower => { + ensure!(args.len() == 1, "lower takes exactly one argument"); + Ok(daft_dsl::functions::utf8::lower(args[0].clone())) + } + SQLFunctions::Upper => { + ensure!(args.len() == 1, "upper takes exactly one argument"); + Ok(daft_dsl::functions::utf8::upper(args[0].clone())) + } + SQLFunctions::Lstrip => { + ensure!(args.len() == 1, "lstrip takes exactly one argument"); + Ok(daft_dsl::functions::utf8::lstrip(args[0].clone())) + } + SQLFunctions::Rstrip => { + ensure!(args.len() == 1, "rstrip takes exactly one argument"); + Ok(daft_dsl::functions::utf8::rstrip(args[0].clone())) + } + SQLFunctions::Reverse => { + ensure!(args.len() == 1, "reverse takes exactly one argument"); + Ok(daft_dsl::functions::utf8::reverse(args[0].clone())) + } + SQLFunctions::Capitalize => { + ensure!(args.len() == 1, "capitalize takes exactly one argument"); + Ok(daft_dsl::functions::utf8::capitalize(args[0].clone())) + } + SQLFunctions::Left => { + ensure!(args.len() == 2, "left takes exactly two arguments"); + Ok(daft_dsl::functions::utf8::left( + args[0].clone(), + args[1].clone(), + )) + } + SQLFunctions::Right => { + ensure!(args.len() == 2, "right takes exactly two arguments"); + Ok(daft_dsl::functions::utf8::right( + args[0].clone(), + args[1].clone(), + )) + } + SQLFunctions::Find => { + ensure!(args.len() == 2, "find takes exactly two arguments"); + Ok(daft_dsl::functions::utf8::find( + args[0].clone(), + args[1].clone(), + )) + } + SQLFunctions::Rpad => { + ensure!(args.len() == 3, "rpad takes exactly three arguments"); + Ok(daft_dsl::functions::utf8::rpad( + args[0].clone(), + args[1].clone(), + args[2].clone(), + )) + } + SQLFunctions::Lpad => { + ensure!(args.len() == 3, "lpad takes exactly three arguments"); + Ok(daft_dsl::functions::utf8::lpad( + args[0].clone(), + args[1].clone(), + args[2].clone(), + )) + } + SQLFunctions::Repeat => { + ensure!(args.len() == 2, "repeat takes exactly two arguments"); + Ok(daft_dsl::functions::utf8::repeat( + args[0].clone(), + args[1].clone(), + )) + } + SQLFunctions::Substr => { + ensure!(args.len() == 3, "substr takes exactly three arguments"); + Ok(daft_dsl::functions::utf8::substr( + args[0].clone(), + args[1].clone(), + args[2].clone(), + )) + } + SQLFunctions::ToDate => { + ensure!(args.len() == 2, "to_date takes exactly two arguments"); + let fmt = match args[1].as_ref().as_literal() { + Some(LiteralValue::Utf8(s)) => s, + _ => invalid_operation_err!("to_date format must be a string"), + }; + Ok(daft_dsl::functions::utf8::to_date(args[0].clone(), fmt)) + } + SQLFunctions::ToDatetime => { + ensure!( + args.len() >= 2, + "to_datetime takes either two or three arguments" + ); + let fmt = match args[1].as_ref().as_literal() { + Some(LiteralValue::Utf8(s)) => s, + _ => invalid_operation_err!("to_datetime format must be a string"), + }; + let tz = match args.get(2).and_then(|e| e.as_ref().as_literal()) { + Some(LiteralValue::Utf8(s)) => Some(s.as_str()), + _ => invalid_operation_err!("to_datetime timezone must be a string"), + }; + + Ok(daft_dsl::functions::utf8::to_datetime( + args[0].clone(), + fmt, + tz, + )) + } + SQLFunctions::Max => { - if args.len() != 1 { - invalid_operation_err!("max takes exactly one argument"); - } + ensure!(args.len() == 1, "max takes exactly one argument"); Ok(args[0].clone().max()) } } diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index bf9cdbb22a..28f6c6ad15 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -20,6 +20,8 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { mod tests { use std::sync::Arc; + use crate::planner::SQLPlanner; + use super::*; use catalog::SQLCatalog; use daft_core::{ @@ -33,9 +35,9 @@ mod tests { LogicalPlanBuilder, LogicalPlanRef, SourceInfo, }; use error::SQLPlannerResult; - use planner::SQLPlanner; - use rstest::rstest; + use rstest::{fixture, rstest}; + #[fixture] fn tbl_1() -> LogicalPlanRef { let schema = Arc::new( Schema::new(vec![ @@ -62,6 +64,8 @@ mod tests { }) .arced() } + + #[fixture] fn tbl_2() -> LogicalPlanRef { let schema = Arc::new( Schema::new(vec![ @@ -81,6 +85,7 @@ mod tests { .arced() } + #[fixture] fn tbl_3() -> LogicalPlanRef { let schema = Arc::new( Schema::new(vec![ @@ -101,7 +106,8 @@ mod tests { .arced() } - fn setup() -> SQLPlanner { + #[fixture] + fn planner() -> SQLPlanner { let mut catalog = SQLCatalog::new(); catalog.register_table("tbl1", tbl_1()); @@ -148,35 +154,31 @@ mod tests { #[case::orderby("select * from tbl1 order by i32 asc")] #[case::orderby_multi("select * from tbl1 order by i32 desc, f32 asc")] #[case::whenthen("select case when i32 = 1 then 'a' else 'b' end from tbl1")] - fn test_compiles(#[case] query: &str) -> SQLPlannerResult<()> { - let planner = setup(); - + fn test_compiles(planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> { let plan = planner.plan_sql(query); assert!(plan.is_ok(), "query: {}\nerror: {:?}", query, plan); Ok(()) } - #[test] - fn test_parse_sql() { - let planner = setup(); + #[rstest] + fn test_parse_sql(planner: SQLPlanner, tbl_1: LogicalPlanRef) { let sql = "select test as a from tbl1"; let plan = planner.plan_sql(sql).unwrap(); - let expected = LogicalPlanBuilder::new(tbl_1()) + let expected = LogicalPlanBuilder::new(tbl_1) .select(vec![col("test").alias("a")]) .unwrap() .build(); assert_eq!(plan, expected); } - #[test] - fn test_where_clause() -> SQLPlannerResult<()> { - let planner = setup(); + #[rstest] + fn test_where_clause(planner: SQLPlanner, tbl_1: LogicalPlanRef) -> SQLPlannerResult<()> { let sql = "select test as a from tbl1 where test = 'a'"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1()) + let expected = LogicalPlanBuilder::new(tbl_1) .filter(col("test").eq(lit("a")))? .select(vec![col("test").alias("a")])? .build(); @@ -184,13 +186,12 @@ mod tests { assert_eq!(plan, expected); Ok(()) } - #[test] - fn test_limit() -> SQLPlannerResult<()> { - let planner = setup(); + #[rstest] + fn test_limit(planner: SQLPlanner, tbl_1: LogicalPlanRef) -> SQLPlannerResult<()> { let sql = "select test as a from tbl1 limit 10"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1()) + let expected = LogicalPlanBuilder::new(tbl_1) .select(vec![col("test").alias("a")])? .limit(10, true)? .build(); @@ -199,13 +200,12 @@ mod tests { Ok(()) } - #[test] - fn test_orderby() -> SQLPlannerResult<()> { - let planner = setup(); + #[rstest] + fn test_orderby(planner: SQLPlanner, tbl_1: LogicalPlanRef) -> SQLPlannerResult<()> { let sql = "select utf8 from tbl1 order by utf8 desc"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1()) + let expected = LogicalPlanBuilder::new(tbl_1) .select(vec![col("utf8")])? .sort(vec![col("utf8")], vec![true])? .build(); @@ -214,10 +214,9 @@ mod tests { Ok(()) } - #[test] - fn test_cast() -> SQLPlannerResult<()> { - let planner = setup(); - let builder = LogicalPlanBuilder::new(tbl_1()); + #[rstest] + fn test_cast(planner: SQLPlanner, tbl_1: LogicalPlanRef) -> SQLPlannerResult<()> { + let builder = LogicalPlanBuilder::new(tbl_1); let cases = vec![ ( "select bool::text from tbl1", @@ -247,14 +246,17 @@ mod tests { Ok(()) } - #[test] - fn test_join() -> SQLPlannerResult<()> { - let planner = setup(); + #[rstest] + fn test_join( + planner: SQLPlanner, + tbl_2: LogicalPlanRef, + tbl_3: LogicalPlanRef, + ) -> SQLPlannerResult<()> { let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_2()) + let expected = LogicalPlanBuilder::new(tbl_2) .join( - tbl_3(), + tbl_3, vec![col("id")], vec![col("id")], daft_core::JoinType::Inner, @@ -264,4 +266,54 @@ mod tests { assert_eq!(plan, expected); Ok(()) } + + #[rstest] + #[case::abs("select abs(i32) as abs from tbl1")] + #[case::ceil("select ceil(i32) as ceil from tbl1")] + #[case::floor("select floor(i32) as floor from tbl1")] + #[case::sign("select sign(i32) as sign from tbl1")] + #[case::round("select round(i32, 1) as round from tbl1")] + #[case::sqrt("select sqrt(i32) as sqrt from tbl1")] + #[case::sin("select sin(i32) as sin from tbl1")] + #[case::cos("select cos(i32) as cos from tbl1")] + #[case::tan("select tan(i32) as tan from tbl1")] + #[case::asin("select asin(i32) as asin from tbl1")] + #[case::acos("select acos(i32) as acos from tbl1")] + #[case::atan("select atan(i32) as atan from tbl1")] + #[case::atan2("select atan2(i32, 1) as atan2 from tbl1")] + #[case::radians("select radians(i32) as radians from tbl1")] + #[case::degrees("select degrees(i32) as degrees from tbl1")] + #[case::log2("select log2(i32) as log2 from tbl1")] + #[case::log10("select log10(i32) as log10 from tbl1")] + #[case::ln("select ln(i32) as ln from tbl1")] + #[case::exp("select exp(i32) as exp from tbl1")] + #[case::atanh("select atanh(i32) as atanh from tbl1")] + #[case::acosh("select acosh(i32) as acosh from tbl1")] + #[case::asinh("select asinh(i32) as asinh from tbl1")] + #[case::ends_with("select ends_with(utf8, 'a') as ends_with from tbl1")] + #[case::starts_with("select starts_with(utf8, 'a') as starts_with from tbl1")] + #[case::contains("select contains(utf8, 'a') as contains from tbl1")] + #[case::split("select split(utf8, '.') as split from tbl1")] + #[case::replace("select replace(utf8, 'a', 'b') as replace from tbl1")] + #[case::length("select length(utf8) as length from tbl1")] + #[case::lower("select lower(utf8) as lower from tbl1")] + #[case::upper("select upper(utf8) as upper from tbl1")] + #[case::lstrip("select lstrip(utf8) as lstrip from tbl1")] + #[case::rstrip("select rstrip(utf8) as rstrip from tbl1")] + #[case::reverse("select reverse(utf8) as reverse from tbl1")] + #[case::capitalize("select capitalize(utf8) as capitalize from tbl1")] + #[case::left("select left(utf8, 1) as left from tbl1")] + #[case::right("select right(utf8, 1) as right from tbl1")] + #[case::find("select find(utf8, 'a') as find from tbl1")] + #[case::rpad("select rpad(utf8, 1, 'a') as rpad from tbl1")] + #[case::lpad("select lpad(utf8, 1, 'a') as lpad from tbl1")] + #[case::repeat("select repeat(utf8, 1) as repeat from tbl1")] + #[case::to_date("select to_date(utf8, 'YYYY-MM-DD') as to_date from tbl1")] + // #[case::to_datetime("select to_datetime(utf8, 'YYYY-MM-DD') as to_datetime from tbl1")] + fn test_compiles_funcs(planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> { + let plan = planner.plan_sql(query); + assert!(plan.is_ok(), "query: {}\nerror: {:?}", query, plan); + + Ok(()) + } }