Skip to content

Commit

Permalink
fix: Ensure hex and bitstring literals work inside SQL IN clauses (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored May 8, 2024
1 parent ddc30ab commit 12b40b9
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 24 deletions.
57 changes: 36 additions & 21 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -581,24 +581,10 @@ impl SQLExprVisitor<'_> {
.map_err(|_| polars_err!(ComputeError: "cannot parse literal: {:?}", s))?
},
SQLValue::SingleQuotedByteStringLiteral(b) => {
// note: for PostgreSQL this syntax represents a BIT string literal (eg: b'10101') not a BYTE
// string literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser
// note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string
// literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser
// patterned the token name after BigQuery (where b'str' really IS a byte string)
if !b.chars().all(|c| c == '0' || c == '1') {
polars_bail!(ComputeError: "bit string literal should contain only 0s and 1s; found '{}'", b)
}
let n_bits = b.len();
let s = b.as_str();
lit(match n_bits {
0 => b"".to_vec(),
1..=8 => u8::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
9..=16 => u16::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
17..=32 => u32::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
33..=64 => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
_ => {
polars_bail!(ComputeError: "cannot parse bit string literal with len > 64 (len={:?})", n_bits)
},
})
bitstring_to_bytes_literal(b)?
},
SQLValue::SingleQuotedString(s) => lit(s.clone()),
other => polars_bail!(ComputeError: "SQL value {:?} is not yet supported", other),
Expand Down Expand Up @@ -635,10 +621,24 @@ impl SQLExprVisitor<'_> {
}
.map_err(|_| polars_err!(ComputeError: "cannot parse literal: {s:?}"))?
},
SQLValue::SingleQuotedString(s)
| SQLValue::NationalStringLiteral(s)
| SQLValue::HexStringLiteral(s)
| SQLValue::DoubleQuotedString(s) => AnyValue::StringOwned(s.into()),
#[cfg(feature = "binary_encoding")]
SQLValue::HexStringLiteral(x) => {
if x.len() % 2 != 0 {
polars_bail!(ComputeError: "hex string literal must have an even number of digits; found '{}'", x)
};
AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
},
SQLValue::SingleQuotedByteStringLiteral(b) => {
// note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE
let bytes_literal = bitstring_to_bytes_literal(b)?;
match bytes_literal {
Expr::Literal(LiteralValue::Binary(v)) => AnyValue::BinaryOwned(v.to_vec()),
_ => polars_bail!(ComputeError: "failed to parse bitstring literal: {:?}", b),
}
},
SQLValue::SingleQuotedString(s) | SQLValue::DoubleQuotedString(s) => {
AnyValue::StringOwned(s.into())
},
other => polars_bail!(ComputeError: "SQL value {:?} is not yet supported", other),
})
}
Expand Down Expand Up @@ -1107,3 +1107,18 @@ pub(crate) fn parse_date_part(expr: Expr, part: &str) -> PolarsResult<Expr> {
},
)
}

fn bitstring_to_bytes_literal(b: &String) -> PolarsResult<Expr> {
let n_bits = b.len();
if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 {
polars_bail!(ComputeError: "bit string literal should contain only 0s and 1s and have length <= 64; found '{}' with length {}", b, n_bits)
}
let s = b.as_str();
Ok(lit(match n_bits {
0 => b"".to_vec(),
1..=8 => u8::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
9..=16 => u16::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
17..=32 => u32::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
_ => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
}))
}
34 changes: 31 additions & 3 deletions py-polars/tests/unit/sql/test_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from polars.exceptions import ComputeError


def test_bin_hex_literals() -> None:
def test_bit_hex_literals() -> None:
with pl.SQLContext(df=None, eager_execution=True) as ctx:
out = ctx.execute(
"""
Expand Down Expand Up @@ -37,7 +37,7 @@ def test_bin_hex_literals() -> None:
}


def test_bin_hex_filter() -> None:
def test_bit_hex_filter() -> None:
df = pl.DataFrame(
{"bin": [b"\x01", b"\x02", b"\x03", b"\x04"], "val": [9, 8, 7, 6]}
)
Expand All @@ -47,7 +47,7 @@ def test_bin_hex_filter() -> None:
assert out.to_series().to_list() == [7, 6]


def test_bin_hex_errors() -> None:
def test_bit_hex_errors() -> None:
with pl.SQLContext(test=None) as ctx:
with pytest.raises(
ComputeError,
Expand All @@ -60,3 +60,31 @@ def test_bin_hex_errors() -> None:
match="hex string literal must have an even number of digits",
):
ctx.execute("SELECT x'00F' FROM test", eager=True)

with pytest.raises(
ComputeError,
match="hex string literal must have an even number of digits",
):
pl.sql_expr("colx IN (x'FF',x'123')")

with pytest.raises(
ComputeError,
match=r'NationalStringLiteral\("hmmm"\) is not yet supported',
):
pl.sql_expr("N'hmmm'")


def test_bit_hex_membership() -> None:
df = pl.DataFrame(
{
"x": [b"\x05", b"\xff", b"\xcc", b"\x0b"],
"y": [1, 2, 3, 4],
}
)
# this checks the internal `visit_any_value` codepath
for values in (
"b'0101', b'1011'",
"x'05', x'0b'",
):
dff = df.filter(pl.sql_expr(f"x IN ({values})"))
assert dff["y"].to_list() == [1, 4]

0 comments on commit 12b40b9

Please sign in to comment.