diff --git a/R/backend-dbplyr__duckdb_connection.R b/R/backend-dbplyr__duckdb_connection.R index c27f03f6f..ea259067f 100644 --- a/R/backend-dbplyr__duckdb_connection.R +++ b/R/backend-dbplyr__duckdb_connection.R @@ -307,6 +307,7 @@ sql_translation.duckdb_connection <- function(con) { ), sql_translator( .parent = base_agg, + prod = sql_aggregate("PRODUCT"), cor = sql_aggregate_2("CORR"), cov = sql_aggregate_2("COVAR_SAMP"), sd = sql_aggregate("STDDEV", "sd"), @@ -319,6 +320,7 @@ sql_translation.duckdb_connection <- function(con) { ), sql_translator( .parent = base_win, + prod = win_aggregate("PRODUCT"), cor = win_aggregate_2("CORR"), cov = win_aggregate_2("COVAR_SAMP"), sd = win_aggregate("STDDEV"), diff --git a/tests/testthat/test_backend-dbplyr__duckdb_connection.R b/tests/testthat/test_backend-dbplyr__duckdb_connection.R index 92b696ed9..8bba1c982 100644 --- a/tests/testthat/test_backend-dbplyr__duckdb_connection.R +++ b/tests/testthat/test_backend-dbplyr__duckdb_connection.R @@ -68,8 +68,6 @@ test_that("duckdb custom scalars translated correctly", { expect_equal(translate(as.POSIXct("2019-01-01 01:01:01")), sql(r"{CAST('2019-01-01 01:01:01' AS TIMESTAMP)}")) }) - - test_that("pasting translated correctly", { skip_if_no_R4() skip_if_not_installed("dbplyr") @@ -180,6 +178,36 @@ test_that("datetime escaping working as in DBI", { expect_equal(escape("2020-01-01 18:23:45 PST"), sql(r"{'2020-01-01 18:23:45 PST'}")) }) +test_that("aggregators translated correctly", { + skip_if_no_R4() + skip_if_not_installed("dbplyr") + con <- dbConnect(duckdb()) + on.exit(dbDisconnect(con, shutdown = TRUE)) + translate <- function(...) dbplyr::translate_sql(..., con = con) + sql <- function(...) dbplyr::sql(...) + + expect_equal(translate(sum(x), window = FALSE), sql(r"{SUM(x)}")) + expect_equal(translate(sum(x), window = TRUE), sql(r"{SUM(x) OVER ()}")) + + expect_equal(translate(prod(x), window = FALSE), sql(r"{PRODUCT(x)}")) + expect_equal(translate(prod(x), window = TRUE), sql(r"{PRODUCT(x) OVER ()}")) + + expect_equal(translate(sd(x), window = FALSE), sql(r"{STDDEV(x)}")) + expect_equal(translate(sd(x), window = TRUE), sql(r"{STDDEV(x) OVER ()}")) + + expect_equal(translate(var(x), window = FALSE), sql(r"{VARIANCE(x)}")) + expect_equal(translate(var(x), window = TRUE), sql(r"{VARIANCE(x) OVER ()}")) + + expect_equal(translate(all(x), window = FALSE), sql(r"{BOOL_AND(x)}")) + expect_equal(translate(all(x), window = TRUE), sql(r"{BOOL_AND(x) OVER ()}")) + + expect_equal(translate(any(x), window = FALSE), sql(r"{BOOL_OR(x)}")) + expect_equal(translate(any(x), window = TRUE), sql(r"{BOOL_OR(x) OVER ()}")) + + expect_equal(translate(str_flatten(x, ","), window = FALSE), sql(r"{STRING_AGG(x, ',')}")) + expect_equal(translate(str_flatten(x, ","), window = TRUE), sql(r"{STRING_AGG(x, ',') OVER ()}")) +}) + test_that("two variable aggregates are translated correctly", { skip_if_no_R4() skip_if_not_installed("dbplyr")