Skip to content

Commit

Permalink
Merge pull request #75 from hannes/sum
Browse files Browse the repository at this point in the history
`r_base::sum()` handling `na.rm = `
  • Loading branch information
romainfrancois authored Apr 24, 2024
2 parents 5c6cbaa + f3861e4 commit 5ca3a39
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ include_directories(src/include)
set(EXTENSION_SOURCES src/rfuns_extension.cpp
src/add.cpp
src/relop.cpp
src/dispatch.cpp)
src/dispatch.cpp
src/sum.cpp)

build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES})
build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES})
Expand Down
1 change: 1 addition & 0 deletions duckdb-rfuns-r/NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

export(rfuns_sum)
import(DBI)
import(cli)
import(rlang)
Expand Down
37 changes: 37 additions & 0 deletions duckdb-rfuns-r/R/aggregate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#' sum
#'
#' @param x vector
#' @param na.rm should the missing values be removed
#'
#' @examples
#' rfuns_sum(1:10)
#'
#' @export
rfuns_sum <- function(x, na.rm = TRUE) {
rfuns_aggregate("sum", tibble(x = x), na.rm = na.rm)
}

rfuns_aggregate <- function(fun, data, ...) {
con <- local_duckdb_con()

names(data) <- paste("x", seq_len(ncol(data)), sep = "")
in_df <- as_tibble(data)
in_rel <- duckdb:::rel_from_df(con, in_df)

refs <- map(names(data), duckdb:::expr_reference)
constants <- map(list2(...), duckdb:::expr_constant)

exprs <- list(
duckdb:::expr_function(
paste0("r_base::", fun),
list2(!!!refs, !!!constants)
)
)

agg <- duckdb:::rel_aggregate(in_rel, list(), exprs)

withr::with_options(list(duckdb.materialize_message = FALSE), {
duckdb:::rel_to_altrep(agg)[, 1][]
})
}

20 changes: 20 additions & 0 deletions duckdb-rfuns-r/man/rfuns_sum.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions duckdb-rfuns-r/tests/testthat/_snaps/sum.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# r_base::sum(<?>, na.rm = <VARCHAR>)

Code
rfuns_sum(c(TRUE, FALSE), na.rm = "hello")
Condition
Error:
! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(BOOLEAN, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(BOOLEAN, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"}

---

Code
rfuns_sum(1:10, na.rm = "hello")
Condition
Error:
! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(INTEGER, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(INTEGER, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"}

---

Code
rfuns_sum(c(1, 2, 3), na.rm = "hello")
Condition
Error:
! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(DOUBLE, VARCHAR)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(DOUBLE, VARCHAR)","error_subtype":"NO_MATCHING_FUNCTION"}

# r_base::sum(<VARCHAR>

Code
rfuns_sum("HufflePuff")
Condition
Error:
! {"exception_type":"Binder","exception_message":"No function matches the given name and argument types 'r_base::sum(VARCHAR, BOOLEAN)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tr_base::sum(BOOLEAN, BOOLEAN) -> INTEGER\n\tr_base::sum(INTEGER, BOOLEAN) -> INTEGER\n\tr_base::sum(DOUBLE, BOOLEAN) -> DOUBLE\n","name":"r_base::sum","candidates":"r_base::sum(BOOLEAN, BOOLEAN) -> INTEGER,r_base::sum(INTEGER, BOOLEAN) -> INTEGER,r_base::sum(DOUBLE, BOOLEAN) -> DOUBLE","call":"r_base::sum(VARCHAR, BOOLEAN)","error_subtype":"NO_MATCHING_FUNCTION"}

51 changes: 51 additions & 0 deletions duckdb-rfuns-r/tests/testthat/test-sum.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
test_that("r_base::sum(<BOOLEAN>)", {
expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE)), 2L)
expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE, NA)), 2L)

expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE, NA), na.rm = TRUE), 2L)
expect_equal(rfuns_sum(c(TRUE, TRUE, FALSE, NA), na.rm = FALSE), NA_integer_)

expect_equal(rfuns_sum(logical(), na.rm = FALSE), 0L)
expect_equal(rfuns_sum(logical(), na.rm = TRUE), 0L)

expect_equal(rfuns_sum(NA, na.rm = TRUE), 0L)
expect_equal(rfuns_sum(NA, na.rm = FALSE), NA_integer_)
})

test_that("r_base::sum(<INTEGER>)", {
expect_equal(rfuns_sum(1:10), 55)
expect_equal(rfuns_sum(c(1:10, NA)), 55)

expect_equal(rfuns_sum(c(1:10, NA), na.rm = TRUE), 55)
expect_equal(rfuns_sum(c(1:10, NA), na.rm = FALSE), NA_integer_)

expect_equal(rfuns_sum(integer(), na.rm = FALSE), 0L)
expect_equal(rfuns_sum(integer(), na.rm = TRUE), 0L)

expect_equal(rfuns_sum(NA_integer_, na.rm = TRUE), 0L)
expect_equal(rfuns_sum(NA_integer_, na.rm = FALSE), NA_integer_)
})

test_that("r_base::sum(<DOUBLE>)", {
expect_equal(rfuns_sum(c(1.1, 2.2, 3.3)), 6.6)
expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA)), 6.6)

expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA), na.rm = TRUE), 6.6)
expect_equal(rfuns_sum(c(1.1, 2.2, 3.3, NA), na.rm = FALSE), NA_real_)

expect_equal(rfuns_sum(double(), na.rm = FALSE), 0)
expect_equal(rfuns_sum(double(), na.rm = TRUE), 0)

expect_equal(rfuns_sum(NA_real_, na.rm = TRUE), 0L)
expect_equal(rfuns_sum(NA_real_, na.rm = FALSE), NA_real_)
})

test_that("r_base::sum(<?>, na.rm = <VARCHAR>)", {
expect_snapshot(error = TRUE, rfuns_sum(c(TRUE, FALSE), na.rm = "hello"))
expect_snapshot(error = TRUE, rfuns_sum(1:10, na.rm = "hello"))
expect_snapshot(error = TRUE, rfuns_sum(c(1, 2, 3), na.rm = "hello"))
})

test_that("r_base::sum(<VARCHAR>", {
expect_snapshot(error = TRUE, rfuns_sum("HufflePuff"))
})
3 changes: 3 additions & 0 deletions src/include/rfuns_extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ ScalarFunctionSet base_r_lte();
ScalarFunctionSet base_r_gt();
ScalarFunctionSet base_r_gte();

// sum
AggregateFunctionSet base_r_sum();

ScalarFunctionSet binary_dispatch(ScalarFunctionSet fn) ;

} // namespace rfuns
Expand Down
3 changes: 3 additions & 0 deletions src/rfuns_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ static void register_rfuns(DatabaseInstance &instance) {
register_binary(instance, base_r_lte());
register_binary(instance, base_r_gt());
register_binary(instance, base_r_gte());

// sum
ExtensionUtil::RegisterFunction(instance, base_r_sum());
}
} // namespace rfuns

Expand Down
121 changes: 121 additions & 0 deletions src/sum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include "rfuns_extension.hpp"

#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp"

#include <math.h>
#include <climits>
#include "duckdb/core_functions/aggregate/sum_helpers.hpp"
#include "duckdb/core_functions/aggregate/distributive_functions.hpp"

namespace duckdb {
namespace rfuns {

template <class T>
struct RSumKeepNaState {
T value;
bool is_set;
bool is_null;
};

template <class ADDOP, bool NA_RM>
struct RSumOperation {

template <class STATE>
static void Initialize(STATE &state) {
state.is_set = false;
state.is_null = false;
state.value = 0;
}

static bool IgnoreNull() {
return NA_RM;
}

template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
if (state.is_null) return;
state.is_set = true;
if (!NA_RM && !unary_input.RowIsValid()) {
state.is_null = true;
} else {
ADDOP::template AddNumber<STATE, INPUT_TYPE>(state, input);
}
}

template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) {
if (!NA_RM && !unary_input.RowIsValid()) {
state.is_null = true;
} else {
ADDOP::template AddConstant<STATE, INPUT_TYPE>(state, input, count);
}
}

template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
if (!target.is_set) {
target = source;
}
}

template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.is_null) {
finalize_data.ReturnNull();
} else {
target = state.value;
}
}
};

template <bool NA_RM>
void BindRSum_dispatch(ClientContext &context, AggregateFunction &function, vector<unique_ptr<Expression>> &arguments) {
auto type = arguments[0]->return_type;

switch (type.id()) {
case LogicalTypeId::DOUBLE:
function = AggregateFunction::UnaryAggregate<RSumKeepNaState<double>, double, double, RSumOperation<RegularAdd, NA_RM>>(type, type);
break;
case LogicalTypeId::INTEGER:
function = AggregateFunction::UnaryAggregate<RSumKeepNaState<hugeint_t>, int32_t, hugeint_t, RSumOperation<HugeintAdd, NA_RM>>(type, type);
break;
case LogicalTypeId::BOOLEAN:
function = AggregateFunction::UnaryAggregate<RSumKeepNaState<int32_t>, bool, int32_t, RSumOperation<RegularAdd, NA_RM>>(LogicalType::BOOLEAN, LogicalType::INTEGER);
break;
default:
break;
}
}

unique_ptr<FunctionData> BindRSum(ClientContext &context, AggregateFunction &function, vector<unique_ptr<Expression>> &arguments) {
auto na_rm = arguments[1]->ToString() == "true";
if (na_rm) {
BindRSum_dispatch<true>(context, function, arguments);
} else {
BindRSum_dispatch<false>(context, function, arguments);
}

return nullptr;
}

AggregateFunction RSum(const LogicalType& type) {
auto return_type = type == LogicalType::BOOLEAN ? LogicalType::INTEGER : type;
return AggregateFunction(
{type, LogicalType::BOOLEAN}, return_type,
nullptr, nullptr, nullptr, nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr,
BindRSum
);
}

AggregateFunctionSet base_r_sum() {
AggregateFunctionSet set("r_base::sum");

set.AddFunction(RSum(LogicalType::BOOLEAN));
set.AddFunction(RSum(LogicalType::INTEGER));
set.AddFunction(RSum(LogicalType::DOUBLE));

return set;
}

}
}

0 comments on commit 5ca3a39

Please sign in to comment.