From f3073e81bdbf6c8e52d73d8c5be5f6c95afb78ea Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Fri, 19 Apr 2024 13:43:34 +0000 Subject: [PATCH 1/8] initial r_base::sum() passthrough --- CMakeLists.txt | 3 +- duckdb-rfuns-r/R/aggregate.R | 22 +++++++++++++ duckdb-rfuns-r/tests/testthat/_snaps/sum.md | 8 +++++ duckdb-rfuns-r/tests/testthat/test-sum.R | 24 ++++++++++++++ src/include/rfuns_extension.hpp | 3 ++ src/rfuns_extension.cpp | 3 ++ src/sum.cpp | 36 +++++++++++++++++++++ 7 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 duckdb-rfuns-r/R/aggregate.R create mode 100644 duckdb-rfuns-r/tests/testthat/_snaps/sum.md create mode 100644 duckdb-rfuns-r/tests/testthat/test-sum.R create mode 100644 src/sum.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 0167e52..ca61ad3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,8 @@ include_directories(src/include) set(EXTENSION_SOURCES src/rfuns_extension.cpp src/add.cpp src/relop.cpp - src/dispatch.cpp) + src/dispatch.cpp + src/sum.cpp) build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES}) build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES}) diff --git a/duckdb-rfuns-r/R/aggregate.R b/duckdb-rfuns-r/R/aggregate.R new file mode 100644 index 0000000..db281dd --- /dev/null +++ b/duckdb-rfuns-r/R/aggregate.R @@ -0,0 +1,22 @@ +rfuns_sum <- function(x, na.rm = TRUE) { + con <- local_duckdb_con() + + in_df <- tibble::tibble(x = x) + in_rel <- duckdb:::rel_from_df(con, in_df) + + exprs <- list( + duckdb:::expr_function( + "r_base::sum", + list( + duckdb:::expr_reference("x"), + duckdb:::expr_constant(TRUE) + ) + ) + ) + + agg <- duckdb:::rel_aggregate(in_rel, list(), exprs) + + withr::with_options(list(duckdb.materialize_message = FALSE), { + duckdb:::rel_to_altrep(agg)[, 1][] + }) +} diff --git a/duckdb-rfuns-r/tests/testthat/_snaps/sum.md b/duckdb-rfuns-r/tests/testthat/_snaps/sum.md new file mode 100644 index 0000000..6a01776 --- /dev/null +++ b/duckdb-rfuns-r/tests/testthat/_snaps/sum.md @@ -0,0 +1,8 @@ +# r_base::sum( + + Code + rfuns_sum("HufflePuff") + Condition + Error: + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(VARCHAR, BOOLEAN)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(VARCHAR, BOOLEAN)","error_subtype":"NO_MATCHING_FUNCTION"} + diff --git a/duckdb-rfuns-r/tests/testthat/test-sum.R b/duckdb-rfuns-r/tests/testthat/test-sum.R new file mode 100644 index 0000000..b58d305 --- /dev/null +++ b/duckdb-rfuns-r/tests/testthat/test-sum.R @@ -0,0 +1,24 @@ +test_that("r_base::sum()", { + expect_equal(rfuns_sum(1:10), 55) + expect_equal(rfuns_sum(c(1:10, NA)), 55) + + expect_equal(rfuns_sum(c(1:10, NA), na.rm = TRUE), 55) + + # TODO: should be NA + expect_equal(rfuns_sum(c(1:10, NA), na.rm = FALSE), 55) +}) + + +test_that("r_base::sum()", { + expect_equal(rfuns_sum(c(1.1, 2.2, 3.3)), 6.6) + expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA)), 6.6) + + expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA), na.rm = TRUE), 6.6) + + # TODO: should be NA + expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA), na.rm = FALSE), 6.6) +}) + +test_that("r_base::sum(", { + expect_snapshot(error = TRUE, rfuns_sum("HufflePuff")) +}) diff --git a/src/include/rfuns_extension.hpp b/src/include/rfuns_extension.hpp index afffc66..c737ce0 100644 --- a/src/include/rfuns_extension.hpp +++ b/src/include/rfuns_extension.hpp @@ -31,6 +31,9 @@ ScalarFunctionSet base_r_lte(); ScalarFunctionSet base_r_gt(); ScalarFunctionSet base_r_gte(); +// sum +AggregateFunctionSet base_r_sum(); + ScalarFunctionSet binary_dispatch(ScalarFunctionSet fn) ; } // namespace rfuns diff --git a/src/rfuns_extension.cpp b/src/rfuns_extension.cpp index 170437e..48eec20 100644 --- a/src/rfuns_extension.cpp +++ b/src/rfuns_extension.cpp @@ -32,6 +32,9 @@ static void register_rfuns(DatabaseInstance &instance) { register_binary(instance, base_r_lte()); register_binary(instance, base_r_gt()); register_binary(instance, base_r_gte()); + + // sum + ExtensionUtil::RegisterFunction(instance, base_r_sum()); } } // namespace rfuns diff --git a/src/sum.cpp b/src/sum.cpp new file mode 100644 index 0000000..6b0bd05 --- /dev/null +++ b/src/sum.cpp @@ -0,0 +1,36 @@ +#include "rfuns_extension.hpp" + +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" + +#include +#include +#include "duckdb/core_functions/aggregate/sum_helpers.hpp" +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" + +namespace duckdb { +namespace rfuns { + +unique_ptr BindRSum(ClientContext &context, AggregateFunction &function, vector> &arguments) { + function = SumFun::GetFunctions().GetFunctionByArguments(context, {arguments[0]->return_type}); + return nullptr; +} + +AggregateFunction RSum(const LogicalType& type) { + return AggregateFunction( + {type, LogicalType::BOOLEAN}, type, + nullptr, nullptr, nullptr, nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, + BindRSum + ); +} + +AggregateFunctionSet base_r_sum() { + AggregateFunctionSet set("r_base::sum"); + + set.AddFunction(RSum(LogicalType::INTEGER)); + set.AddFunction(RSum(LogicalType::DOUBLE)); + + return set; +} + +} +} From 4c395f392beaddb2231258f8e61d5043e79c0372 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Wed, 24 Apr 2024 06:56:14 +0000 Subject: [PATCH 2/8] handle sum(na.rm = FALSE) --- src/sum.cpp | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/src/sum.cpp b/src/sum.cpp index 6b0bd05..6e76b08 100644 --- a/src/sum.cpp +++ b/src/sum.cpp @@ -10,8 +10,84 @@ namespace duckdb { namespace rfuns { +template +struct RSumKeepNaState { + T value; + bool is_set; + bool is_null; +}; + +template +struct RSumKeepNaOperation { + + template + static void Initialize(STATE &state) { + state.is_set = false; + state.is_null = false; + state.value = 0; + } + + static bool IgnoreNull() { + return false; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (state.is_null) return; + state.is_set = true; + if (!unary_input.RowIsValid()) { + state.is_null = true; + } else { + ADDOP::template AddNumber(state, input); + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { + if (!unary_input.RowIsValid()) { + state.is_null = true; + } else { + ADDOP::template AddConstant(state, input, count); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!target.is_set) { + target = source; + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set || state.is_null) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + unique_ptr BindRSum(ClientContext &context, AggregateFunction &function, vector> &arguments) { - function = SumFun::GetFunctions().GetFunctionByArguments(context, {arguments[0]->return_type}); + auto& na_rm = arguments[1]; + + auto na_keep = na_rm->ToString() != "true"; + if (na_keep) { + auto type = arguments[0]->return_type; + switch (type.id()) { + case LogicalTypeId::DOUBLE: + function = AggregateFunction::UnaryAggregate, double, double, RSumKeepNaOperation>(type, type); + break; + case LogicalTypeId::INTEGER: + function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, RSumKeepNaOperation>(type, type); + break; + default: + break; + } + } else { + function = SumFun::GetFunctions().GetFunctionByArguments(context, {arguments[0]->return_type}); + } + return nullptr; } From 74fcc4967c394e10ac71dfd9c22e5556908ed0b1 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Wed, 24 Apr 2024 06:59:51 +0000 Subject: [PATCH 3/8] some comments --- src/sum.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/sum.cpp b/src/sum.cpp index 6e76b08..bdd2a3c 100644 --- a/src/sum.cpp +++ b/src/sum.cpp @@ -69,10 +69,13 @@ struct RSumKeepNaOperation { }; unique_ptr BindRSum(ClientContext &context, AggregateFunction &function, vector> &arguments) { - auto& na_rm = arguments[1]; - - auto na_keep = na_rm->ToString() != "true"; - if (na_keep) { + auto na_rm = arguments[1]->ToString() == "true"; + if (na_rm) { + // na.rm = TRUE, just use the regular duckdb function + function = SumFun::GetFunctions().GetFunctionByArguments(context, {arguments[0]->return_type}); + } else { + // na.rm = FALSE + // use a custom function that does not ignore nulls and returns null if there are any auto type = arguments[0]->return_type; switch (type.id()) { case LogicalTypeId::DOUBLE: @@ -84,8 +87,6 @@ unique_ptr BindRSum(ClientContext &context, AggregateFunction &fun default: break; } - } else { - function = SumFun::GetFunctions().GetFunctionByArguments(context, {arguments[0]->return_type}); } return nullptr; From e711e79bf55d8df791ea20117dce92dc23649fbe Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Wed, 24 Apr 2024 07:22:49 +0000 Subject: [PATCH 4/8] can't use duckdb version --- src/sum.cpp | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/sum.cpp b/src/sum.cpp index bdd2a3c..9d17e7c 100644 --- a/src/sum.cpp +++ b/src/sum.cpp @@ -17,8 +17,8 @@ struct RSumKeepNaState { bool is_null; }; -template -struct RSumKeepNaOperation { +template +struct RSumOperation { template static void Initialize(STATE &state) { @@ -28,14 +28,14 @@ struct RSumKeepNaOperation { } static bool IgnoreNull() { - return false; + return NA_RM; } template static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { if (state.is_null) return; state.is_set = true; - if (!unary_input.RowIsValid()) { + if (!NA_RM && !unary_input.RowIsValid()) { state.is_null = true; } else { ADDOP::template AddNumber(state, input); @@ -44,7 +44,7 @@ struct RSumKeepNaOperation { template static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { - if (!unary_input.RowIsValid()) { + if (!NA_RM && !unary_input.RowIsValid()) { state.is_null = true; } else { ADDOP::template AddConstant(state, input, count); @@ -60,7 +60,7 @@ struct RSumKeepNaOperation { template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set || state.is_null) { + if (state.is_null) { finalize_data.ReturnNull(); } else { target = state.value; @@ -69,20 +69,30 @@ struct RSumKeepNaOperation { }; unique_ptr BindRSum(ClientContext &context, AggregateFunction &function, vector> &arguments) { + auto type = arguments[0]->return_type; auto na_rm = arguments[1]->ToString() == "true"; if (na_rm) { // na.rm = TRUE, just use the regular duckdb function - function = SumFun::GetFunctions().GetFunctionByArguments(context, {arguments[0]->return_type}); + switch (type.id()) { + case LogicalTypeId::DOUBLE: + function = AggregateFunction::UnaryAggregate, double, double, RSumOperation>(type, type); + break; + case LogicalTypeId::INTEGER: + function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, RSumOperation>(type, type); + break; + default: + break; + } + } else { // na.rm = FALSE // use a custom function that does not ignore nulls and returns null if there are any - auto type = arguments[0]->return_type; switch (type.id()) { case LogicalTypeId::DOUBLE: - function = AggregateFunction::UnaryAggregate, double, double, RSumKeepNaOperation>(type, type); + function = AggregateFunction::UnaryAggregate, double, double, RSumOperation>(type, type); break; case LogicalTypeId::INTEGER: - function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, RSumKeepNaOperation>(type, type); + function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, RSumOperation>(type, type); break; default: break; From c2b303e300c62ac791d432f7ee4a26f1a8bc4564 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Wed, 24 Apr 2024 07:28:19 +0000 Subject: [PATCH 5/8] BindRSum_dispatch --- src/sum.cpp | 43 ++++++++++++++++++------------------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/src/sum.cpp b/src/sum.cpp index 9d17e7c..819a460 100644 --- a/src/sum.cpp +++ b/src/sum.cpp @@ -68,35 +68,28 @@ struct RSumOperation { } }; -unique_ptr BindRSum(ClientContext &context, AggregateFunction &function, vector> &arguments) { +template +void BindRSum_dispatch(ClientContext &context, AggregateFunction &function, vector> &arguments) { auto type = arguments[0]->return_type; + + switch (type.id()) { + case LogicalTypeId::DOUBLE: + function = AggregateFunction::UnaryAggregate, double, double, RSumOperation>(type, type); + break; + case LogicalTypeId::INTEGER: + function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, RSumOperation>(type, type); + break; + default: + break; + } +} + +unique_ptr BindRSum(ClientContext &context, AggregateFunction &function, vector> &arguments) { auto na_rm = arguments[1]->ToString() == "true"; if (na_rm) { - // na.rm = TRUE, just use the regular duckdb function - switch (type.id()) { - case LogicalTypeId::DOUBLE: - function = AggregateFunction::UnaryAggregate, double, double, RSumOperation>(type, type); - break; - case LogicalTypeId::INTEGER: - function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, RSumOperation>(type, type); - break; - default: - break; - } - + BindRSum_dispatch(context, function, arguments); } else { - // na.rm = FALSE - // use a custom function that does not ignore nulls and returns null if there are any - switch (type.id()) { - case LogicalTypeId::DOUBLE: - function = AggregateFunction::UnaryAggregate, double, double, RSumOperation>(type, type); - break; - case LogicalTypeId::INTEGER: - function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, RSumOperation>(type, type); - break; - default: - break; - } + BindRSum_dispatch(context, function, arguments); } return nullptr; From 76c9be3fc06248b0eecb19333f34483b5148382d Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Wed, 24 Apr 2024 07:29:21 +0000 Subject: [PATCH 6/8] tests --- duckdb-rfuns-r/NAMESPACE | 1 + duckdb-rfuns-r/R/aggregate.R | 27 ++++++++++++++++++------ duckdb-rfuns-r/man/rfuns_sum.Rd | 20 ++++++++++++++++++ duckdb-rfuns-r/tests/testthat/test-sum.R | 17 ++++++++++----- 4 files changed, 54 insertions(+), 11 deletions(-) create mode 100644 duckdb-rfuns-r/man/rfuns_sum.Rd diff --git a/duckdb-rfuns-r/NAMESPACE b/duckdb-rfuns-r/NAMESPACE index 8bfbadf..88a7ba2 100644 --- a/duckdb-rfuns-r/NAMESPACE +++ b/duckdb-rfuns-r/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand +export(rfuns_sum) import(DBI) import(rlang) importFrom(constructive,construct) diff --git a/duckdb-rfuns-r/R/aggregate.R b/duckdb-rfuns-r/R/aggregate.R index db281dd..274995b 100644 --- a/duckdb-rfuns-r/R/aggregate.R +++ b/duckdb-rfuns-r/R/aggregate.R @@ -1,16 +1,30 @@ +#' sum +#' +#' @param x vector +#' @param na.rm should the missing values be removed +#' +#' @examples +#' rfuns_sum(1:10) +#' +#' @export rfuns_sum <- function(x, na.rm = TRUE) { + rfuns_aggregate("sum", tibble(x = x), na.rm = na.rm) +} + +rfuns_aggregate <- function(fun, data, ...) { con <- local_duckdb_con() - in_df <- tibble::tibble(x = x) + names(data) <- paste("x", seq_len(ncol(data)), sep = "") + in_df <- as_tibble(data) in_rel <- duckdb:::rel_from_df(con, in_df) + refs <- map(names(data), duckdb:::expr_reference) + constants <- map(list2(...), duckdb:::expr_constant) + exprs <- list( duckdb:::expr_function( - "r_base::sum", - list( - duckdb:::expr_reference("x"), - duckdb:::expr_constant(TRUE) - ) + paste0("r_base::", fun), + list2(!!!refs, !!!constants) ) ) @@ -20,3 +34,4 @@ rfuns_sum <- function(x, na.rm = TRUE) { duckdb:::rel_to_altrep(agg)[, 1][] }) } + diff --git a/duckdb-rfuns-r/man/rfuns_sum.Rd b/duckdb-rfuns-r/man/rfuns_sum.Rd new file mode 100644 index 0000000..28435a5 --- /dev/null +++ b/duckdb-rfuns-r/man/rfuns_sum.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aggregate.R +\name{rfuns_sum} +\alias{rfuns_sum} +\title{sum} +\usage{ +rfuns_sum(x, na.rm = TRUE) +} +\arguments{ +\item{x}{vector} + +\item{na.rm}{should the missing values be removed} +} +\description{ +sum +} +\examples{ +rfuns_sum(1:10) + +} diff --git a/duckdb-rfuns-r/tests/testthat/test-sum.R b/duckdb-rfuns-r/tests/testthat/test-sum.R index b58d305..2092188 100644 --- a/duckdb-rfuns-r/tests/testthat/test-sum.R +++ b/duckdb-rfuns-r/tests/testthat/test-sum.R @@ -3,20 +3,27 @@ test_that("r_base::sum()", { expect_equal(rfuns_sum(c(1:10, NA)), 55) expect_equal(rfuns_sum(c(1:10, NA), na.rm = TRUE), 55) + expect_equal(rfuns_sum(c(1:10, NA), na.rm = FALSE), NA_integer_) - # TODO: should be NA - expect_equal(rfuns_sum(c(1:10, NA), na.rm = FALSE), 55) -}) + expect_equal(rfuns_sum(integer(), na.rm = FALSE), 0L) + expect_equal(rfuns_sum(integer(), na.rm = TRUE), 0L) + expect_equal(rfuns_sum(NA_integer_, na.rm = TRUE), 0L) + expect_equal(rfuns_sum(NA_integer_, na.rm = FALSE), NA_integer_) +}) test_that("r_base::sum()", { expect_equal(rfuns_sum(c(1.1, 2.2, 3.3)), 6.6) expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA)), 6.6) expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA), na.rm = TRUE), 6.6) + expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA), na.rm = FALSE), NA_real_) + + expect_equal(rfuns_sum(double(), na.rm = FALSE), 0) + expect_equal(rfuns_sum(double(), na.rm = TRUE), 0) - # TODO: should be NA - expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA), na.rm = FALSE), 6.6) + expect_equal(rfuns_sum(NA_real_, na.rm = TRUE), 0L) + expect_equal(rfuns_sum(NA_real_, na.rm = FALSE), NA_real_) }) test_that("r_base::sum(", { From 115cce8a0298ccc47fad1f3a48bd92645af22139 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Wed, 24 Apr 2024 07:34:49 +0000 Subject: [PATCH 7/8] more tests --- duckdb-rfuns-r/tests/testthat/_snaps/sum.md | 16 ++++++++++++++++ duckdb-rfuns-r/tests/testthat/test-sum.R | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/duckdb-rfuns-r/tests/testthat/_snaps/sum.md b/duckdb-rfuns-r/tests/testthat/_snaps/sum.md index 6a01776..2b0b029 100644 --- a/duckdb-rfuns-r/tests/testthat/_snaps/sum.md +++ b/duckdb-rfuns-r/tests/testthat/_snaps/sum.md @@ -1,3 +1,19 @@ +# r_base::sum(, na.rm = ) + + Code + rfuns_sum(1:10, na.rm = "hello") + Condition + Error: + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(INTEGER, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(INTEGER, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} + +--- + + Code + rfuns_sum(c(1, 2, 3), na.rm = "hello") + Condition + Error: + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(DOUBLE, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(DOUBLE, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} + # r_base::sum( Code diff --git a/duckdb-rfuns-r/tests/testthat/test-sum.R b/duckdb-rfuns-r/tests/testthat/test-sum.R index 2092188..290dd54 100644 --- a/duckdb-rfuns-r/tests/testthat/test-sum.R +++ b/duckdb-rfuns-r/tests/testthat/test-sum.R @@ -26,6 +26,11 @@ test_that("r_base::sum()", { expect_equal(rfuns_sum(NA_real_, na.rm = FALSE), NA_real_) }) +test_that("r_base::sum(, na.rm = )", { + expect_snapshot(error = TRUE, rfuns_sum(1:10, na.rm = "hello")) + expect_snapshot(error = TRUE, rfuns_sum(c(1, 2, 3), na.rm = "hello")) +}) + test_that("r_base::sum(", { expect_snapshot(error = TRUE, rfuns_sum("HufflePuff")) }) From f3861e42df8db7ceb57fb502502d0db5038ec91d Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Wed, 24 Apr 2024 08:10:50 +0000 Subject: [PATCH 8/8] sum() --- duckdb-rfuns-r/tests/testthat/_snaps/sum.md | 14 +++++++++++--- duckdb-rfuns-r/tests/testthat/test-sum.R | 15 +++++++++++++++ src/sum.cpp | 7 ++++++- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/duckdb-rfuns-r/tests/testthat/_snaps/sum.md b/duckdb-rfuns-r/tests/testthat/_snaps/sum.md index 2b0b029..f88b33d 100644 --- a/duckdb-rfuns-r/tests/testthat/_snaps/sum.md +++ b/duckdb-rfuns-r/tests/testthat/_snaps/sum.md @@ -1,10 +1,18 @@ # r_base::sum(, na.rm = ) + Code + rfuns_sum(c(TRUE, FALSE), na.rm = "hello") + Condition + Error: + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(BOOLEAN, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(BOOLEAN, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} + +--- + Code rfuns_sum(1:10, na.rm = "hello") Condition Error: - ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(INTEGER, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(INTEGER, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(INTEGER, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(INTEGER, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} --- @@ -12,7 +20,7 @@ rfuns_sum(c(1, 2, 3), na.rm = "hello") Condition Error: - ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(DOUBLE, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(DOUBLE, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(DOUBLE, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(DOUBLE, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} # r_base::sum( @@ -20,5 +28,5 @@ rfuns_sum("HufflePuff") Condition Error: - ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(VARCHAR, BOOLEAN)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(VARCHAR, BOOLEAN)","error_subtype":"NO_MATCHING_FUNCTION"} + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(VARCHAR, BOOLEAN)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(VARCHAR, BOOLEAN)","error_subtype":"NO_MATCHING_FUNCTION"} diff --git a/duckdb-rfuns-r/tests/testthat/test-sum.R b/duckdb-rfuns-r/tests/testthat/test-sum.R index 290dd54..fc9eea6 100644 --- a/duckdb-rfuns-r/tests/testthat/test-sum.R +++ b/duckdb-rfuns-r/tests/testthat/test-sum.R @@ -1,3 +1,17 @@ +test_that("r_base::sum()", { + expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE)), 2L) + expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE, NA)), 2L) + + expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE, NA), na.rm = TRUE), 2L) + expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE, NA), na.rm = FALSE), NA_integer_) + + expect_equal(rfuns_sum(logical(), na.rm = FALSE), 0L) + expect_equal(rfuns_sum(logical(), na.rm = TRUE), 0L) + + expect_equal(rfuns_sum(NA, na.rm = TRUE), 0L) + expect_equal(rfuns_sum(NA, na.rm = FALSE), NA_integer_) +}) + test_that("r_base::sum()", { expect_equal(rfuns_sum(1:10), 55) expect_equal(rfuns_sum(c(1:10, NA)), 55) @@ -27,6 +41,7 @@ test_that("r_base::sum()", { }) test_that("r_base::sum(, na.rm = )", { + expect_snapshot(error = TRUE, rfuns_sum(c(TRUE, FALSE), na.rm = "hello")) expect_snapshot(error = TRUE, rfuns_sum(1:10, na.rm = "hello")) expect_snapshot(error = TRUE, rfuns_sum(c(1, 2, 3), na.rm = "hello")) }) diff --git a/src/sum.cpp b/src/sum.cpp index 819a460..b3add92 100644 --- a/src/sum.cpp +++ b/src/sum.cpp @@ -79,6 +79,9 @@ void BindRSum_dispatch(ClientContext &context, AggregateFunction &function, vect case LogicalTypeId::INTEGER: function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, RSumOperation>(type, type); break; + case LogicalTypeId::BOOLEAN: + function = AggregateFunction::UnaryAggregate, bool, int32_t, RSumOperation>(LogicalType::BOOLEAN, LogicalType::INTEGER); + break; default: break; } @@ -96,8 +99,9 @@ unique_ptr BindRSum(ClientContext &context, AggregateFunction &fun } AggregateFunction RSum(const LogicalType& type) { + auto return_type = type == LogicalType::BOOLEAN ? LogicalType::INTEGER : type; return AggregateFunction( - {type, LogicalType::BOOLEAN}, type, + {type, LogicalType::BOOLEAN}, return_type, nullptr, nullptr, nullptr, nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, BindRSum ); @@ -106,6 +110,7 @@ AggregateFunction RSum(const LogicalType& type) { AggregateFunctionSet base_r_sum() { AggregateFunctionSet set("r_base::sum"); + set.AddFunction(RSum(LogicalType::BOOLEAN)); set.AddFunction(RSum(LogicalType::INTEGER)); set.AddFunction(RSum(LogicalType::DOUBLE));