Skip to content

Commit

Permalink
fix: Resolve multiple SQL JOIN issues
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed May 27, 2024
1 parent d856b49 commit b6729e9
Show file tree
Hide file tree
Showing 18 changed files with 415 additions and 154 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,13 @@ impl LazyFrame {

/// Creates the Cartesian product from both frames, preserving the order of the left keys.
#[cfg(feature = "cross_join")]
pub fn cross_join(self, other: LazyFrame) -> LazyFrame {
self.join(other, vec![], vec![], JoinArgs::new(JoinType::Cross))
pub fn cross_join(self, other: LazyFrame, suffix: Option<String>) -> LazyFrame {
self.join(
other,
vec![],
vec![],
JoinArgs::new(JoinType::Cross).with_suffix(suffix),
)
}

/// Left outer join this query with another lazy query.
Expand Down Expand Up @@ -1220,9 +1225,7 @@ impl LazyFrame {
if let Some(suffix) = args.suffix {
builder = builder.suffix(suffix);
}

// Note: args.slice is set by the optimizer

builder.finish()
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ fn test_cse_columns_projections() -> PolarsResult<()> {
]?
.lazy();

let left = left.cross_join(right.clone().select([col("A")]));
let left = left.cross_join(right.clone().select([col("A")]), None);
let q = left.join(
right.rename(["B"], ["C"]),
[col("A"), col("C")],
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/projection_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn test_cross_join_pd() -> PolarsResult<()> {
"price" => [5, 4]
]?;

let q = food.lazy().cross_join(drink.lazy()).select([
let q = food.lazy().cross_join(drink.lazy(), None).select([
col("name").alias("food"),
col("name_right").alias("beverage"),
(col("price") + col("price_right")).alias("total"),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ fn test_cross_join() -> PolarsResult<()> {
"b" => [None, Some(12)]
]?;

let out = df1.lazy().cross_join(df2.lazy()).collect()?;
let out = df1.lazy().cross_join(df2.lazy(), None).collect()?;
assert_eq!(out.shape(), (6, 4));
Ok(())
}
Expand Down
14 changes: 9 additions & 5 deletions crates/polars-lazy/src/tests/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ fn test_streaming_union_order() -> PolarsResult<()> {
fn test_streaming_union_join() -> PolarsResult<()> {
let q = get_csv_glob();
let q = q.select([col("sugars_g"), col("calories")]);
let q = q.clone().cross_join(q);
let q = q.clone().cross_join(q, None);

assert_streaming_with_default(q, true, true);
Ok(())
Expand Down Expand Up @@ -166,18 +166,22 @@ fn test_streaming_cross_join() -> PolarsResult<()> {
"a" => [1 ,2, 3]
]?;
let q = df.lazy();
let out = q.clone().cross_join(q).with_streaming(true).collect()?;
let out = q
.clone()
.cross_join(q, None)
.with_streaming(true)
.collect()?;
assert_eq!(out.shape(), (9, 2));

let q = get_parquet_file().with_projection_pushdown(false);
let q1 = q
.clone()
.select([col("calories")])
.cross_join(q.clone())
.cross_join(q.clone(), None)
.filter(col("calories").gt(col("calories_right")));
let q2 = q1
.select([all().name().suffix("_second")])
.cross_join(q)
.cross_join(q, None)
.filter(col("calories_right_second").lt(col("calories")))
.select([
col("calories"),
Expand Down Expand Up @@ -266,7 +270,7 @@ fn test_streaming_slice() -> PolarsResult<()> {
]?
.lazy();

let q = lf_a.clone().cross_join(lf_a).slice(10, 20);
let q = lf_a.clone().cross_join(lf_a, None).slice(10, 20);
let a = q.with_streaming(true).collect().unwrap();
assert_eq!(a.shape(), (20, 2));

Expand Down
5 changes: 5 additions & 0 deletions crates/polars-ops/src/frame/join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ impl JoinArgs {
self
}

pub fn with_suffix(mut self, suffix: Option<String>) -> Self {
self.suffix = suffix;
self
}

pub fn suffix(&self) -> &str {
self.suffix.as_deref().unwrap_or("_right")
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ pub enum Excluded {
impl Expr {
/// Get Field result of the expression. The schema is the input data.
pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult<Field> {
// this is not called much and th expression depth is typically shallow
// this is not called much and the expression depth is typically shallow
let mut arena = Arena::with_capacity(5);
self.to_field_amortized(schema, ctxt, &mut arena)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ fn resolve_join_suffixes(
.iter()
.map(|proj| {
let name = column_node_to_name(*proj, expr_arena);
if name.contains(suffix) && schema_after_join.get(&name).is_none() {
if name.ends_with(suffix) && schema_after_join.get(&name).is_none() {
let downstream_name = &name.as_ref()[..name.len() - suffix.len()];
let col = AExpr::Column(ColumnName::from(downstream_name));
let node = expr_arena.add(col);
Expand Down
1 change: 1 addition & 0 deletions crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ arrow = { workspace = true }
polars-core = { workspace = true }
polars-error = { workspace = true }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] }
polars-ops = { workspace = true }
polars-plan = { workspace = true }

hex = { workspace = true }
Expand Down
143 changes: 108 additions & 35 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@ use std::cell::RefCell;
use polars_core::prelude::*;
use polars_error::to_compute_err;
use polars_lazy::prelude::*;
use polars_ops::frame::JoinCoalesce;
use polars_plan::prelude::*;
use sqlparser::ast::{
Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinOperator,
ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator,
SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator,
Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint,
JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr,
SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator,
Value as SQLValue, WildcardAdditionalOptions,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};

use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry};
use crate::sql_expr::{parse_sql_expr, process_join};
use crate::sql_expr::{parse_sql_expr, process_join_constraint};
use crate::table_functions::PolarsTableFunctions;

/// The SQLContext is the main entry point for executing SQL queries.
Expand All @@ -23,7 +24,8 @@ pub struct SQLContext {
pub(crate) table_map: PlHashMap<String, LazyFrame>,
pub(crate) function_registry: Arc<dyn FunctionRegistry>,
cte_map: RefCell<PlHashMap<String, LazyFrame>>,
aliases: RefCell<PlHashMap<String, String>>,
table_aliases: RefCell<PlHashMap<String, String>>,
joined_aliases: RefCell<PlHashMap<String, PlHashMap<String, String>>>,
}

impl Default for SQLContext {
Expand All @@ -32,7 +34,8 @@ impl Default for SQLContext {
function_registry: Arc::new(DefaultFunctionRegistry {}),
table_map: Default::default(),
cte_map: Default::default(),
aliases: Default::default(),
table_aliases: Default::default(),
joined_aliases: Default::default(),
}
}
}
Expand Down Expand Up @@ -110,11 +113,16 @@ impl SQLContext {
.map_err(to_compute_err)?
.parse_statements()
.map_err(to_compute_err)?;
polars_ensure!(ast.len() == 1, ComputeError: "One and only one statement at a time please");

polars_ensure!(ast.len() == 1, ComputeError: "One (and only one) statement at a time please");

let res = self.execute_statement(ast.first().unwrap());
// Every execution should clear the CTE map.

// Every execution should clear the statement-level maps.
self.cte_map.borrow_mut().clear();
self.aliases.borrow_mut().clear();
self.table_aliases.borrow_mut().clear();
self.joined_aliases.borrow_mut().clear();

res
}

Expand All @@ -137,22 +145,6 @@ impl SQLContext {
}

impl SQLContext {
fn register_cte(&mut self, name: &str, lf: LazyFrame) {
self.cte_map.borrow_mut().insert(name.to_owned(), lf);
}

pub(super) fn get_table_from_current_scope(&self, name: &str) -> Option<LazyFrame> {
let table_name = self.table_map.get(name).cloned();
table_name
.or_else(|| self.cte_map.borrow().get(name).cloned())
.or_else(|| {
self.aliases
.borrow()
.get(name)
.and_then(|alias| self.table_map.get(alias).cloned())
})
}

pub(crate) fn execute_statement(&mut self, stmt: &Statement) -> PolarsResult<LazyFrame> {
let ast = stmt;
Ok(match ast {
Expand Down Expand Up @@ -183,6 +175,31 @@ impl SQLContext {
self.process_limit_offset(lf, &query.limit, &query.offset)
}

pub(super) fn get_table_from_current_scope(&self, name: &str) -> Option<LazyFrame> {
let table_name = self.table_map.get(name).cloned();
table_name
.or_else(|| self.cte_map.borrow().get(name).cloned())
.or_else(|| {
self.table_aliases
.borrow()
.get(name)
.and_then(|alias| self.table_map.get(alias).cloned())
})
}

pub(super) fn resolve_name(&self, tbl_name: &str, column_name: &str) -> String {
if self.joined_aliases.borrow().contains_key(tbl_name) {
self.joined_aliases
.borrow()
.get(tbl_name)
.and_then(|aliases| aliases.get(column_name))
.cloned()
.unwrap_or_else(|| column_name.to_string())
} else {
column_name.to_string()
}
}

fn process_set_expr(&mut self, expr: &SetExpr, query: &Query) -> PolarsResult<LazyFrame> {
match expr {
SetExpr::Select(select_stmt) => self.execute_select(select_stmt, query),
Expand Down Expand Up @@ -296,6 +313,10 @@ impl SQLContext {
}
}

fn register_cte(&mut self, name: &str, lf: LazyFrame) {
self.cte_map.borrow_mut().insert(name.to_owned(), lf);
}

fn register_ctes(&mut self, query: &Query) -> PolarsResult<()> {
if let Some(with) = &query.with {
if with.recursive {
Expand All @@ -316,40 +337,63 @@ impl SQLContext {
if !tbl_expr.joins.is_empty() {
for tbl in &tbl_expr.joins {
let (r_name, rf) = self.get_table(&tbl.relation)?;
let left_schema = lf.schema()?;
let right_schema = rf.schema()?;

lf = match &tbl.join_operator {
JoinOperator::CrossJoin => lf.cross_join(rf),
JoinOperator::FullOuter(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Full)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Full)?
},
JoinOperator::Inner(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)?
},
JoinOperator::LeftOuter(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Left)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Left)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::LeftAnti(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Anti)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Anti)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::LeftSemi(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Semi)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Semi)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::RightAnti(constraint) => {
process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Anti)?
self.process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Anti)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::RightSemi(constraint) => {
process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Semi)?
self.process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Semi)?
},
JoinOperator::CrossJoin => lf.cross_join(rf, Some(format!(":{}", r_name))),
join_type => {
polars_bail!(
InvalidOperation:
"join type '{:?}' not yet supported by polars-sql", join_type
);
},
}
};

// track join-aliased columns so we can resolve them later
let joined_schema = lf.schema()?;
self.joined_aliases.borrow_mut().insert(
r_name.to_string(),
right_schema
.iter_names()
.filter_map(|name| {
// col exists in both tables and is aliased in the joined result
let aliased_name = format!("{}:{}", name, r_name);
if left_schema.contains(name)
&& joined_schema.contains(aliased_name.as_str())
{
Some((name.to_string(), aliased_name))
} else {
None
}
})
.collect::<PlHashMap<String, String>>(),
);
}
};
Ok(lf)
Expand Down Expand Up @@ -406,6 +450,10 @@ impl SQLContext {
})
.collect::<PolarsResult<_>>()?;

// dbg!(projections.clone());
// dbg!(lf.clone().collect()?);
// dbg!(lf.clone().select(projections.clone()).collect()?);

// Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints).
let mut group_by_keys: Vec<Expr> = Vec::new();
match &select_stmt.group_by {
Expand Down Expand Up @@ -578,6 +626,31 @@ impl SQLContext {
Ok(lf)
}

pub(super) fn process_join(
&self,
left_tbl: LazyFrame,
right_tbl: LazyFrame,
constraint: &JoinConstraint,
tbl_name: &str,
join_tbl_name: &str,
join_type: JoinType,
) -> PolarsResult<LazyFrame> {
let (left_on, right_on) = process_join_constraint(constraint, tbl_name, join_tbl_name)?;

let joined_tbl = left_tbl
.clone()
.join_builder()
.with(right_tbl.clone())
.left_on(left_on)
.right_on(right_on)
.how(join_type)
.suffix(format!(":{}", join_tbl_name))
.coalesce(JoinCoalesce::KeepColumns)
.finish();

Ok(joined_tbl)
}

fn process_subqueries(&self, lf: LazyFrame, exprs: Vec<&mut Expr>) -> LazyFrame {
let mut contexts = vec![];
for expr in exprs {
Expand Down Expand Up @@ -644,7 +717,7 @@ impl SQLContext {
if let Some(lf) = self.get_table_from_current_scope(tbl_name) {
match alias {
Some(alias) => {
self.aliases
self.table_aliases
.borrow_mut()
.insert(alias.name.value.clone(), tbl_name.to_string());
Ok((alias.to_string(), lf))
Expand Down
Loading

0 comments on commit b6729e9

Please sign in to comment.