Skip to content

Commit

Permalink
[FEAT]: sql tbl alias, and compount ident for joins (#3066)
Browse files Browse the repository at this point in the history
- closes #3065
- closes #3059
  • Loading branch information
universalmind303 authored Oct 17, 2024
1 parent 69fef20 commit e4c6f3f
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 53 deletions.
2 changes: 1 addition & 1 deletion src/daft-sql/src/modules/aggs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fn handle_count(inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResul
},
[FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(name))] => {
match planner.relation_opt() {
Some(rel) if name.to_string() == rel.name => {
Some(rel) if name.to_string() == rel.get_name() => {
let schema = rel.schema();
col(schema.fields[0].name.clone())
.count(daft_core::count_mode::CountMode::All)
Expand Down
129 changes: 84 additions & 45 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use common_error::DaftResult;
use daft_core::prelude::*;
Expand All @@ -12,8 +12,8 @@ use daft_plan::{LogicalPlanBuilder, LogicalPlanRef};
use sqlparser::{
ast::{
ArrayElemTypeDef, BinaryOperator, CastKind, ExactNumberInfo, GroupByExpr, Ident, Query,
SelectItem, Statement, StructField, Subscript, TableWithJoins, TimezoneInfo, UnaryOperator,
Value, WildcardAdditionalOptions,
SelectItem, Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo,
UnaryOperator, Value, WildcardAdditionalOptions,
},
dialect::GenericDialect,
parser::{Parser, ParserOptions},
Expand All @@ -32,11 +32,28 @@ use crate::{
pub struct Relation {
pub(crate) inner: LogicalPlanBuilder,
pub(crate) name: String,
pub(crate) alias: Option<TableAlias>,
}

impl Relation {
pub fn new(inner: LogicalPlanBuilder, name: String) -> Self {
Self { inner, name }
Self {
inner,
name,
alias: None,
}
}
pub fn with_alias(self, alias: TableAlias) -> Self {
Self {
alias: Some(alias),
..self
}
}
pub fn get_name(&self) -> String {
self.alias
.as_ref()
.map(|alias| ident_to_str(&alias.name))
.unwrap_or_else(|| self.name.clone())
}
pub(crate) fn schema(&self) -> SchemaRef {
self.inner.schema()
Expand All @@ -46,13 +63,15 @@ impl Relation {
pub struct SQLPlanner {
catalog: SQLCatalog,
current_relation: Option<Relation>,
table_map: HashMap<String, Relation>,
}

impl Default for SQLPlanner {
fn default() -> Self {
Self {
catalog: SQLCatalog::new(),
current_relation: None,
table_map: HashMap::new(),
}
}
}
Expand All @@ -62,6 +81,7 @@ impl SQLPlanner {
Self {
catalog: context,
current_relation: None,
table_map: HashMap::new(),
}
}

Expand All @@ -76,6 +96,12 @@ impl SQLPlanner {
self.current_relation.as_ref()
}

/// Clears the current context used for planning a SQL query
fn clear_context(&mut self) {
self.current_relation = None;
self.table_map.clear();
}

pub fn plan_sql(&mut self, sql: &str) -> SQLPlannerResult<LogicalPlanRef> {
let tokens = Tokenizer::new(&GenericDialect {}, sql).tokenize()?;

Expand All @@ -88,12 +114,14 @@ impl SQLPlanner {

let statements = parser.parse_statements()?;

match statements.len() {
let plan = match statements.len() {
1 => Ok(self.plan_statement(&statements[0])?),
other => {
unsupported_sql_err!("Only exactly one SQL statement allowed, found {}", other)
}
}
};
self.clear_context();
plan
}

fn plan_statement(&mut self, statement: &Statement) -> SQLPlannerResult<LogicalPlanRef> {
Expand Down Expand Up @@ -256,19 +284,19 @@ impl SQLPlanner {
Ok((exprs, desc))
}

fn plan_from(&self, from: &TableWithJoins) -> SQLPlannerResult<Relation> {
fn plan_from(&mut self, from: &TableWithJoins) -> SQLPlannerResult<Relation> {
fn collect_compound_identifiers(
left: &[Ident],
right: &[Ident],
left_name: &str,
right_name: &str,
left_rel: &Relation,
right_rel: &Relation,
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<ExprRef>)> {
if left.len() == 2 && right.len() == 2 {
let (tbl_a, col_a) = (&left[0].value, &left[1].value);
let (tbl_b, col_b) = (&right[0].value, &right[1].value);

// switch left/right operands if the caller has them in reverse
if left_name == tbl_b || right_name == tbl_a {
if &left_rel.get_name() == tbl_b || &right_rel.get_name() == tbl_a {
Ok((vec![col(col_b.as_ref())], vec![col(col_a.as_ref())]))
} else {
Ok((vec![col(col_a.as_ref())], vec![col(col_b.as_ref())]))
Expand All @@ -280,8 +308,8 @@ impl SQLPlanner {

fn process_join_on(
expression: &sqlparser::ast::Expr,
left_name: &str,
right_name: &str,
left_rel: &Relation,
right_rel: &Relation,
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<ExprRef>)> {
if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression {
match *op {
Expand All @@ -291,16 +319,14 @@ impl SQLPlanner {
sqlparser::ast::Expr::CompoundIdentifier(right),
) = (left.as_ref(), right.as_ref())
{
collect_compound_identifiers(left, right, left_name, right_name)
collect_compound_identifiers(left, right, left_rel, right_rel)
} else {
unsupported_sql_err!("JOIN clauses support '=' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right);
}
}
BinaryOperator::And => {
let (mut left_i, mut right_i) =
process_join_on(left, left_name, right_name)?;
let (mut left_j, mut right_j) =
process_join_on(right, left_name, right_name)?;
let (mut left_i, mut right_i) = process_join_on(left, left_rel, right_rel)?;
let (mut left_j, mut right_j) = process_join_on(left, left_rel, right_rel)?;
left_i.append(&mut left_j);
right_i.append(&mut right_j);
Ok((left_i, right_i))
Expand All @@ -310,14 +336,15 @@ impl SQLPlanner {
}
}
} else if let sqlparser::ast::Expr::Nested(expr) = expression {
process_join_on(expr, left_name, right_name)
process_join_on(expr, left_rel, right_rel)
} else {
unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found expression = {:?}", expression);
}
}

let relation = from.relation.clone();
let mut left_rel = self.plan_relation(&relation)?;
self.table_map.insert(left_rel.get_name(), left_rel.clone());

for join in &from.joins {
use sqlparser::ast::{
Expand All @@ -327,17 +354,16 @@ impl SQLPlanner {
OuterApply, RightAnti, RightOuter, RightSemi,
},
};
let Relation {
inner: right_plan,
name: right_name,
} = self.plan_relation(&join.relation)?;
let right_rel = self.plan_relation(&join.relation)?;
self.table_map
.insert(right_rel.get_name(), right_rel.clone());

match &join.join_operator {
Inner(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel.name, &right_name)?;
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
right_plan,
right_rel.inner,
left_on,
right_on,
JoinType::Inner,
Expand All @@ -350,24 +376,30 @@ impl SQLPlanner {
.map(|i| col(i.value.clone()))
.collect::<Vec<_>>();

left_rel.inner =
left_rel
.inner
.join(right_plan, on.clone(), on, JoinType::Inner, None)?;
left_rel.inner = left_rel.inner.join(
right_rel.inner,
on.clone(),
on,
JoinType::Inner,
None,
)?;
}
LeftOuter(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel.name, &right_name)?;
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner =
left_rel
.inner
.join(right_plan, left_on, right_on, JoinType::Left, None)?;
left_rel.inner = left_rel.inner.join(
right_rel.inner,
left_on,
right_on,
JoinType::Left,
None,
)?;
}
RightOuter(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel.name, &right_name)?;
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
right_plan,
right_rel.inner,
left_on,
right_on,
JoinType::Right,
Expand All @@ -376,10 +408,10 @@ impl SQLPlanner {
}

FullOuter(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel.name, &right_name)?;
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
right_plan,
right_rel.inner,
left_on,
right_on,
JoinType::Outer,
Expand All @@ -402,29 +434,36 @@ impl SQLPlanner {
}

fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult<Relation> {
match rel {
let (rel, alias) = match rel {
sqlparser::ast::TableFactor::Table {
name,
args: Some(args),
alias,
..
} => {
let tbl_fn = name.0.first().unwrap().value.as_str();

self.plan_table_function(tbl_fn, args, alias)
(self.plan_table_function(tbl_fn, args)?, alias.clone())
}
sqlparser::ast::TableFactor::Table {
name, args: None, ..
name,
args: None,
alias,
..
} => {
let table_name = name.to_string();
let plan = self
.catalog
.get_table(&table_name)
.ok_or_else(|| PlannerError::table_not_found(table_name.clone()))?;
let plan_builder = LogicalPlanBuilder::new(plan, None);
Ok(Relation::new(plan_builder, table_name))
(Relation::new(plan_builder, table_name), alias.clone())
}
_ => todo!(),
};
if let Some(alias) = alias {
Ok(rel.with_alias(alias))
} else {
Ok(rel)
}
}

Expand All @@ -433,23 +472,23 @@ impl SQLPlanner {

let root = idents.next().unwrap();
let root = ident_to_str(root);
let current_relation = match &self.current_relation {
let current_relation = match self.table_map.get(&root) {
Some(rel) => rel,
None => {
return Err(PlannerError::TableNotFound {
message: "Expected table".to_string(),
})
}
};
if root == current_relation.name {
if root == current_relation.get_name() {
let column = idents.next().unwrap();
let column_name = ident_to_str(column);
let current_schema = current_relation.inner.schema();
let f = current_schema.get_field(&column_name).ok();
if let Some(field) = f {
Ok(vec![col(field.name.clone())])
} else {
column_not_found_err!(&column_name, &current_relation.name);
column_not_found_err!(&column_name, &current_relation.get_name());
}
} else {
table_not_found_err!(root);
Expand Down
9 changes: 2 additions & 7 deletions src/daft-sql/src/table_provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{collections::HashMap, sync::Arc};
use daft_plan::LogicalPlanBuilder;
use once_cell::sync::Lazy;
use read_parquet::ReadParquetFunction;
use sqlparser::ast::{TableAlias, TableFunctionArgs};
use sqlparser::ast::TableFunctionArgs;

use crate::{
error::SQLPlannerResult,
Expand Down Expand Up @@ -52,7 +52,6 @@ impl SQLPlanner {
&self,
fn_name: &str,
args: &TableFunctionArgs,
alias: &Option<TableAlias>,
) -> SQLPlannerResult<Relation> {
let fns = &SQL_TABLE_FUNCTIONS;

Expand All @@ -61,12 +60,8 @@ impl SQLPlanner {
};

let builder = func.plan(self, args)?;
let name = alias
.as_ref()
.map(|a| a.name.value.clone())
.unwrap_or_else(|| fn_name.to_string());

Ok(Relation::new(builder, name))
Ok(Relation::new(builder, fn_name.to_string()))
}
}

Expand Down
27 changes: 27 additions & 0 deletions tests/sql/test_joins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import daft
from daft import col


def test_joins_using():
df1 = daft.from_pydict({"idx": [1, 2], "val": [10, 20]})
df2 = daft.from_pydict({"idx": [1, 2], "score": [0.1, 0.2]})

df_sql = daft.sql("select * from df1 join df2 using (idx)")
actual = df_sql.collect().to_pydict()

expected = df1.join(df2, on="idx").collect().to_pydict()

assert actual == expected


def test_joins_with_alias():
df1 = daft.from_pydict({"idx": [1, 2], "val": [10, 20]})
df2 = daft.from_pydict({"idx": [1, 2], "score": [0.1, 0.2]})

df_sql = daft.sql("select * from df1 as foo join df2 as bar on (foo.idx=bar.idx) where bar.score>0.1")

actual = df_sql.collect().to_pydict()

expected = df1.join(df2, on="idx").filter(col("score") > 0.1).collect().to_pydict()

assert actual == expected
6 changes: 6 additions & 0 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,9 @@ def test_sql_multi_statement_sql_error():
catalog = SQLCatalog({})
with pytest.raises(Exception, match="one SQL statement allowed"):
daft.sql("SELECT * FROM df; SELECT * FROM df", catalog)


def test_sql_tbl_alias():
catalog = SQLCatalog({"df": daft.from_pydict({"n": [1, 2, 3]})})
df = daft.sql("SELECT df_alias.n FROM df AS df_alias where df_alias.n = 2", catalog)
assert df.collect().to_pydict() == {"n": [2]}

0 comments on commit e4c6f3f

Please sign in to comment.