Skip to content

Commit

Permalink
[BUG]: Sql groupby fix (#2843)
Browse files Browse the repository at this point in the history
addresses #2835
  • Loading branch information
universalmind303 authored Sep 17, 2024
1 parent 72b1440 commit 07e92f6
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 115 deletions.
287 changes: 174 additions & 113 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::sync::Arc;

use common_error::DaftResult;
use daft_core::prelude::*;
use daft_dsl::{
col,
Expand Down Expand Up @@ -96,30 +97,7 @@ impl SQLPlanner {
}

fn plan_query(&mut self, query: &Query) -> SQLPlannerResult<LogicalPlanBuilder> {
if let Some(with) = &query.with {
unsupported_sql_err!("WITH: {with}")
}
if !query.limit_by.is_empty() {
unsupported_sql_err!("LIMIT BY");
}
if query.offset.is_some() {
unsupported_sql_err!("OFFSET");
}
if query.fetch.is_some() {
unsupported_sql_err!("FETCH");
}
if !query.locks.is_empty() {
unsupported_sql_err!("LOCKS");
}
if let Some(for_clause) = &query.for_clause {
unsupported_sql_err!("{for_clause}");
}
if query.settings.is_some() {
unsupported_sql_err!("SETTINGS");
}
if let Some(format_clause) = &query.format_clause {
unsupported_sql_err!("{format_clause}");
}
check_query_features(query)?;

let selection = query.body.as_select().ok_or_else(|| {
PlannerError::invalid_operation(format!(
Expand All @@ -128,91 +106,7 @@ impl SQLPlanner {
))
})?;

self.plan_select(selection)?;

if let Some(order_by) = &query.order_by {
if order_by.interpolate.is_some() {
unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]");
}
// TODO: if ordering by a column not in the projection, this will fail.
let (exprs, descending) = self.plan_order_by_exprs(order_by.exprs.as_slice())?;
let rel = self.relation_mut();
rel.inner = rel.inner.sort(exprs, descending)?;
}

if let Some(limit) = &query.limit {
let limit = self.plan_expr(limit)?;
if let Expr::Literal(LiteralValue::Int64(limit)) = limit.as_ref() {
let rel = self.relation_mut();
rel.inner = rel.inner.limit(*limit, true)?; // TODO: Should this be eager or not?
} else {
invalid_operation_err!(
"LIMIT <n> must be a constant integer, instead got: {limit}"
);
}
}

Ok(self.current_relation.clone().unwrap().inner)
}

fn plan_order_by_exprs(
&self,
expr: &[sqlparser::ast::OrderByExpr],
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<bool>)> {
let mut exprs = Vec::with_capacity(expr.len());
let mut desc = Vec::with_capacity(expr.len());
for order_by_expr in expr {
if order_by_expr.nulls_first.is_some() {
unsupported_sql_err!("NULLS FIRST");
}
if order_by_expr.with_fill.is_some() {
unsupported_sql_err!("WITH FILL");
}
let expr = self.plan_expr(&order_by_expr.expr)?;
desc.push(!order_by_expr.asc.unwrap_or(true));

exprs.push(expr);
}
Ok((exprs, desc))
}

fn plan_select(&mut self, selection: &sqlparser::ast::Select) -> SQLPlannerResult<()> {
if selection.top.is_some() {
unsupported_sql_err!("TOP");
}
if selection.distinct.is_some() {
unsupported_sql_err!("DISTINCT");
}
if selection.into.is_some() {
unsupported_sql_err!("INTO");
}
if !selection.lateral_views.is_empty() {
unsupported_sql_err!("LATERAL");
}
if selection.prewhere.is_some() {
unsupported_sql_err!("PREWHERE");
}
if !selection.cluster_by.is_empty() {
unsupported_sql_err!("CLUSTER BY");
}
if !selection.distribute_by.is_empty() {
unsupported_sql_err!("DISTRIBUTE BY");
}
if !selection.sort_by.is_empty() {
unsupported_sql_err!("SORT BY");
}
if selection.having.is_some() {
unsupported_sql_err!("HAVING");
}
if !selection.named_window.is_empty() {
unsupported_sql_err!("WINDOW");
}
if selection.qualify.is_some() {
unsupported_sql_err!("QUALIFY");
}
if selection.connect_by.is_some() {
unsupported_sql_err!("CONNECT BY");
}
check_select_features(selection)?;

// FROM/JOIN
let from = selection.clone().from;
Expand Down Expand Up @@ -246,18 +140,23 @@ impl SQLPlanner {
}
}

let to_select = selection
// split the selection into the groupby expressions and the rest
let (groupby_selection, to_select) = selection
.projection
.iter()
.map(|expr| self.select_item_to_expr(expr))
.collect::<SQLPlannerResult<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
.partition::<Vec<_>, _>(|expr| {
groupby_exprs
.iter()
.any(|e| expr.input_mapping() == e.input_mapping())
});

if !groupby_exprs.is_empty() {
let rel = self.relation_mut();
rel.inner = rel.inner.aggregate(to_select, groupby_exprs)?;
rel.inner = rel.inner.aggregate(to_select, groupby_exprs.clone())?;
} else if !to_select.is_empty() {
let rel = self.relation_mut();
let has_aggs = to_select.iter().any(has_agg);
Expand All @@ -268,7 +167,86 @@ impl SQLPlanner {
}
}

Ok(())
if let Some(order_by) = &query.order_by {
if order_by.interpolate.is_some() {
unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]");
}
// TODO: if ordering by a column not in the projection, this will fail.
let (exprs, descending) = self.plan_order_by_exprs(order_by.exprs.as_slice())?;
let rel = self.relation_mut();
rel.inner = rel.inner.sort(exprs, descending)?;
}

// Properly apply or remove the groupby columns from the selection
// This needs to be done after the orderby
// otherwise, the orderby will not be able to reference the grouping columns
//
// ex: SELECT sum(a) as sum_a, max(a) as max_a, b as c FROM table GROUP BY b
//
// The groupby columns are [b]
// the evaluation of sum(a) and max(a) are already handled by the earlier aggregate,
// so our projection is [sum_a, max_a, (b as c)]
// leaving us to handle (b as c)
//
// we filter for the columns in the schema that are not in the groupby keys,
// [sum_a, max_a, b] -> [sum_a, max_a]
//
// Then we add the groupby columns back in with the correct expressions
// this gives us the final projection: [sum_a, max_a, (b as c)]
if !groupby_exprs.is_empty() {
let rel = self.relation_mut();
let schema = rel.inner.schema();

let groupby_keys = groupby_exprs
.iter()
.map(|e| Ok(e.to_field(&schema)?.name))
.collect::<DaftResult<Vec<_>>>()?;

let selection_colums = schema
.exclude(groupby_keys.as_ref())?
.names()
.iter()
.map(|n| col(n.as_str()))
.chain(groupby_selection)
.collect();

rel.inner = rel.inner.select(selection_colums)?;
}

if let Some(limit) = &query.limit {
let limit = self.plan_expr(limit)?;
if let Expr::Literal(LiteralValue::Int64(limit)) = limit.as_ref() {
let rel = self.relation_mut();
rel.inner = rel.inner.limit(*limit, true)?; // TODO: Should this be eager or not?
} else {
invalid_operation_err!(
"LIMIT <n> must be a constant integer, instead got: {limit}"
);
}
}

Ok(self.current_relation.clone().unwrap().inner)
}

