From 5c12ba454f560b88f427f4c25775c21168857395 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 23 Jun 2024 23:52:32 +0400 Subject: [PATCH] support `replace` & `rename` select wildcard options --- .../src/executors/projection_utils.rs | 2 +- .../src/plans/conversion/dsl_to_ir.rs | 2 +- crates/polars-sql/src/context.rs | 95 ++++++++++++++----- crates/polars-sql/tests/statements.rs | 26 +++-- .../tests/unit/sql/test_wildcard_opts.py | 79 +++++++++++++++ 5 files changed, 167 insertions(+), 37 deletions(-) create mode 100644 py-polars/tests/unit/sql/test_wildcard_opts.py diff --git a/crates/polars-mem-engine/src/executors/projection_utils.rs b/crates/polars-mem-engine/src/executors/projection_utils.rs index 125b774cf935..1ca7e085bfa3 100644 --- a/crates/polars-mem-engine/src/executors/projection_utils.rs +++ b/crates/polars-mem-engine/src/executors/projection_utils.rs @@ -263,7 +263,7 @@ pub(super) fn check_expand_literals( if duplicate_check && !names.insert(name) { let msg = format!( - "the name: '{}' is duplicate\n\n\ + "the name '{}' is duplicate\n\n\ It's possible that multiple expressions are returning the same default column \ name. If this is the case, try renaming the columns with \ `.alias(\"new_name\")` to avoid duplicate column names.", diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index cd273a6c6ddc..a05417dbfd24 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -707,7 +707,7 @@ fn resolve_with_columns( if !output_names.insert(field.name().clone()) { let msg = format!( - "the name: '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\ + "the name '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\ It's possible that multiple expressions are returning the same default column name. \ If this is the case, try renaming the columns with `.alias(\"new_name\")` to avoid \ duplicate column names.", diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 0fc38925c2bb..ea86de06f4e5 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -9,9 +9,9 @@ use polars_plan::dsl::function_expr::StructFunction; use polars_plan::prelude::*; use sqlparser::ast::{ Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint, - JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, - SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, - Value as SQLValue, Values, WildcardAdditionalOptions, + JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, RenameSelectItem, Select, + SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, + TableWithJoins, UnaryOperator, Value as SQLValue, Values, WildcardAdditionalOptions, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; @@ -590,13 +590,11 @@ impl SQLContext { /// Execute the 'SELECT' part of the query. fn execute_select(&mut self, select_stmt: &Select, query: &Query) -> PolarsResult { - // Determine involved dataframes. - // Note: implicit joins require more work in query parsing, - // explicit joins are preferred for now (ref: #16662) - let mut lf = if select_stmt.from.is_empty() { DataFrame::empty().lazy() } else { + // Note: implicit joins need more work to support properly, + // explicit joins are preferred for now (ref: #16662) let from = select_stmt.clone().from; if from.len() > 1 { polars_bail!(SQLInterface: "multiple tables in FROM clause are not currently supported (found {}); use explicit JOIN syntax instead", from.len()) @@ -604,12 +602,16 @@ impl SQLContext { self.execute_from_statement(from.first().unwrap())? }; - // Filter expression. + // Filter expression (WHERE clause) let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; lf = self.process_where(lf, &select_stmt.selection)?; - // Column projections. - let mut excluded_cols = Vec::new(); + // 'SELECT *' modifiers + let mut excluded_cols = vec![]; + let mut replace_exprs = vec![]; + let mut rename_cols = (&mut vec![], &mut vec![]); + + // Column projections (SELECT clause) let projections: Vec = select_stmt .projection .iter() @@ -627,6 +629,8 @@ impl SQLContext { obj_name, wildcard_options, &mut excluded_cols, + &mut rename_cols, + &mut replace_exprs, Some(schema.deref()), )?, SelectItem::Wildcard(wildcard_options) => { @@ -639,6 +643,9 @@ impl SQLContext { cols, wildcard_options, &mut excluded_cols, + &mut rename_cols, + &mut replace_exprs, + Some(schema.deref()), )? }, }) @@ -700,8 +707,8 @@ impl SQLContext { // No sort, select cols as given lf.select(projections) } else { - // Add all projections to the base frame as any of - // the original columns may be required for the sort + // Add projections to the base frame as any of the + // original columns may be required for the sort lf = lf.with_columns(projections.clone()); // Final/selected cols (also ensures accurate ordinal position refs) @@ -737,7 +744,7 @@ impl SQLContext { } }; - // Apply optional 'distinct' clause. + // Apply optional DISTINCT clause. lf = match &select_stmt.distinct { Some(Distinct::Distinct) => lf.unique_stable(None, UniqueKeepStrategy::Any), Some(Distinct::On(exprs)) => { @@ -764,6 +771,13 @@ impl SQLContext { None => lf, }; + // Apply final 'SELECT *' modifiers + if !replace_exprs.is_empty() { + lf = lf.with_columns(replace_exprs); + } + if !rename_cols.0.is_empty() { + lf = lf.rename(rename_cols.0, rename_cols.1); + } Ok(lf) } @@ -1160,13 +1174,22 @@ impl SQLContext { ObjectName(idents): &ObjectName, options: &WildcardAdditionalOptions, excluded_cols: &mut Vec, + rename_cols: &mut (&mut Vec, &mut Vec), + replace_exprs: &mut Vec, schema: Option<&Schema>, ) -> PolarsResult> { let mut new_idents = idents.clone(); new_idents.push(Ident::new("*")); let expr = resolve_compound_identifier(self, new_idents.deref(), schema); - self.process_wildcard_additional_options(expr?, options, excluded_cols) + self.process_wildcard_additional_options( + expr?, + options, + excluded_cols, + rename_cols, + replace_exprs, + schema, + ) } fn process_wildcard_additional_options( @@ -1174,26 +1197,48 @@ impl SQLContext { exprs: Vec, options: &WildcardAdditionalOptions, excluded_cols: &mut Vec, + rename_cols: &mut (&mut Vec, &mut Vec), + replace_exprs: &mut Vec, + schema: Option<&Schema>, ) -> PolarsResult> { - // bail on unsupported wildcard options - if options.opt_ilike.is_some() { - polars_bail!(SQLSyntax: "ILIKE wildcard option is unsupported") - } else if options.opt_rename.is_some() { - polars_bail!(SQLSyntax: "RENAME wildcard option is unsupported") - } else if options.opt_replace.is_some() { - polars_bail!(SQLSyntax: "REPLACE wildcard option is unsupported") - } else if options.opt_except.is_some() { - polars_bail!(SQLSyntax: "EXCEPT wildcard option is unsupported (use EXCLUDE instead)") + // bail on (currently) unsupported wildcard options + if options.opt_except.is_some() { + polars_bail!(SQLInterface: "EXCEPT wildcard option is unsupported (use EXCLUDE instead)") + } else if options.opt_ilike.is_some() { + polars_bail!(SQLInterface: "ILIKE wildcard option is currently unsupported") + } else if options.opt_rename.is_some() && options.opt_replace.is_some() { + // pending an upstream fix: https://github.com/sqlparser-rs/sqlparser-rs/pull/1321 + polars_bail!(SQLInterface: "RENAME and REPLACE wildcard options cannot (yet) be used simultaneously") } - if let Some(exc_items) = &options.opt_exclude { - *excluded_cols = match exc_items { + if let Some(items) = &options.opt_exclude { + *excluded_cols = match items { ExcludeSelectItem::Single(ident) => vec![ident.value.clone()], ExcludeSelectItem::Multiple(idents) => { idents.iter().map(|i| i.value.clone()).collect() }, }; } + if let Some(items) = &options.opt_rename { + match items { + RenameSelectItem::Single(rename) => { + rename_cols.0.push(rename.ident.value.clone()); + rename_cols.1.push(rename.alias.value.clone()); + }, + RenameSelectItem::Multiple(renames) => { + for rn in renames { + rename_cols.0.push(rn.ident.value.clone()); + rename_cols.1.push(rn.alias.value.clone()); + } + }, + } + } + if let Some(replacements) = &options.opt_replace { + for rp in &replacements.items { + let replacement_expr = parse_sql_expr(&rp.expr, self, schema); + replace_exprs.push(replacement_expr?.alias(rp.column_name.value.as_str())); + } + } Ok(exprs) } diff --git a/crates/polars-sql/tests/statements.rs b/crates/polars-sql/tests/statements.rs index 0af4ae64fa86..dd1f89027c46 100644 --- a/crates/polars-sql/tests/statements.rs +++ b/crates/polars-sql/tests/statements.rs @@ -419,7 +419,7 @@ fn test_resolve_join_column_select_13618() { } #[test] -fn test_compound_join_nested_and_with_brackets() { +fn test_compound_join_and_select_exclude_rename_replace() { let df1 = df! { "a" => [1, 2, 3, 4, 5], "b" => [1, 2, 3, 4, 5], @@ -442,10 +442,13 @@ fn test_compound_join_nested_and_with_brackets() { ctx.register("df2", df2.lazy()); let sql = r#" - SELECT df1.* EXCLUDE "e", df2.e - FROM df1 - INNER JOIN df2 ON df1.a = df2.a AND - ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + SELECT * RENAME ("ee" AS "e") + FROM ( + SELECT df1.* EXCLUDE "e", df2.e AS "ee" + FROM df1 + INNER JOIN df2 ON df1.a = df2.a AND + ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + ) tbl "#; let actual = ctx.execute(sql).unwrap().collect().unwrap(); let expected = df! { @@ -465,10 +468,13 @@ fn test_compound_join_nested_and_with_brackets() { ); let sql = r#" - SELECT * EXCLUDE ("e", "e:df2"), df1.e - FROM df1 - INNER JOIN df2 ON df1.a = df2.a AND - ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + SELECT * REPLACE ("ee" || "ee" AS "ee") + FROM ( + SELECT * EXCLUDE ("e", "e:df2"), df1.e AS "ee" + FROM df1 + INNER JOIN df2 ON df1.a = df2.a AND + ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + ) tbl "#; let actual = ctx.execute(sql).unwrap().collect().unwrap(); @@ -481,7 +487,7 @@ fn test_compound_join_nested_and_with_brackets() { "b:df2" => [1, 3], "c:df2" => [0, 4], "d:df2" => [0, 4], - "e" => ["a", "c"], + "ee" => ["aa", "cc"], } .unwrap(); diff --git a/py-polars/tests/unit/sql/test_wildcard_opts.py b/py-polars/tests/unit/sql/test_wildcard_opts.py new file mode 100644 index 000000000000..ad17a215f7da --- /dev/null +++ b/py-polars/tests/unit/sql/test_wildcard_opts.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import DuplicateError + + +@pytest.fixture() +def df() -> pl.DataFrame: + return pl.DataFrame({"num": [999, 666], "str": ["b", "a"], "val": [2.0, 0.5]}) + + +@pytest.mark.parametrize( + ("excluded", "expected"), + [ + ("num", ["str", "val"]), + ("(val, num)", ["str"]), + ("(str, num)", ["val"]), + ("(str, val, num)", []), + ], +) +def test_select_exclude( + excluded: str, + expected: list[str], + df: pl.DataFrame, +) -> None: + assert df.sql(f"SELECT * EXCLUDE {excluded} FROM self").columns == expected + + +def test_select_exclude_error(df: pl.DataFrame) -> None: + with pytest.raises(DuplicateError, match="the name 'num' is duplicate"): + # note: missing "()" around the exclude option results in dupe col + assert df.sql("SELECT * EXCLUDE val, num FROM self") + + +@pytest.mark.parametrize( + ("renames", "expected"), + [ + ("val AS value", ["num", "str", "value"]), + ("(num AS flt)", ["flt", "str", "val"]), + ("(val AS value, num AS flt)", ["flt", "str", "value"]), + ], +) +def test_select_rename( + renames: str, + expected: list[str], + df: pl.DataFrame, +) -> None: + assert df.sql(f"SELECT * RENAME {renames} FROM self").columns == expected + + +@pytest.mark.parametrize( + ("replacements", "check_cols", "expected"), + [ + ( + "(num // 3 AS num)", + ["num"], + [(333,), (222,)], + ), + ( + "((str || str) AS str, num / 3 AS num)", + ["num", "str"], + [(333, "bb"), (222, "aa")], + ), + ], +) +def test_select_replace( + replacements: str, + check_cols: list[str], + expected: list[tuple[Any]], + df: pl.DataFrame, +) -> None: + res = df.sql(f"SELECT * REPLACE {replacements} FROM self") + + assert res.select(check_cols).rows() == expected + assert res.columns == df.columns