Skip to content

Commit

Permalink
Reuse presto's array agg & set agg and ignore null input value (#431)
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE authored Nov 3, 2023
1 parent d02e83a commit 51ef237
Show file tree
Hide file tree
Showing 9 changed files with 1,140 additions and 4 deletions.
23 changes: 23 additions & 0 deletions velox/functions/prestosql/aggregates/ValueList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,27 @@ bool ValueListReader::next(BaseVector& output, vector_size_t outputIndex) {
pos_++;
return pos_ < size_;
}

bool ValueListReader::nextIgnoreNull(
BaseVector& output,
vector_size_t outputIndex,
bool& skipped) {
if (pos_ == lastNullsStart_) {
nulls_ = lastNulls_;
} else if (pos_ % 64 == 0) {
nulls_ = nullsStream_.read<uint64_t>();
}

if (nulls_ & (1UL << (pos_ % 64))) {
// Ignore null.
skipped = true;
} else {
exec::ContainerRowSerde::instance().deserialize(
dataStream_, outputIndex, &output);
}

pos_++;
return pos_ < size_;
}

} // namespace facebook::velox::aggregate
2 changes: 2 additions & 0 deletions velox/functions/prestosql/aggregates/ValueList.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class ValueListReader {
explicit ValueListReader(ValueList& values);

bool next(BaseVector& output, vector_size_t outputIndex);
bool
nextIgnoreNull(BaseVector& output, vector_size_t outputIndex, bool& skipped);

private:
const vector_size_t size_;
Expand Down
211 changes: 211 additions & 0 deletions velox/functions/sparksql/aggregates/ArrayAggAggregate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/exec/ContainerRowSerde.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/functions/prestosql/aggregates/AggregateNames.h"
#include "velox/functions/prestosql/aggregates/ValueList.h"
#include "velox/vector/ComplexVector.h"

namespace facebook::velox::functions::aggregate::sparksql {
namespace {

using namespace facebook::velox::aggregate;

struct ArrayAccumulator {
ValueList elements;
};

class ArrayAggAggregate : public exec::Aggregate {
public:
explicit ArrayAggAggregate(TypePtr resultType) : Aggregate(resultType) {}

int32_t accumulatorFixedWidthSize() const override {
return sizeof(ArrayAccumulator);
}

bool isFixedSize() const override {
return false;
}

void initializeNewGroups(
char** groups,
folly::Range<const vector_size_t*> indices) override {
for (auto index : indices) {
new (groups[index] + offset_) ArrayAccumulator();
}
}

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override {
auto vector = (*result)->as<ArrayVector>();
VELOX_CHECK(vector);
vector->resize(numGroups);

auto elements = vector->elements();
elements->resize(countElements(groups, numGroups));

uint64_t* rawNulls = getRawNulls(vector);
vector_size_t offset = 0;
for (int32_t i = 0; i < numGroups; ++i) {
// No null result, either empty array or non-empty array.
clearNull(rawNulls, i);
auto& values = value<ArrayAccumulator>(groups[i])->elements;
auto arraySize = values.size();
if (arraySize) {
ValueListReader reader(values);
int count = 0;
for (auto index = 0; index < arraySize; ++index) {
// To mark whether skipped due to null input value.
bool skipped = false;
reader.nextIgnoreNull(*elements, offset + count, skipped);
if (!skipped) {
count++;
}
}
vector->setOffsetAndSize(i, offset, count);
offset += count;
} else {
// Empty array.
vector->setOffsetAndSize(i, offset, 0);
}
}
}

void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result)
override {
extractValues(groups, numGroups, result);
}

void addRawInput(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
decodedElements_.decode(*args[0], rows);
rows.applyToSelected([&](vector_size_t row) {
auto group = groups[row];
auto tracker = trackRowSize(group);
value<ArrayAccumulator>(group)->elements.appendValue(
decodedElements_, row, allocator_);
});
}

void addIntermediateResults(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
decodedIntermediate_.decode(*args[0], rows);

auto arrayVector = decodedIntermediate_.base()->as<ArrayVector>();
auto& elements = arrayVector->elements();
rows.applyToSelected([&](vector_size_t row) {
auto group = groups[row];
auto decodedRow = decodedIntermediate_.index(row);
auto tracker = trackRowSize(group);
value<ArrayAccumulator>(group)->elements.appendRange(
elements,
arrayVector->offsetAt(decodedRow),
arrayVector->sizeAt(decodedRow),
allocator_);
});
}

void addSingleGroupRawInput(
char* group,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /* mayPushdown */) override {
auto& values = value<ArrayAccumulator>(group)->elements;

decodedElements_.decode(*args[0], rows);
auto tracker = trackRowSize(group);
rows.applyToSelected([&](vector_size_t row) {
values.appendValue(decodedElements_, row, allocator_);
});
}

void addSingleGroupIntermediateResults(
char* group,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /* mayPushdown */) override {
decodedIntermediate_.decode(*args[0], rows);
auto arrayVector = decodedIntermediate_.base()->as<ArrayVector>();

auto& values = value<ArrayAccumulator>(group)->elements;
auto elements = arrayVector->elements();
rows.applyToSelected([&](vector_size_t row) {
auto decodedRow = decodedIntermediate_.index(row);
values.appendRange(
elements,
arrayVector->offsetAt(decodedRow),
arrayVector->sizeAt(decodedRow),
allocator_);
});
}

void destroy(folly::Range<char**> groups) override {
for (auto group : groups) {
value<ArrayAccumulator>(group)->elements.free(allocator_);
}
}

private:
vector_size_t countElements(char** groups, int32_t numGroups) const {
vector_size_t size = 0;
for (int32_t i = 0; i < numGroups; ++i) {
size += value<ArrayAccumulator>(groups[i])->elements.size();
}
return size;
}

// Reusable instance of DecodedVector for decoding input vectors.
DecodedVector decodedElements_;
DecodedVector decodedIntermediate_;
};

exec::AggregateRegistrationResult registerArray(const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.typeVariable("E")
.returnType("array(E)")
.intermediateType("array(E)")
.argumentType("E")
.build()};

return exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType) -> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<ArrayAggAggregate>(resultType);
},
false,
true);
}

} // namespace