fn plan_order_by_exprs(
&self,
expr: &[sqlparser::ast::OrderByExpr],
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<bool>)> {
let mut exprs = Vec::with_capacity(expr.len());
let mut desc = Vec::with_capacity(expr.len());
for order_by_expr in expr {
if order_by_expr.nulls_first.is_some() {
unsupported_sql_err!("NULLS FIRST");
}
if order_by_expr.with_fill.is_some() {
unsupported_sql_err!("WITH FILL");
}
let expr = self.plan_expr(&order_by_expr.expr)?;
desc.push(!order_by_expr.asc.unwrap_or(true));

exprs.push(expr);
}
Ok((exprs, desc))
}

fn plan_from(&self, from: &TableWithJoins) -> SQLPlannerResult<Relation> {
Expand Down Expand Up @@ -952,6 +930,89 @@ impl SQLPlanner {
}
}

/// Checks if the SQL query is valid syntax and doesn't use unsupported features.
/// /// This function examines various clauses and options in the provided [sqlparser::ast::Query]
/// and returns an error if any unsupported features are encountered.
fn check_query_features(query: &sqlparser::ast::Query) -> SQLPlannerResult<()> {
if let Some(with) = &query.with {
unsupported_sql_err!("WITH: {with}")
}
if !query.limit_by.is_empty() {
unsupported_sql_err!("LIMIT BY");
}
if query.offset.is_some() {
unsupported_sql_err!("OFFSET");
}
if query.fetch.is_some() {
unsupported_sql_err!("FETCH");
}
if !query.locks.is_empty() {
unsupported_sql_err!("LOCKS");
}
if let Some(for_clause) = &query.for_clause {
unsupported_sql_err!("{for_clause}");
}
if query.settings.is_some() {
unsupported_sql_err!("SETTINGS");
}
if let Some(format_clause) = &query.format_clause {
unsupported_sql_err!("{format_clause}");
}
Ok(())
}

