Skip to content

Commit

Permalink
fix: Fix use of COUNT(*) in SQL GROUP BY operations
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed May 24, 2024
1 parent d5f9c3b commit e611fa2
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 19 deletions.
19 changes: 13 additions & 6 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,11 @@ impl SQLContext {
// nested agg/window funcs to the group key (also ignores literals).
GroupByExpr::All => {
projections.iter().for_each(|expr| match expr {
// immediately match the most common cases (col|agg|lit, optionally aliased).
Expr::Agg(_) | Expr::Literal(_) => (),
// immediately match the most common cases (col|agg|len|lit, optionally aliased).
Expr::Agg(_) | Expr::Len | Expr::Literal(_) => (),
Expr::Column(_) => group_by_keys.push(expr.clone()),
Expr::Alias(e, _) if matches!(&**e, Expr::Agg(_) | Expr::Literal(_)) => (),
Expr::Alias(e, _)
if matches!(&**e, Expr::Agg(_) | Expr::Len | Expr::Literal(_)) => {},
Expr::Alias(e, _) if matches!(&**e, Expr::Column(_)) => {
if let Expr::Column(name) = &**e {
group_by_keys.push(col(name));
Expand All @@ -457,7 +458,9 @@ impl SQLContext {
_ => {
// If not quick-matched, add if no nested agg/window expressions
if !has_expr(expr, |e| {
matches!(e, Expr::Agg(_)) || matches!(e, Expr::Window { .. })
matches!(e, Expr::Agg(_))
|| matches!(e, Expr::Len)
|| matches!(e, Expr::Window { .. })
}) {
group_by_keys.push(expr.clone())
}
Expand Down Expand Up @@ -734,8 +737,12 @@ impl SQLContext {
let mut group_key_aliases = PlHashSet::new();

for mut e in projections {
// If simple aliased expression we defer aliasing until after the group_by.
let is_agg_or_window = has_expr(e, |e| matches!(e, Expr::Agg(_) | Expr::Window { .. }));
// `Len` represents COUNT(*) so we treat as an aggregation here.
let is_agg_or_window = has_expr(e, |e| {
matches!(e, Expr::Agg(_) | Expr::Len | Expr::Window { .. })
});

// Note: if simple aliased expression we defer aliasing until after the group_by.
if let Expr::Alias(expr, alias) = e {
if e.clone().meta().is_simple_projection() {
group_key_aliases.insert(alias.as_ref());
Expand Down
6 changes: 2 additions & 4 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1235,16 +1235,14 @@ impl SQLFunctionVisitor<'_> {
fn visit_count(&mut self) -> PolarsResult<Expr> {
let args = extract_args(self.func);
match (self.func.distinct, args.as_slice()) {
// count()
(false, []) => Ok(len()),
// count(*), count()
(false, [FunctionArgExpr::Wildcard] | []) => Ok(len()),
// count(column_name)
(false, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx, None)?;
let expr = self.apply_window_spec(expr, &self.func.over)?;
Ok(expr.count())
},
// count(*)
(false, [FunctionArgExpr::Wildcard]) => Ok(len()),
// count(distinct column_name)
(true, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx, None)?;
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-sql/tests/simple_exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ fn test_group_by_simple() -> PolarsResult<()> {
a AS "aa",
SUM(b) AS "bb",
SUM(a + b) AS "cc",
COUNT(a) AS "total_count"
COUNT(a) AS "count_a",
COUNT(*) AS "count_star"
FROM df
GROUP BY a
LIMIT 100
Expand All @@ -81,7 +82,8 @@ fn test_group_by_simple() -> PolarsResult<()> {
.agg(&[
col("b").sum().alias("bb"),
(col("a") + col("b")).sum().alias("cc"),
col("a").count().alias("total_count"),
col("a").count().alias("count_a"),
col("a").len().alias("count_star"),
])
.limit(100)
.sort(["aa"], Default::default())
Expand Down
4 changes: 4 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,6 +3485,10 @@ def write_database(
"""
Write the data in a Polars DataFrame to a database.
.. versionadded:: 0.20.26
Support for instantiated connection objects in addition to URI strings, and
a new `engine_options` parameter.
Parameters
----------
table_name
Expand Down
5 changes: 5 additions & 0 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,11 @@ def by_index(*indices: int | range | Sequence[int | range]) -> SelectorType:
One or more column indices (or range objects).
Negative indexing is supported.
Notes
-----
Matching columns are returned in the order in which their indexes
appear in the selector, not the underlying schema order.
See Also
--------
by_dtype : Select all columns matching the given dtypes.
Expand Down
27 changes: 20 additions & 7 deletions py-polars/tests/unit/sql/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def test_group_by_all() -> None:
SELECT
a,
SUM(b),
SUM(c)
SUM(c),
COUNT(*) AS n
FROM self
GROUP BY ALL
ORDER BY a
Expand All @@ -93,9 +94,10 @@ def test_group_by_all() -> None:
"a": ["xx", "yy", "zz"],
"b": [9, 6, 6],
"c": [231, 165, 66],
"n": [3, 2, 1],
}
)
assert_frame_equal(expected, res)
assert_frame_equal(expected, res, check_dtype=False)

# more involved determination of agg/group columns
res = df.sql(
Expand Down Expand Up @@ -170,22 +172,33 @@ def test_group_by_ordinal_position() -> None:
df = pl.DataFrame(
{
"a": ["xx", "yy", "xx", "yy", "xx", "zz"],
"b": [1, 2, 3, 4, 5, 6],
"b": [1, None, 3, 4, 5, 6],
"c": [99, 99, 66, 66, 66, 66],
}
)
expected = pl.LazyFrame({"c": [66, 99], "total_b": [18, 3]})
expected = pl.LazyFrame(
{
"c": [66, 99],
"total_b": [18, 1],
"count_b": [4, 1],
"count_star": [4, 2],
}
)

with pl.SQLContext(frame=df) as ctx:
res1 = ctx.execute(
"""
SELECT c, SUM(b) AS total_b
SELECT
c,
SUM(b) AS total_b,
COUNT(b) AS count_b,
COUNT(*) AS count_star
FROM frame
GROUP BY 1
ORDER BY c
"""
)
assert_frame_equal(res1, expected)
assert_frame_equal(res1, expected, check_dtype=False)

res2 = ctx.execute(
"""
Expand All @@ -196,7 +209,7 @@ def test_group_by_ordinal_position() -> None:
)
SELECT c, total_b FROM grp ORDER BY c"""
)
assert_frame_equal(res2, expected)
assert_frame_equal(res2, expected.select(expected.columns[:2]))


def test_group_by_errors() -> None:
Expand Down

0 comments on commit e611fa2

Please sign in to comment.