From b6729e968d8a0d44aa296928ae92e32deabfc9f6 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 18 Jan 2024 14:01:15 +0000 Subject: [PATCH] fix: Resolve multiple SQL `JOIN` issues --- Cargo.lock | 1 + crates/polars-lazy/src/frame/mod.rs | 11 +- crates/polars-lazy/src/tests/cse.rs | 2 +- .../src/tests/projection_queries.rs | 2 +- crates/polars-lazy/src/tests/queries.rs | 2 +- crates/polars-lazy/src/tests/streaming.rs | 14 +- crates/polars-ops/src/frame/join/args.rs | 5 + crates/polars-plan/src/dsl/expr.rs | 2 +- .../optimizer/projection_pushdown/joins.rs | 2 +- crates/polars-sql/Cargo.toml | 1 + crates/polars-sql/src/context.rs | 143 +++++++++--- crates/polars-sql/src/sql_expr.rs | 39 +--- crates/polars-sql/tests/statements.rs | 216 ++++++++++++++---- .../rust/user-guide/transformations/joins.rs | 2 +- py-polars/tests/unit/sql/test_joins.py | 85 ++++++- .../tests/unit/sql/test_miscellaneous.py | 2 +- py-polars/tests/unit/sql/test_subqueries.py | 38 +-- py-polars/tests/unit/test_queries.py | 2 +- 18 files changed, 415 insertions(+), 154 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4364f9fb5fd08..0e51d5ba97cba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3160,6 +3160,7 @@ dependencies = [ "polars-core", "polars-error", "polars-lazy", + "polars-ops", "polars-plan", "rand", "serde", diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 03e48a059183c..6fd013d246e4e 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -1068,8 +1068,13 @@ impl LazyFrame { /// Creates the Cartesian product from both frames, preserving the order of the left keys. #[cfg(feature = "cross_join")] - pub fn cross_join(self, other: LazyFrame) -> LazyFrame { - self.join(other, vec![], vec![], JoinArgs::new(JoinType::Cross)) + pub fn cross_join(self, other: LazyFrame, suffix: Option) -> LazyFrame { + self.join( + other, + vec![], + vec![], + JoinArgs::new(JoinType::Cross).with_suffix(suffix), + ) } /// Left outer join this query with another lazy query. @@ -1220,9 +1225,7 @@ impl LazyFrame { if let Some(suffix) = args.suffix { builder = builder.suffix(suffix); } - // Note: args.slice is set by the optimizer - builder.finish() } diff --git a/crates/polars-lazy/src/tests/cse.rs b/crates/polars-lazy/src/tests/cse.rs index 4452b3845b467..95c6c5be64dfd 100644 --- a/crates/polars-lazy/src/tests/cse.rs +++ b/crates/polars-lazy/src/tests/cse.rs @@ -305,7 +305,7 @@ fn test_cse_columns_projections() -> PolarsResult<()> { ]? .lazy(); - let left = left.cross_join(right.clone().select([col("A")])); + let left = left.cross_join(right.clone().select([col("A")]), None); let q = left.join( right.rename(["B"], ["C"]), [col("A"), col("C")], diff --git a/crates/polars-lazy/src/tests/projection_queries.rs b/crates/polars-lazy/src/tests/projection_queries.rs index 71e43ab10d3e5..e3d5cb9f25dda 100644 --- a/crates/polars-lazy/src/tests/projection_queries.rs +++ b/crates/polars-lazy/src/tests/projection_queries.rs @@ -47,7 +47,7 @@ fn test_cross_join_pd() -> PolarsResult<()> { "price" => [5, 4] ]?; - let q = food.lazy().cross_join(drink.lazy()).select([ + let q = food.lazy().cross_join(drink.lazy(), None).select([ col("name").alias("food"), col("name_right").alias("beverage"), (col("price") + col("price_right")).alias("total"), diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index 822c830d9d1c0..7b3c76487080e 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -1171,7 +1171,7 @@ fn test_cross_join() -> PolarsResult<()> { "b" => [None, Some(12)] ]?; - let out = df1.lazy().cross_join(df2.lazy()).collect()?; + let out = df1.lazy().cross_join(df2.lazy(), None).collect()?; assert_eq!(out.shape(), (6, 4)); Ok(()) } diff --git a/crates/polars-lazy/src/tests/streaming.rs b/crates/polars-lazy/src/tests/streaming.rs index e34c16a34334d..1ca82f18b832b 100644 --- a/crates/polars-lazy/src/tests/streaming.rs +++ b/crates/polars-lazy/src/tests/streaming.rs @@ -85,7 +85,7 @@ fn test_streaming_union_order() -> PolarsResult<()> { fn test_streaming_union_join() -> PolarsResult<()> { let q = get_csv_glob(); let q = q.select([col("sugars_g"), col("calories")]); - let q = q.clone().cross_join(q); + let q = q.clone().cross_join(q, None); assert_streaming_with_default(q, true, true); Ok(()) @@ -166,18 +166,22 @@ fn test_streaming_cross_join() -> PolarsResult<()> { "a" => [1 ,2, 3] ]?; let q = df.lazy(); - let out = q.clone().cross_join(q).with_streaming(true).collect()?; + let out = q + .clone() + .cross_join(q, None) + .with_streaming(true) + .collect()?; assert_eq!(out.shape(), (9, 2)); let q = get_parquet_file().with_projection_pushdown(false); let q1 = q .clone() .select([col("calories")]) - .cross_join(q.clone()) + .cross_join(q.clone(), None) .filter(col("calories").gt(col("calories_right"))); let q2 = q1 .select([all().name().suffix("_second")]) - .cross_join(q) + .cross_join(q, None) .filter(col("calories_right_second").lt(col("calories"))) .select([ col("calories"), @@ -266,7 +270,7 @@ fn test_streaming_slice() -> PolarsResult<()> { ]? .lazy(); - let q = lf_a.clone().cross_join(lf_a).slice(10, 20); + let q = lf_a.clone().cross_join(lf_a, None).slice(10, 20); let a = q.with_streaming(true).collect().unwrap(); assert_eq!(a.shape(), (20, 2)); diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index 4fbc305968347..f19be5352a569 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -88,6 +88,11 @@ impl JoinArgs { self } + pub fn with_suffix(mut self, suffix: Option) -> Self { + self.suffix = suffix; + self + } + pub fn suffix(&self) -> &str { self.suffix.as_deref().unwrap_or("_right") } diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 62c16d5e90420..08688fcd5aba8 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -304,7 +304,7 @@ pub enum Excluded { impl Expr { /// Get Field result of the expression. The schema is the input data. pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult { - // this is not called much and th expression depth is typically shallow + // this is not called much and the expression depth is typically shallow let mut arena = Arena::with_capacity(5); self.to_field_amortized(schema, ctxt, &mut arena) } diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs index 0a663e9e91955..9ca01c0898da2 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs @@ -442,7 +442,7 @@ fn resolve_join_suffixes( .iter() .map(|proj| { let name = column_node_to_name(*proj, expr_arena); - if name.contains(suffix) && schema_after_join.get(&name).is_none() { + if name.ends_with(suffix) && schema_after_join.get(&name).is_none() { let downstream_name = &name.as_ref()[..name.len() - suffix.len()]; let col = AExpr::Column(ColumnName::from(downstream_name)); let node = expr_arena.add(col); diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 65dc5e9523fdf..1db0e88a116cb 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -13,6 +13,7 @@ arrow = { workspace = true } polars-core = { workspace = true } polars-error = { workspace = true } polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] } +polars-ops = { workspace = true } polars-plan = { workspace = true } hex = { workspace = true } diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 46e1c055f6452..aa0bbe843b47a 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -3,18 +3,19 @@ use std::cell::RefCell; use polars_core::prelude::*; use polars_error::to_compute_err; use polars_lazy::prelude::*; +use polars_ops::frame::JoinCoalesce; use polars_plan::prelude::*; use sqlparser::ast::{ - Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinOperator, - ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, - SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, + Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint, + JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, + SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value as SQLValue, WildcardAdditionalOptions, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry}; -use crate::sql_expr::{parse_sql_expr, process_join}; +use crate::sql_expr::{parse_sql_expr, process_join_constraint}; use crate::table_functions::PolarsTableFunctions; /// The SQLContext is the main entry point for executing SQL queries. @@ -23,7 +24,8 @@ pub struct SQLContext { pub(crate) table_map: PlHashMap, pub(crate) function_registry: Arc, cte_map: RefCell>, - aliases: RefCell>, + table_aliases: RefCell>, + joined_aliases: RefCell>>, } impl Default for SQLContext { @@ -32,7 +34,8 @@ impl Default for SQLContext { function_registry: Arc::new(DefaultFunctionRegistry {}), table_map: Default::default(), cte_map: Default::default(), - aliases: Default::default(), + table_aliases: Default::default(), + joined_aliases: Default::default(), } } } @@ -110,11 +113,16 @@ impl SQLContext { .map_err(to_compute_err)? .parse_statements() .map_err(to_compute_err)?; - polars_ensure!(ast.len() == 1, ComputeError: "One and only one statement at a time please"); + + polars_ensure!(ast.len() == 1, ComputeError: "One (and only one) statement at a time please"); + let res = self.execute_statement(ast.first().unwrap()); - // Every execution should clear the CTE map. + + // Every execution should clear the statement-level maps. self.cte_map.borrow_mut().clear(); - self.aliases.borrow_mut().clear(); + self.table_aliases.borrow_mut().clear(); + self.joined_aliases.borrow_mut().clear(); + res } @@ -137,22 +145,6 @@ impl SQLContext { } impl SQLContext { - fn register_cte(&mut self, name: &str, lf: LazyFrame) { - self.cte_map.borrow_mut().insert(name.to_owned(), lf); - } - - pub(super) fn get_table_from_current_scope(&self, name: &str) -> Option { - let table_name = self.table_map.get(name).cloned(); - table_name - .or_else(|| self.cte_map.borrow().get(name).cloned()) - .or_else(|| { - self.aliases - .borrow() - .get(name) - .and_then(|alias| self.table_map.get(alias).cloned()) - }) - } - pub(crate) fn execute_statement(&mut self, stmt: &Statement) -> PolarsResult { let ast = stmt; Ok(match ast { @@ -183,6 +175,31 @@ impl SQLContext { self.process_limit_offset(lf, &query.limit, &query.offset) } + pub(super) fn get_table_from_current_scope(&self, name: &str) -> Option { + let table_name = self.table_map.get(name).cloned(); + table_name + .or_else(|| self.cte_map.borrow().get(name).cloned()) + .or_else(|| { + self.table_aliases + .borrow() + .get(name) + .and_then(|alias| self.table_map.get(alias).cloned()) + }) + } + + pub(super) fn resolve_name(&self, tbl_name: &str, column_name: &str) -> String { + if self.joined_aliases.borrow().contains_key(tbl_name) { + self.joined_aliases + .borrow() + .get(tbl_name) + .and_then(|aliases| aliases.get(column_name)) + .cloned() + .unwrap_or_else(|| column_name.to_string()) + } else { + column_name.to_string() + } + } + fn process_set_expr(&mut self, expr: &SetExpr, query: &Query) -> PolarsResult { match expr { SetExpr::Select(select_stmt) => self.execute_select(select_stmt, query), @@ -296,6 +313,10 @@ impl SQLContext { } } + fn register_cte(&mut self, name: &str, lf: LazyFrame) { + self.cte_map.borrow_mut().insert(name.to_owned(), lf); + } + fn register_ctes(&mut self, query: &Query) -> PolarsResult<()> { if let Some(with) = &query.with { if with.recursive { @@ -316,40 +337,63 @@ impl SQLContext { if !tbl_expr.joins.is_empty() { for tbl in &tbl_expr.joins { let (r_name, rf) = self.get_table(&tbl.relation)?; + let left_schema = lf.schema()?; + let right_schema = rf.schema()?; + lf = match &tbl.join_operator { - JoinOperator::CrossJoin => lf.cross_join(rf), JoinOperator::FullOuter(constraint) => { - process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Full)? + self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Full)? }, JoinOperator::Inner(constraint) => { - process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)? + self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)? }, JoinOperator::LeftOuter(constraint) => { - process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Left)? + self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Left)? }, #[cfg(feature = "semi_anti_join")] JoinOperator::LeftAnti(constraint) => { - process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Anti)? + self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Anti)? }, #[cfg(feature = "semi_anti_join")] JoinOperator::LeftSemi(constraint) => { - process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Semi)? + self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Semi)? }, #[cfg(feature = "semi_anti_join")] JoinOperator::RightAnti(constraint) => { - process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Anti)? + self.process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Anti)? }, #[cfg(feature = "semi_anti_join")] JoinOperator::RightSemi(constraint) => { - process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Semi)? + self.process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Semi)? }, + JoinOperator::CrossJoin => lf.cross_join(rf, Some(format!(":{}", r_name))), join_type => { polars_bail!( InvalidOperation: "join type '{:?}' not yet supported by polars-sql", join_type ); }, - } + }; + + // track join-aliased columns so we can resolve them later + let joined_schema = lf.schema()?; + self.joined_aliases.borrow_mut().insert( + r_name.to_string(), + right_schema + .iter_names() + .filter_map(|name| { + // col exists in both tables and is aliased in the joined result + let aliased_name = format!("{}:{}", name, r_name); + if left_schema.contains(name) + && joined_schema.contains(aliased_name.as_str()) + { + Some((name.to_string(), aliased_name)) + } else { + None + } + }) + .collect::>(), + ); } }; Ok(lf) @@ -406,6 +450,10 @@ impl SQLContext { }) .collect::>()?; + // dbg!(projections.clone()); + // dbg!(lf.clone().collect()?); + // dbg!(lf.clone().select(projections.clone()).collect()?); + // Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints). let mut group_by_keys: Vec = Vec::new(); match &select_stmt.group_by { @@ -578,6 +626,31 @@ impl SQLContext { Ok(lf) } + pub(super) fn process_join( + &self, + left_tbl: LazyFrame, + right_tbl: LazyFrame, + constraint: &JoinConstraint, + tbl_name: &str, + join_tbl_name: &str, + join_type: JoinType, + ) -> PolarsResult { + let (left_on, right_on) = process_join_constraint(constraint, tbl_name, join_tbl_name)?; + + let joined_tbl = left_tbl + .clone() + .join_builder() + .with(right_tbl.clone()) + .left_on(left_on) + .right_on(right_on) + .how(join_type) + .suffix(format!(":{}", join_tbl_name)) + .coalesce(JoinCoalesce::KeepColumns) + .finish(); + + Ok(joined_tbl) + } + fn process_subqueries(&self, lf: LazyFrame, exprs: Vec<&mut Expr>) -> LazyFrame { let mut contexts = vec![]; for expr in exprs { @@ -644,7 +717,7 @@ impl SQLContext { if let Some(lf) = self.get_table_from_current_scope(tbl_name) { match alias { Some(alias) => { - self.aliases + self.table_aliases .borrow_mut() .insert(alias.name.value.clone(), tbl_name.to_string()); Ok((alias.to_string(), lf)) diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 480b0cce23c34..3ecaa771c15e3 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -350,7 +350,12 @@ impl SQLExprVisitor<'_> { let schema = lf.schema()?; if let Some((_, name, _)) = schema.get_full(&column_name.value) { - Ok(col(name)) + let resolved = &self.ctx.resolve_name(&tbl_name.value, &column_name.value); + Ok(if name != resolved { + col(resolved).alias(name) + } else { + col(name) + }) } else { polars_bail!( ColumnNotFound: "no column named '{}' found in table '{}'", @@ -959,25 +964,6 @@ impl SQLExprVisitor<'_> { } } -pub(super) fn process_join( - left_tbl: LazyFrame, - right_tbl: LazyFrame, - constraint: &JoinConstraint, - tbl_name: &str, - join_tbl_name: &str, - join_type: JoinType, -) -> PolarsResult { - let (left_on, right_on) = process_join_constraint(constraint, tbl_name, join_tbl_name)?; - - Ok(left_tbl - .join_builder() - .with(right_tbl) - .left_on(left_on) - .right_on(right_on) - .how(join_type) - .finish()) -} - fn collect_compound_identifiers( left: &[Ident], right: &[Ident], @@ -988,12 +974,11 @@ fn collect_compound_identifiers( let (tbl_a, col_a) = (&left[0].value, &left[1].value); let (tbl_b, col_b) = (&right[0].value, &right[1].value); - if left_name == tbl_a && right_name == tbl_b { - Ok((vec![col(col_a)], vec![col(col_b)])) - } else if left_name == tbl_b && right_name == tbl_a { + // switch left/right operands if the caller has them in reverse + if left_name == tbl_b || right_name == tbl_a { Ok((vec![col(col_b)], vec![col(col_a)])) } else { - polars_bail!(InvalidOperation: "collect_compound_identifiers: left_name={:?}, right_name={:?}, tbl_a={:?}, tbl_b={:?}", left_name, right_name, tbl_a, tbl_b); + Ok((vec![col(col_a)], vec![col(col_b)])) } } else { polars_bail!(InvalidOperation: "collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", left.len(), right.len()); @@ -1042,9 +1027,9 @@ pub(super) fn process_join_constraint( if let JoinConstraint::On(SQLExpr::BinaryOp { left, op, right }) = constraint { if op == &BinaryOperator::And { let (mut left_on, mut right_on) = process_join_on(left, left_name, right_name)?; - let (left_on_2, right_on_2) = process_join_on(right, left_name, right_name)?; - left_on.extend(left_on_2); - right_on.extend(right_on_2); + let (left_on_, right_on_) = process_join_on(right, left_name, right_name)?; + left_on.extend(left_on_); + right_on.extend(right_on_); return Ok((left_on, right_on)); } if op != &BinaryOperator::Eq { diff --git a/crates/polars-sql/tests/statements.rs b/crates/polars-sql/tests/statements.rs index 712df6873d90c..a5d671ca53cd1 100644 --- a/crates/polars-sql/tests/statements.rs +++ b/crates/polars-sql/tests/statements.rs @@ -188,11 +188,17 @@ fn iss_9560_join_as() { let expected = df! { "id" => [1, 2, 3, 4], "ano" => [2, 3, 4, 5], - "ano_right" => [2, 3, 4, 5], + "id:t2" => [1, 2, 3, 4], + "ano:t2" => [2, 3, 4, 5], } .unwrap(); - assert!(actual.equals(&expected)); + assert!( + actual.equals(&expected), + "expected = {:?}\nactual={:?}", + expected, + actual + ); } fn prepare_compound_join_context() -> SQLContext { @@ -206,11 +212,10 @@ fn prepare_compound_join_context() -> SQLContext { "b" => [0, 3, 4, 5, 6] } .unwrap(); - let df3 = df! { "a" => [1, 2, 3, 4, 5], - "b" => [0, 3, 4, 5, 6], - "c" => [0, 3, 4, 5, 6] + "b" => [0, 3, 4, 5, 7], + "c" => [1, 3, 4, 5, 7] } .unwrap(); let mut ctx = SQLContext::new(); @@ -232,6 +237,8 @@ fn test_compound_join_basic() { let expected = df! { "a" => [2, 3], "b" => [3, 4], + "a:df2" => [2, 3], + "b:df2" => [3, 4], } .unwrap(); @@ -253,18 +260,21 @@ fn test_compound_join_different_column_names() { let df2 = df! { "a" => [0, 2, 3, 4, 5], "b" => [1, 2, 3, 5, 6], - "c" => [7, 8, 9, 10, 11], + "c" => [7, 5, 3, 5, 7], } .unwrap(); + let mut ctx = SQLContext::new(); ctx.register("df1", df1.lazy()); ctx.register("df2", df2.lazy()); let sql = r#" - SELECT * FROM df1 INNER JOIN df2 ON df1.a = df2.b AND df1.b = df2.a + SELECT df1.a, df2.b, df2.c + FROM df1 INNER JOIN df2 + ON df1.a = df2.b AND df1.a = df2.c + ORDER BY a "#; let actual = ctx.execute(sql).unwrap().collect().unwrap(); - let expected = df! { "a" => [2, 3], "b" => [2, 3], @@ -284,11 +294,11 @@ fn test_compound_join_different_column_names() { fn test_compound_join_three_tables() { let mut ctx = prepare_compound_join_context(); let sql = r#" - SELECT * FROM df1 - INNER JOIN df2 - ON df1.a = df2.a AND df1.b = df2.b - INNER JOIN df3 - ON df1.a = df3.a AND df1.b = df3.b + SELECT df3.* FROM df1 + INNER JOIN df2 + ON df1.a = df2.a AND df1.b = df2.b + INNER JOIN df3 + ON df3.a = df1.a AND df3.b = df1.b "#; let actual = ctx.execute(sql).unwrap().collect().unwrap(); @@ -327,30 +337,84 @@ fn test_compound_join_nested_and() { ctx.register("df1", df1.lazy()); ctx.register("df2", df2.lazy()); - let sql = r#" - SELECT * FROM df1 - INNER JOIN df2 ON - df1.a = df2.a AND - df1.b = df2.b AND - df1.c = df2.c AND - df1.d = df2.d - "#; - let actual = ctx.execute(sql).unwrap().collect().unwrap(); + for cols in [ + "df1.*", + "df2.*", + "df1.a, df1.b, df2.c, df2.d", + "df2.a, df2.b, df1.c, df1.d", + ] { + let sql = format!( + r#" + SELECT {} FROM df1 + INNER JOIN df2 ON + df1.a = df2.a AND + df1.b = df2.b AND + df1.c = df2.c AND + df1.d = df2.d + "#, + cols + ); + let actual = ctx.execute(sql.as_str()).unwrap().collect().unwrap(); + let expected = df! { + "a" => [1, 3], + "b" => [1, 3], + "c" => [0, 4], + "d" => [0, 4], + } + .unwrap(); + + assert!( + actual.equals(&expected), + "expected = {:?}\nactual={:?}", + expected, + actual + ); + } +} - let expected = df! { - "a" => [1, 3], - "b" => [1, 3], - "c" => [0, 4], - "d" => [0, 4], +#[test] +fn test_resolve_join_column_select_13618() { + let df1 = df! { + "A" => [1, 2, 3, 4, 5], + "B" => [5, 4, 3, 2, 1], + "fruits" => ["banana", "banana", "apple", "apple", "banana"], + "cars" => ["beetle", "audi", "beetle", "beetle", "beetle"], } .unwrap(); + let df2 = df1.clone(); - assert!( - actual.equals(&expected), - "expected = {:?}\nactual={:?}", - expected, - actual - ); + let mut ctx = SQLContext::new(); + ctx.register("tbl", df1.lazy()); + ctx.register("other", df2.lazy()); + + let join_types = vec!["LEFT", "INNER", "FULL OUTER", ""]; + for join_type in join_types { + let sql = format!( + r#" + SELECT tbl.A, other.B, tbl.fruits, other.cars + FROM tbl + {} JOIN other ON tbl.A = other.B + ORDER BY tbl.A ASC + "#, + join_type + ); + let actual = ctx.execute(sql.as_str()).unwrap().collect().unwrap(); + let expected = df! { + "A" => [1, 2, 3, 4, 5], + "B" => [1, 2, 3, 4, 5], + "fruits" => ["banana", "banana", "apple", "apple", "banana"], + "cars" => ["beetle", "beetle", "beetle", "audi", "beetle"], + } + .unwrap(); + + assert!( + actual.equals(&expected), + "({} JOIN) expected = {:?}\nactual={:?}", + join_type, + expected, + actual + ); + } } #[test] @@ -360,13 +424,15 @@ fn test_compound_join_nested_and_with_brackets() { "b" => [1, 2, 3, 4, 5], "c" => [0, 3, 4, 5, 6], "d" => [0, 3, 4, 5, 6], + "e" => ["a", "b", "c", "d", "?"], } .unwrap(); let df2 = df! { "a" => [1, 2, 3, 4, 5], "b" => [1, 3, 3, 5, 6], "c" => [0, 3, 4, 5, 6], - "d" => [0, 3, 4, 5, 6] + "d" => [0, 3, 4, 5, 6], + "e" => ["w", "x", "y", "z", "!"], } .unwrap(); let mut ctx = SQLContext::new(); @@ -374,12 +440,10 @@ fn test_compound_join_nested_and_with_brackets() { ctx.register("df2", df2.lazy()); let sql = r#" - SELECT * FROM df1 - INNER JOIN df2 ON - df1.a = df2.a AND - ((df1.b = df2.b AND - df1.c = df2.c) AND - df1.d = df2.d) + SELECT df1.* EXCLUDE "e", df2.e + FROM df1 + INNER JOIN df2 ON df1.a = df2.a AND + ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) "#; let actual = ctx.execute(sql).unwrap().collect().unwrap(); @@ -388,6 +452,7 @@ fn test_compound_join_nested_and_with_brackets() { "b" => [1, 3], "c" => [0, 4], "d" => [0, 4], + "e" => ["w", "y"], } .unwrap(); @@ -397,6 +462,77 @@ fn test_compound_join_nested_and_with_brackets() { expected, actual ); + + let sql = r#" + SELECT * EXCLUDE ("e", "e:df2"), df1.e + FROM df1 + INNER JOIN df2 ON df1.a = df2.a AND + ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + "#; + let actual = ctx.execute(sql).unwrap().collect().unwrap(); + + let expected = df! { + "a" => [1, 3], + "b" => [1, 3], + "c" => [0, 4], + "d" => [0, 4], + "a:df2" => [1, 3], + "b:df2" => [1, 3], + "c:df2" => [0, 4], + "d:df2" => [0, 4], + "e" => ["a", "c"], + } + .unwrap(); + + assert!( + actual.equals(&expected), + "expected = {:?}\nactual={:?}", + expected, + actual + ); +} + +#[test] +fn test_join_subquery_utf8() { + // (色) color and (野菜) vegetable + let df1 = df! { + "色" => ["赤", "緑", "黄色"], + "野菜" => ["トマト", "ケール", "コーン"], + } + .unwrap(); + + // (色) color and (動物) animal + let df2 = df! { + "色" => ["黄色", "緑", "赤"], + "動物" => ["ゴシキヒワ", "蛙", "レッサーパンダ"], + } + .unwrap(); + + let mut ctx = SQLContext::new(); + ctx.register("df1", df1.lazy()); + ctx.register("df2", df2.lazy()); + + let expected = df! { + "色" => ["黄色", "緑", "赤"], + "野菜" => ["コーン", "ケール", "トマト"], + "動物" => ["ゴシキヒワ", "蛙", "レッサーパンダ"], + } + .unwrap(); + + let sql = r#" + SELECT df1.*, df2.動物 + FROM df1 + INNER JOIN (SELECT 動物, 色 FROM df2) AS df2 + ON df1.色 = df2.色 + "#; + let actual = ctx.execute(sql).unwrap().collect().unwrap(); + + assert!( + actual.equals(&expected), + "expected = {:?}\nactual={:?}", + expected, + actual + ); } #[test] diff --git a/docs/src/rust/user-guide/transformations/joins.rs b/docs/src/rust/user-guide/transformations/joins.rs index cc6d7ec9cb6a8..26e4a2a067a60 100644 --- a/docs/src/rust/user-guide/transformations/joins.rs +++ b/docs/src/rust/user-guide/transformations/joins.rs @@ -96,7 +96,7 @@ fn main() -> Result<(), Box> { let df_cross_join = df_colors .clone() .lazy() - .cross_join(df_sizes.clone().lazy()) + .cross_join(df_sizes.clone().lazy(), None) .collect()?; println!("{}", &df_cross_join); // --8<-- [end:cross] diff --git a/py-polars/tests/unit/sql/test_joins.py b/py-polars/tests/unit/sql/test_joins.py index a498cbb6629e2..a90ac39790479 100644 --- a/py-polars/tests/unit/sql/test_joins.py +++ b/py-polars/tests/unit/sql/test_joins.py @@ -1,5 +1,6 @@ from __future__ import annotations +from io import BytesIO from pathlib import Path import pytest @@ -61,6 +62,46 @@ def test_join_anti_semi(sql: str, expected: pl.DataFrame) -> None: assert_frame_equal(expected, ctx.execute(sql)) +def test_join_cross() -> None: + frames = { + "tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, 0, 6], "c": ["w", "y", "z"]}), + "tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}), + } + with pl.SQLContext(frames, eager_execution=True) as ctx: + out = ctx.execute( + """ + SELECT * + FROM tbl_a + CROSS JOIN tbl_b + ORDER BY a, b, c + """ + ) + assert out.rows() == [ + (1, 4, "w", 3, 6, "x"), + (1, 4, "w", 2, 5, "y"), + (1, 4, "w", 1, 4, "z"), + (2, 0, "y", 3, 6, "x"), + (2, 0, "y", 2, 5, "y"), + (2, 0, "y", 1, 4, "z"), + (3, 6, "z", 3, 6, "x"), + (3, 6, "z", 2, 5, "y"), + (3, 6, "z", 1, 4, "z"), + ] + + +def test_join_cross_13618() -> None: + df1 = pl.DataFrame({"id": [1, 2, 3]}) + df2 = pl.DataFrame({"id": [3, 4, 5]}) # noqa: F841 + res = df1.sql( + """ + SELECT df2.id + FROM self CROSS JOIN df2 + WHERE self.id = df2.id + """, + ) + assert_frame_equal(res, pl.DataFrame({"id": [3]})) + + @pytest.mark.parametrize( "join_clause", [ @@ -87,9 +128,10 @@ def test_join_inner(foods_ipc_path: Path, join_clause: str) -> None: "calories": [45, 20], "fats_g": [0.5, 0.0], "sugars_g": [2, 2], - "calories_right": [45, 45], - "fats_g_right": [0.5, 0.5], - "sugars_g_right": [2, 2], + "category:foods2": ["vegetables", "vegetables"], + "calories:foods2": [45, 45], + "fats_g:foods2": [0.5, 0.5], + "sugars_g:foods2": [2, 2], } @@ -178,6 +220,43 @@ def test_join_left_multi_nested() -> None: ] +def test_join_misc_13618() -> None: + import polars as pl + + df = pl.DataFrame( + { + "A": [1, 2, 3, 4, 5], + "B": [5, 4, 3, 2, 1], + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + } + ) + res = ( + pl.SQLContext(t=df, t1=df, eager_execution=True) + .execute("SELECT t.A, t.fruits, t1.B, t1.cars FROM t JOIN t1 ON t.A=t1.B") + .to_dict(as_series=False) + ) + assert res == { + "A": [5, 4, 3, 2, 1], + "fruits": ["banana", "apple", "apple", "banana", "banana"], + "B": [5, 4, 3, 2, 1], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + } + + +def test_join_misc_16255() -> None: + df1 = pl.read_csv(BytesIO(b"id,data\n1,open")) + df2 = pl.read_csv(BytesIO(b"id,data\n1,closed")) # noqa: F841 + res = df1.sql( + """ + SELECT a.id, a.data AS d1, b.data AS d2 + FROM self AS a JOIN df2 AS b + ON a.id = b.id + """ + ) + assert res.rows() == [(1, "open", "closed")] + + @pytest.mark.parametrize( "constraint", ["tbl.a != tbl.b", "tbl.a > tbl.b", "a >= b", "a < b", "b <= a"] ) diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index d32674db78fff..f5d119b5c48b2 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -180,7 +180,7 @@ def test_order_by(foods_ipc_path: Path) -> None: df.x, df.y as y_alias FROM df - ORDER BY y + ORDER BY y_alias """, eager=True, ) diff --git a/py-polars/tests/unit/sql/test_subqueries.py b/py-polars/tests/unit/sql/test_subqueries.py index 7e9fc0e124d60..651e325a4b992 100644 --- a/py-polars/tests/unit/sql/test_subqueries.py +++ b/py-polars/tests/unit/sql/test_subqueries.py @@ -5,30 +5,18 @@ def test_join_on_subquery() -> None: - df1 = pl.DataFrame( - { - "x": [-1, 0, 1, 2, 3, 4], - } - ) - - df2 = pl.DataFrame( - { - "y": [0, 1, 2, 3], - } - ) + df1 = pl.DataFrame({"x": [-1, 0, 1, 2, 3, 4]}) + df2 = pl.DataFrame({"y": [0, 1, 2, 3]}) sql = pl.SQLContext(df1=df1, df2=df2) res = sql.execute( """ - SELECT - * - FROM df1 + SELECT * FROM df1 INNER JOIN (SELECT * FROM df2) AS df2 ON df1.x = df2.y """, eager=True, ) - df_expected_join = pl.DataFrame({"x": [0, 1, 2, 3]}) assert_frame_equal( left=res, @@ -37,30 +25,18 @@ def test_join_on_subquery() -> None: def test_from_subquery() -> None: - df1 = pl.DataFrame( - { - "x": [-1, 0, 1, 2, 3, 4], - } - ) - - df2 = pl.DataFrame( - { - "y": [0, 1, 2, 3], - } - ) + df1 = pl.DataFrame({"x": [-1, 0, 1, 2, 3, 4]}) + df2 = pl.DataFrame({"y": [0, 1, 2, 3]}) sql = pl.SQLContext(df1=df1, df2=df2) res = sql.execute( """ - SELECT - * - FROM (SELECT * FROM df1) AS df1 + SELECT * FROM (SELECT * FROM df1) AS df1 INNER JOIN (SELECT * FROM df2) AS df2 ON df1.x = df2.y """, eager=True, ) - df_expected_join = pl.DataFrame({"x": [0, 1, 2, 3]}) assert_frame_equal( left=res, @@ -75,14 +51,12 @@ def test_in_subquery() -> None: "y": [2, 3, 4, 5, 6, 7], } ) - df_other = pl.DataFrame( { "w": [1, 2, 3, 4, 5, 6], "z": [2, 3, 4, 5, 6, 7], } ) - df_chars = pl.DataFrame( { "one": ["a", "b", "c", "d", "e", "f"], diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 4ca3851a09d50..b71bbdc18e012 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -62,7 +62,7 @@ def test_overflow_uint16_agg_mean() -> None: pl.DataFrame( { "col1": ["A" for _ in range(1025)], - "col3": [64 for i in range(1025)], + "col3": [64 for _ in range(1025)], } ) .with_columns(