void registerArrayAggregate(const std::string& prefix) {
registerArray(prefix + kArrayAgg);
}

} // namespace facebook::velox::functions::aggregate::sparksql
8 changes: 6 additions & 2 deletions velox/functions/sparksql/aggregates/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
# limitations under the License.
add_library(
velox_functions_spark_aggregates
BitwiseXorAggregate.cpp BloomFilterAggAggregate.cpp FirstLastAggregate.cpp
Register.cpp)
ArrayAggAggregate.cpp
BitwiseXorAggregate.cpp
BloomFilterAggAggregate.cpp
FirstLastAggregate.cpp
Register.cpp
SetAggAggregate.cpp)

target_link_libraries(velox_functions_spark_aggregates fmt::fmt velox_exec
velox_expression_functions velox_aggregates velox_vector)
Expand Down
4 changes: 4 additions & 0 deletions velox/functions/sparksql/aggregates/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ namespace facebook::velox::functions::aggregate::sparksql {
using namespace facebook::velox::functions::sparksql::aggregates;

extern void registerFirstLastAggregates(const std::string& prefix);
extern void registerArrayAggregate(const std::string& prefix);
extern void registerSetAggAggregate(const std::string& prefix);

void registerAggregateFunctions(const std::string& prefix) {
registerArrayAggregate(prefix);
registerFirstLastAggregates(prefix);
registerBitwiseXorAggregate(prefix);
registerBloomFilterAggAggregate(prefix + "bloom_filter_agg");
registerDecimalAvgAggregate(prefix + "decimal_avg");
registerDecimalSumAggregate(prefix + "decimal_sum");
registerSetAggAggregate(prefix);
}
} // namespace facebook::velox::functions::aggregate::sparksql
Loading

0 comments on commit 51ef237

Please sign in to comment.