From 12b40b9394b9823ae7daeca222b056c09b78aa2f Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Wed, 8 May 2024 15:12:30 +0400 Subject: [PATCH] fix: Ensure hex and bitstring literals work inside SQL `IN` clauses (#16101) --- crates/polars-sql/src/sql_expr.rs | 57 ++++++++++++++--------- py-polars/tests/unit/sql/test_literals.py | 34 ++++++++++++-- 2 files changed, 67 insertions(+), 24 deletions(-) diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index a2655caf7342..b20fde159b4f 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -581,24 +581,10 @@ impl SQLExprVisitor<'_> { .map_err(|_| polars_err!(ComputeError: "cannot parse literal: {:?}", s))? }, SQLValue::SingleQuotedByteStringLiteral(b) => { - // note: for PostgreSQL this syntax represents a BIT string literal (eg: b'10101') not a BYTE - // string literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser + // note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string + // literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser // patterned the token name after BigQuery (where b'str' really IS a byte string) - if !b.chars().all(|c| c == '0' || c == '1') { - polars_bail!(ComputeError: "bit string literal should contain only 0s and 1s; found '{}'", b) - } - let n_bits = b.len(); - let s = b.as_str(); - lit(match n_bits { - 0 => b"".to_vec(), - 1..=8 => u8::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - 9..=16 => u16::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - 17..=32 => u32::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - 33..=64 => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - _ => { - polars_bail!(ComputeError: "cannot parse bit string literal with len > 64 (len={:?})", n_bits) - }, - }) + bitstring_to_bytes_literal(b)? }, SQLValue::SingleQuotedString(s) => lit(s.clone()), other => polars_bail!(ComputeError: "SQL value {:?} is not yet supported", other), @@ -635,10 +621,24 @@ impl SQLExprVisitor<'_> { } .map_err(|_| polars_err!(ComputeError: "cannot parse literal: {s:?}"))? }, - SQLValue::SingleQuotedString(s) - | SQLValue::NationalStringLiteral(s) - | SQLValue::HexStringLiteral(s) - | SQLValue::DoubleQuotedString(s) => AnyValue::StringOwned(s.into()), + #[cfg(feature = "binary_encoding")] + SQLValue::HexStringLiteral(x) => { + if x.len() % 2 != 0 { + polars_bail!(ComputeError: "hex string literal must have an even number of digits; found '{}'", x) + }; + AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap()) + }, + SQLValue::SingleQuotedByteStringLiteral(b) => { + // note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE + let bytes_literal = bitstring_to_bytes_literal(b)?; + match bytes_literal { + Expr::Literal(LiteralValue::Binary(v)) => AnyValue::BinaryOwned(v.to_vec()), + _ => polars_bail!(ComputeError: "failed to parse bitstring literal: {:?}", b), + } + }, + SQLValue::SingleQuotedString(s) | SQLValue::DoubleQuotedString(s) => { + AnyValue::StringOwned(s.into()) + }, other => polars_bail!(ComputeError: "SQL value {:?} is not yet supported", other), }) } @@ -1107,3 +1107,18 @@ pub(crate) fn parse_date_part(expr: Expr, part: &str) -> PolarsResult { }, ) } + +fn bitstring_to_bytes_literal(b: &String) -> PolarsResult { + let n_bits = b.len(); + if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 { + polars_bail!(ComputeError: "bit string literal should contain only 0s and 1s and have length <= 64; found '{}' with length {}", b, n_bits) + } + let s = b.as_str(); + Ok(lit(match n_bits { + 0 => b"".to_vec(), + 1..=8 => u8::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + 9..=16 => u16::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + 17..=32 => u32::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + _ => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + })) +} diff --git a/py-polars/tests/unit/sql/test_literals.py b/py-polars/tests/unit/sql/test_literals.py index 0f24963e6c64..f46b347fa547 100644 --- a/py-polars/tests/unit/sql/test_literals.py +++ b/py-polars/tests/unit/sql/test_literals.py @@ -6,7 +6,7 @@ from polars.exceptions import ComputeError -def test_bin_hex_literals() -> None: +def test_bit_hex_literals() -> None: with pl.SQLContext(df=None, eager_execution=True) as ctx: out = ctx.execute( """ @@ -37,7 +37,7 @@ def test_bin_hex_literals() -> None: } -def test_bin_hex_filter() -> None: +def test_bit_hex_filter() -> None: df = pl.DataFrame( {"bin": [b"\x01", b"\x02", b"\x03", b"\x04"], "val": [9, 8, 7, 6]} ) @@ -47,7 +47,7 @@ def test_bin_hex_filter() -> None: assert out.to_series().to_list() == [7, 6] -def test_bin_hex_errors() -> None: +def test_bit_hex_errors() -> None: with pl.SQLContext(test=None) as ctx: with pytest.raises( ComputeError, @@ -60,3 +60,31 @@ def test_bin_hex_errors() -> None: match="hex string literal must have an even number of digits", ): ctx.execute("SELECT x'00F' FROM test", eager=True) + + with pytest.raises( + ComputeError, + match="hex string literal must have an even number of digits", + ): + pl.sql_expr("colx IN (x'FF',x'123')") + + with pytest.raises( + ComputeError, + match=r'NationalStringLiteral\("hmmm"\) is not yet supported', + ): + pl.sql_expr("N'hmmm'") + + +def test_bit_hex_membership() -> None: + df = pl.DataFrame( + { + "x": [b"\x05", b"\xff", b"\xcc", b"\x0b"], + "y": [1, 2, 3, 4], + } + ) + # this checks the internal `visit_any_value` codepath + for values in ( + "b'0101', b'1011'", + "x'05', x'0b'", + ): + dff = df.filter(pl.sql_expr(f"x IN ({values})")) + assert dff["y"].to_list() == [1, 4]