Skip to content

Commit

Permalink
[FEAT]: add partitioning_* functions to sql (#2869)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Sep 24, 2024
1 parent c66e384 commit e3dd671
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 4 deletions.
99 changes: 95 additions & 4 deletions src/daft-sql/src/modules/partitioning.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,102 @@
use daft_dsl::functions::partitioning::{self, PartitioningExpr};

use super::SQLModule;
use crate::functions::SQLFunctions;
use crate::{
ensure,
functions::{SQLFunction, SQLFunctions},
};

pub struct SQLModulePartitioning;

impl SQLModule for SQLModulePartitioning {
fn register(_parent: &mut SQLFunctions) {
// use FunctionExpr::Partitioning as f;
// TODO
fn register(parent: &mut SQLFunctions) {
parent.add_fn("partitioning_years", PartitioningExpr::Years);
parent.add_fn("partitioning_months", PartitioningExpr::Months);
parent.add_fn("partitioning_days", PartitioningExpr::Days);
parent.add_fn("partitioning_hours", PartitioningExpr::Hours);
parent.add_fn(
"partitioning_iceberg_bucket",
PartitioningExpr::IcebergBucket(0),
);
parent.add_fn(
"partitioning_iceberg_truncate",
PartitioningExpr::IcebergTruncate(0),
);
}
}

impl SQLFunction for PartitioningExpr {
fn to_expr(
&self,
args: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match self {
PartitioningExpr::Years => {
partitioning_helper(args, planner, "years", partitioning::years)
}
PartitioningExpr::Months => {
partitioning_helper(args, planner, "months", partitioning::months)
}
PartitioningExpr::Days => {
partitioning_helper(args, planner, "days", partitioning::days)
}
PartitioningExpr::Hours => {
partitioning_helper(args, planner, "hours", partitioning::hours)
}
PartitioningExpr::IcebergBucket(_) => {
ensure!(args.len() == 2, "iceberg_bucket takes exactly 2 arguments");
let input = planner.plan_function_arg(&args[0])?;
let n = planner
.plan_function_arg(&args[1])?
.as_literal()
.and_then(|l| l.as_i64())
.ok_or_else(|| {
crate::error::PlannerError::unsupported_sql(
"Expected integer literal".to_string(),
)
})
.and_then(|n| {
if n > i32::MAX as i64 {
Err(crate::error::PlannerError::unsupported_sql(
"Integer literal too large".to_string(),
))
} else {
Ok(n as i32)
}
})?;

Ok(partitioning::iceberg_bucket(input, n))
}
PartitioningExpr::IcebergTruncate(_) => {
ensure!(
args.len() == 2,
"iceberg_truncate takes exactly 2 arguments"
);
let input = planner.plan_function_arg(&args[0])?;
let w = planner
.plan_function_arg(&args[1])?
.as_literal()
.and_then(|l| l.as_i64())
.ok_or_else(|| {
crate::error::PlannerError::unsupported_sql(
"Expected integer literal".to_string(),
)
})?;

Ok(partitioning::iceberg_truncate(input, w))
}
}
}
}

fn partitioning_helper<F: FnOnce(daft_dsl::ExprRef) -> daft_dsl::ExprRef>(
args: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
method_name: &str,
f: F,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
ensure!(args.len() == 1, "{} takes exactly 1 argument", method_name);
let args = planner.plan_function_arg(&args[0])?;
Ok(f(args))
}
45 changes: 45 additions & 0 deletions tests/sql/test_partitioning_exprs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from datetime import datetime

import daft
from daft.sql.sql import SQLCatalog


def test_partitioning_exprs():
df = daft.from_pydict(
{
"id": [1, 2, 3, 4, 5],
"date": [
datetime(2024, 1, 1),
datetime(2024, 2, 1),
datetime(2024, 3, 1),
datetime(2024, 4, 1),
datetime(2024, 5, 1),
],
}
)
catalog = SQLCatalog({"test": df})
expected = (
df.select(
daft.col("date").partitioning.days().alias("date_days"),
daft.col("date").partitioning.hours().alias("date_hours"),
daft.col("date").partitioning.months().alias("date_months"),
daft.col("date").partitioning.years().alias("date_years"),
daft.col("id").partitioning.iceberg_bucket(10).alias("id_bucket"),
daft.col("id").partitioning.iceberg_truncate(10).alias("id_truncate"),
)
.collect()
.to_pydict()
)
sql = """
SELECT
partitioning_days(date) AS date_days,
partitioning_hours(date) AS date_hours,
partitioning_months(date) AS date_months,
partitioning_years(date) AS date_years,
partitioning_iceberg_bucket(id, 10) AS id_bucket,
partitioning_iceberg_truncate(id, 10) AS id_truncate
FROM test
"""
actual = daft.sql(sql, catalog).collect().to_pydict()

assert actual == expected

0 comments on commit e3dd671

Please sign in to comment.