diff --git a/src/daft-sql/src/modules/partitioning.rs b/src/daft-sql/src/modules/partitioning.rs index b357ac810f..589c298e2f 100644 --- a/src/daft-sql/src/modules/partitioning.rs +++ b/src/daft-sql/src/modules/partitioning.rs @@ -1,11 +1,102 @@ +use daft_dsl::functions::partitioning::{self, PartitioningExpr}; + use super::SQLModule; -use crate::functions::SQLFunctions; +use crate::{ + ensure, + functions::{SQLFunction, SQLFunctions}, +}; pub struct SQLModulePartitioning; impl SQLModule for SQLModulePartitioning { - fn register(_parent: &mut SQLFunctions) { - // use FunctionExpr::Partitioning as f; - // TODO + fn register(parent: &mut SQLFunctions) { + parent.add_fn("partitioning_years", PartitioningExpr::Years); + parent.add_fn("partitioning_months", PartitioningExpr::Months); + parent.add_fn("partitioning_days", PartitioningExpr::Days); + parent.add_fn("partitioning_hours", PartitioningExpr::Hours); + parent.add_fn( + "partitioning_iceberg_bucket", + PartitioningExpr::IcebergBucket(0), + ); + parent.add_fn( + "partitioning_iceberg_truncate", + PartitioningExpr::IcebergTruncate(0), + ); + } +} + +impl SQLFunction for PartitioningExpr { + fn to_expr( + &self, + args: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match self { + PartitioningExpr::Years => { + partitioning_helper(args, planner, "years", partitioning::years) + } + PartitioningExpr::Months => { + partitioning_helper(args, planner, "months", partitioning::months) + } + PartitioningExpr::Days => { + partitioning_helper(args, planner, "days", partitioning::days) + } + PartitioningExpr::Hours => { + partitioning_helper(args, planner, "hours", partitioning::hours) + } + PartitioningExpr::IcebergBucket(_) => { + ensure!(args.len() == 2, "iceberg_bucket takes exactly 2 arguments"); + let input = planner.plan_function_arg(&args[0])?; + let n = planner + .plan_function_arg(&args[1])? + .as_literal() + .and_then(|l| l.as_i64()) + .ok_or_else(|| { + crate::error::PlannerError::unsupported_sql( + "Expected integer literal".to_string(), + ) + }) + .and_then(|n| { + if n > i32::MAX as i64 { + Err(crate::error::PlannerError::unsupported_sql( + "Integer literal too large".to_string(), + )) + } else { + Ok(n as i32) + } + })?; + + Ok(partitioning::iceberg_bucket(input, n)) + } + PartitioningExpr::IcebergTruncate(_) => { + ensure!( + args.len() == 2, + "iceberg_truncate takes exactly 2 arguments" + ); + let input = planner.plan_function_arg(&args[0])?; + let w = planner + .plan_function_arg(&args[1])? + .as_literal() + .and_then(|l| l.as_i64()) + .ok_or_else(|| { + crate::error::PlannerError::unsupported_sql( + "Expected integer literal".to_string(), + ) + })?; + + Ok(partitioning::iceberg_truncate(input, w)) + } + } } } + +fn partitioning_helper daft_dsl::ExprRef>( + args: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + method_name: &str, + f: F, +) -> crate::error::SQLPlannerResult { + ensure!(args.len() == 1, "{} takes exactly 1 argument", method_name); + let args = planner.plan_function_arg(&args[0])?; + Ok(f(args)) +} diff --git a/tests/sql/test_partitioning_exprs.py b/tests/sql/test_partitioning_exprs.py new file mode 100644 index 0000000000..04bd3d1447 --- /dev/null +++ b/tests/sql/test_partitioning_exprs.py @@ -0,0 +1,45 @@ +from datetime import datetime + +import daft +from daft.sql.sql import SQLCatalog + + +def test_partitioning_exprs(): + df = daft.from_pydict( + { + "id": [1, 2, 3, 4, 5], + "date": [ + datetime(2024, 1, 1), + datetime(2024, 2, 1), + datetime(2024, 3, 1), + datetime(2024, 4, 1), + datetime(2024, 5, 1), + ], + } + ) + catalog = SQLCatalog({"test": df}) + expected = ( + df.select( + daft.col("date").partitioning.days().alias("date_days"), + daft.col("date").partitioning.hours().alias("date_hours"), + daft.col("date").partitioning.months().alias("date_months"), + daft.col("date").partitioning.years().alias("date_years"), + daft.col("id").partitioning.iceberg_bucket(10).alias("id_bucket"), + daft.col("id").partitioning.iceberg_truncate(10).alias("id_truncate"), + ) + .collect() + .to_pydict() + ) + sql = """ + SELECT + partitioning_days(date) AS date_days, + partitioning_hours(date) AS date_hours, + partitioning_months(date) AS date_months, + partitioning_years(date) AS date_years, + partitioning_iceberg_bucket(id, 10) AS id_bucket, + partitioning_iceberg_truncate(id, 10) AS id_truncate + FROM test + """ + actual = daft.sql(sql, catalog).collect().to_pydict() + + assert actual == expected