/// Checks if the features used in the SQL SELECT statement are supported.
///
/// This function examines various clauses and options in the provided [sqlparser::ast::Select]
/// and returns an error if any unsupported features are encountered.
///
/// # Arguments
///
/// * `selection` - A reference to the [sqlparser::ast::Select] to be checked.
///
/// # Returns
///
/// * `SQLPlannerResult<()>` - Ok(()) if all features are supported, or an error describing
/// the first unsupported feature encountered.
fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult<()> {
if selection.top.is_some() {
unsupported_sql_err!("TOP");
}
if selection.distinct.is_some() {
unsupported_sql_err!("DISTINCT");
}
if selection.into.is_some() {
unsupported_sql_err!("INTO");
}
if !selection.lateral_views.is_empty() {
unsupported_sql_err!("LATERAL");
}
if selection.prewhere.is_some() {
unsupported_sql_err!("PREWHERE");
}
if !selection.cluster_by.is_empty() {
unsupported_sql_err!("CLUSTER BY");
}
if !selection.distribute_by.is_empty() {
unsupported_sql_err!("DISTRIBUTE BY");
}
if !selection.sort_by.is_empty() {
unsupported_sql_err!("SORT BY");
}
if selection.having.is_some() {
unsupported_sql_err!("HAVING");
}
if !selection.named_window.is_empty() {
unsupported_sql_err!("WINDOW");
}
if selection.qualify.is_some() {
unsupported_sql_err!("QUALIFY");
}
if selection.connect_by.is_some() {
unsupported_sql_err!("CONNECT BY");
}
Ok(())
}
pub fn sql_expr<S: AsRef<str>>(s: S) -> SQLPlannerResult<ExprRef> {
let planner = SQLPlanner::default();

Expand Down
22 changes: 20 additions & 2 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,26 @@ def test_sql_global_agg():
def test_sql_groupby_agg():
df = daft.from_pydict({"n": [1, 1, 2, 2], "v": [1, 2, 3, 4]})
catalog = SQLCatalog({"test": df})
df = daft.sql("SELECT sum(v) FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert df.collect().to_pydict() == {"n": [1, 2], "v": [3, 7]}
actual = daft.sql("SELECT sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"sum": [3, 7]}

# test with grouping column
actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"n": [1, 2], "sum": [3, 7]}

# test with multiple columns
actual = daft.sql("SELECT max(v) as max, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"max": [2, 4], "sum": [3, 7]}

# test with aliased grouping key
actual = daft.sql("SELECT n as n_alias, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"n_alias": [1, 2], "sum": [3, 7]}

actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY -n", catalog=catalog)
assert actual.collect().to_pydict() == {"n": [2, 1], "sum": [7, 3]}

actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY sum", catalog=catalog)
assert actual.collect().to_pydict() == {"n": [1, 2], "sum": [3, 7]}


def test_sql_count_star():
Expand Down

0 comments on commit 07e92f6

Please sign in to comment.