diff --git a/CMakeLists.txt b/CMakeLists.txt index 0167e52..ca61ad3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,8 @@ include_directories(src/include) set(EXTENSION_SOURCES src/rfuns_extension.cpp src/add.cpp src/relop.cpp - src/dispatch.cpp) + src/dispatch.cpp + src/sum.cpp) build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES}) build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES}) diff --git a/duckdb-rfuns-r/NAMESPACE b/duckdb-rfuns-r/NAMESPACE index e61cbac..7ecaf75 100644 --- a/duckdb-rfuns-r/NAMESPACE +++ b/duckdb-rfuns-r/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand +export(rfuns_sum) import(DBI) import(cli) import(rlang) diff --git a/duckdb-rfuns-r/R/aggregate.R b/duckdb-rfuns-r/R/aggregate.R new file mode 100644 index 0000000..274995b --- /dev/null +++ b/duckdb-rfuns-r/R/aggregate.R @@ -0,0 +1,37 @@ +#' sum +#' +#' @param x vector +#' @param na.rm should the missing values be removed +#' +#' @examples +#' rfuns_sum(1:10) +#' +#' @export +rfuns_sum <- function(x, na.rm = TRUE) { + rfuns_aggregate("sum", tibble(x = x), na.rm = na.rm) +} + +rfuns_aggregate <- function(fun, data, ...) { + con <- local_duckdb_con() + + names(data) <- paste("x", seq_len(ncol(data)), sep = "") + in_df <- as_tibble(data) + in_rel <- duckdb:::rel_from_df(con, in_df) + + refs <- map(names(data), duckdb:::expr_reference) + constants <- map(list2(...), duckdb:::expr_constant) + + exprs <- list( + duckdb:::expr_function( + paste0("r_base::", fun), + list2(!!!refs, !!!constants) + ) + ) + + agg <- duckdb:::rel_aggregate(in_rel, list(), exprs) + + withr::with_options(list(duckdb.materialize_message = FALSE), { + duckdb:::rel_to_altrep(agg)[, 1][] + }) +} + diff --git a/duckdb-rfuns-r/man/rfuns_sum.Rd b/duckdb-rfuns-r/man/rfuns_sum.Rd new file mode 100644 index 0000000..28435a5 --- /dev/null +++ b/duckdb-rfuns-r/man/rfuns_sum.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aggregate.R +\name{rfuns_sum} +\alias{rfuns_sum} +\title{sum} +\usage{ +rfuns_sum(x, na.rm = TRUE) +} +\arguments{ +\item{x}{vector} + +\item{na.rm}{should the missing values be removed} +} +\description{ +sum +} +\examples{ +rfuns_sum(1:10) + +} diff --git a/duckdb-rfuns-r/tests/testthat/_snaps/sum.md b/duckdb-rfuns-r/tests/testthat/_snaps/sum.md new file mode 100644 index 0000000..f88b33d --- /dev/null +++ b/duckdb-rfuns-r/tests/testthat/_snaps/sum.md @@ -0,0 +1,32 @@ +# r_base::sum(, na.rm = ) + + Code + rfuns_sum(c(TRUE, FALSE), na.rm = "hello") + Condition + Error: + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(BOOLEAN, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(BOOLEAN, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} + +--- + + Code + rfuns_sum(1:10, na.rm = "hello") + Condition + Error: + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(INTEGER, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(INTEGER, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} + +--- + + Code + rfuns_sum(c(1, 2, 3), na.rm = "hello") + Condition + Error: + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(DOUBLE, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(DOUBLE, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"} + +# r_base::sum( + + Code + rfuns_sum("HufflePuff") + Condition + Error: + ! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(VARCHAR, BOOLEAN)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(VARCHAR, BOOLEAN)","error_subtype":"NO_MATCHING_FUNCTION"} + diff --git a/duckdb-rfuns-r/tests/testthat/test-sum.R b/duckdb-rfuns-r/tests/testthat/test-sum.R new file mode 100644 index 0000000..fc9eea6 --- /dev/null +++ b/duckdb-rfuns-r/tests/testthat/test-sum.R @@ -0,0 +1,51 @@ +test_that("r_base::sum()", { + expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE)), 2L) + expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE, NA)), 2L) + + expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE, NA), na.rm = TRUE), 2L) + expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE, NA), na.rm = FALSE), NA_integer_) + + expect_equal(rfuns_sum(logical(), na.rm = FALSE), 0L) + expect_equal(rfuns_sum(logical(), na.rm = TRUE), 0L) + + expect_equal(rfuns_sum(NA, na.rm = TRUE), 0L) + expect_equal(rfuns_sum(NA, na.rm = FALSE), NA_integer_) +}) + +test_that("r_base::sum()", { + expect_equal(rfuns_sum(1:10), 55) + expect_equal(rfuns_sum(c(1:10, NA)), 55) + + expect_equal(rfuns_sum(c(1:10, NA), na.rm = TRUE), 55) + expect_equal(rfuns_sum(c(1:10, NA), na.rm = FALSE), NA_integer_) + + expect_equal(rfuns_sum(integer(), na.rm = FALSE), 0L) + expect_equal(rfuns_sum(integer(), na.rm = TRUE), 0L) + + expect_equal(rfuns_sum(NA_integer_, na.rm = TRUE), 0L) + expect_equal(rfuns_sum(NA_integer_, na.rm = FALSE), NA_integer_) +}) + +test_that("r_base::sum()", { + expect_equal(rfuns_sum(c(1.1, 2.2, 3.3)), 6.6) + expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA)), 6.6) + + expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA), na.rm = TRUE), 6.6) + expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA), na.rm = FALSE), NA_real_) + + expect_equal(rfuns_sum(double(), na.rm = FALSE), 0) + expect_equal(rfuns_sum(double(), na.rm = TRUE), 0) + + expect_equal(rfuns_sum(NA_real_, na.rm = TRUE), 0L) + expect_equal(rfuns_sum(NA_real_, na.rm = FALSE), NA_real_) +}) + +test_that("r_base::sum(, na.rm = )", { + expect_snapshot(error = TRUE, rfuns_sum(c(TRUE, FALSE), na.rm = "hello")) + expect_snapshot(error = TRUE, rfuns_sum(1:10, na.rm = "hello")) + expect_snapshot(error = TRUE, rfuns_sum(c(1, 2, 3), na.rm = "hello")) +}) + +test_that("r_base::sum(", { + expect_snapshot(error = TRUE, rfuns_sum("HufflePuff")) +}) diff --git a/src/include/rfuns_extension.hpp b/src/include/rfuns_extension.hpp index afffc66..c737ce0 100644 --- a/src/include/rfuns_extension.hpp +++ b/src/include/rfuns_extension.hpp @@ -31,6 +31,9 @@ ScalarFunctionSet base_r_lte(); ScalarFunctionSet base_r_gt(); ScalarFunctionSet base_r_gte(); +// sum +AggregateFunctionSet base_r_sum(); + ScalarFunctionSet binary_dispatch(ScalarFunctionSet fn) ; } // namespace rfuns diff --git a/src/rfuns_extension.cpp b/src/rfuns_extension.cpp index 170437e..48eec20 100644 --- a/src/rfuns_extension.cpp +++ b/src/rfuns_extension.cpp @@ -32,6 +32,9 @@ static void register_rfuns(DatabaseInstance &instance) { register_binary(instance, base_r_lte()); register_binary(instance, base_r_gt()); register_binary(instance, base_r_gte()); + + // sum + ExtensionUtil::RegisterFunction(instance, base_r_sum()); } } // namespace rfuns diff --git a/src/sum.cpp b/src/sum.cpp new file mode 100644 index 0000000..b3add92 --- /dev/null +++ b/src/sum.cpp @@ -0,0 +1,121 @@ +#include "rfuns_extension.hpp" + +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" + +#include +#include +#include "duckdb/core_functions/aggregate/sum_helpers.hpp" +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" + +namespace duckdb { +namespace rfuns { + +template +struct RSumKeepNaState { + T value; + bool is_set; + bool is_null; +}; + +template +struct RSumOperation { + + template + static void Initialize(STATE &state) { + state.is_set = false; + state.is_null = false; + state.value = 0; + } + + static bool IgnoreNull() { + return NA_RM; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (state.is_null) return; + state.is_set = true; + if (!NA_RM && !unary_input.RowIsValid()) { + state.is_null = true; + } else { + ADDOP::template AddNumber(state, input); + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { + if (!NA_RM && !unary_input.RowIsValid()) { + state.is_null = true; + } else { + ADDOP::template AddConstant(state, input, count); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!target.is_set) { + target = source; + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.is_null) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + +template +void BindRSum_dispatch(ClientContext &context, AggregateFunction &function, vector> &arguments) { + auto type = arguments[0]->return_type; + + switch (type.id()) { + case LogicalTypeId::DOUBLE: + function = AggregateFunction::UnaryAggregate, double, double, RSumOperation>(type, type); + break; + case LogicalTypeId::INTEGER: + function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, RSumOperation>(type, type); + break; + case LogicalTypeId::BOOLEAN: + function = AggregateFunction::UnaryAggregate, bool, int32_t, RSumOperation>(LogicalType::BOOLEAN, LogicalType::INTEGER); + break; + default: + break; + } +} + +unique_ptr BindRSum(ClientContext &context, AggregateFunction &function, vector> &arguments) { + auto na_rm = arguments[1]->ToString() == "true"; + if (na_rm) { + BindRSum_dispatch(context, function, arguments); + } else { + BindRSum_dispatch(context, function, arguments); + } + + return nullptr; +} + +AggregateFunction RSum(const LogicalType& type) { + auto return_type = type == LogicalType::BOOLEAN ? LogicalType::INTEGER : type; + return AggregateFunction( + {type, LogicalType::BOOLEAN}, return_type, + nullptr, nullptr, nullptr, nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, + BindRSum + ); +} + +AggregateFunctionSet base_r_sum() { + AggregateFunctionSet set("r_base::sum"); + + set.AddFunction(RSum(LogicalType::BOOLEAN)); + set.AddFunction(RSum(LogicalType::INTEGER)); + set.AddFunction(RSum(LogicalType::DOUBLE)); + + return set; +} + +} +}