Skip to content

Commit

Permalink
feat: Tweak implementation of r_base::sum()
Browse files Browse the repository at this point in the history
  • Loading branch information
krlmlr committed Sep 23, 2024
1 parent d4fa3a1 commit 6287696
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/rfuns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ struct RSumOperation {

template <class STATE>
static void Initialize(STATE &state) {
state.value = 0;
state.is_set = false;
state.is_null = false;
}
Expand Down Expand Up @@ -990,10 +991,10 @@ unique_ptr<FunctionData> BindRSum_dispatch(ClientContext &context, AggregateFunc
function = AggregateFunction::UnaryAggregate<RSumKeepNaState<double>, double, double, RSumOperation<RegularAdd, NA_RM>>(type, type);
break;
case LogicalTypeId::INTEGER:
function = AggregateFunction::UnaryAggregate<RSumKeepNaState<hugeint_t>, int32_t, hugeint_t, RSumOperation<HugeintAdd, NA_RM>>(type, type);
function = AggregateFunction::UnaryAggregate<RSumKeepNaState<double>, int32_t, double, RSumOperation<RegularAdd, NA_RM>>(type, LogicalTypeId::DOUBLE);
break;
case LogicalTypeId::BOOLEAN:
function = AggregateFunction::UnaryAggregate<RSumKeepNaState<int32_t>, bool, int32_t, RSumOperation<RegularAdd, NA_RM>>(LogicalType::BOOLEAN, LogicalType::INTEGER);
function = AggregateFunction::UnaryAggregate<RSumKeepNaState<int32_t>, bool, int32_t, RSumOperation<RegularAdd, NA_RM>>(type, LogicalType::INTEGER);
break;
default:
break;
Expand All @@ -1011,8 +1012,7 @@ unique_ptr<FunctionData> BindRSum(ClientContext &context, AggregateFunction &fun
}
}

void add_RSum(AggregateFunctionSet& set, const LogicalType& type) {
auto return_type = type == LogicalType::BOOLEAN ? LogicalType::INTEGER : type;
void add_RSum(AggregateFunctionSet& set, const LogicalType& type, const LogicalType& return_type) {
set.AddFunction(AggregateFunction(
{type, LogicalType::BOOLEAN}, return_type,
nullptr, nullptr, nullptr, nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr,
Expand All @@ -1029,9 +1029,9 @@ void add_RSum(AggregateFunctionSet& set, const LogicalType& type) {
AggregateFunctionSet base_r_sum() {
AggregateFunctionSet set("r_base::sum");

add_RSum(set, LogicalType::BOOLEAN);
add_RSum(set, LogicalType::INTEGER);
add_RSum(set, LogicalType::DOUBLE);
add_RSum(set, LogicalType::BOOLEAN, LogicalType::INTEGER);
add_RSum(set, LogicalType::INTEGER, LogicalType::DOUBLE);
add_RSum(set, LogicalType::DOUBLE, LogicalType::DOUBLE);

return set;
}
Expand Down

0 comments on commit 6287696

Please sign in to comment.