diff --git a/velox/functions/prestosql/aggregates/ValueList.cpp b/velox/functions/prestosql/aggregates/ValueList.cpp index 89da682020080..4c2433809a3f2 100644 --- a/velox/functions/prestosql/aggregates/ValueList.cpp +++ b/velox/functions/prestosql/aggregates/ValueList.cpp @@ -123,4 +123,27 @@ bool ValueListReader::next(BaseVector& output, vector_size_t outputIndex) { pos_++; return pos_ < size_; } + +bool ValueListReader::nextIgnoreNull( + BaseVector& output, + vector_size_t outputIndex, + bool& skipped) { + if (pos_ == lastNullsStart_) { + nulls_ = lastNulls_; + } else if (pos_ % 64 == 0) { + nulls_ = nullsStream_.read(); + } + + if (nulls_ & (1UL << (pos_ % 64))) { + // Ignore null. + skipped = true; + } else { + exec::ContainerRowSerde::instance().deserialize( + dataStream_, outputIndex, &output); + } + + pos_++; + return pos_ < size_; +} + } // namespace facebook::velox::aggregate diff --git a/velox/functions/prestosql/aggregates/ValueList.h b/velox/functions/prestosql/aggregates/ValueList.h index 13271b2856927..9eb5d17e008b9 100644 --- a/velox/functions/prestosql/aggregates/ValueList.h +++ b/velox/functions/prestosql/aggregates/ValueList.h @@ -108,6 +108,8 @@ class ValueListReader { explicit ValueListReader(ValueList& values); bool next(BaseVector& output, vector_size_t outputIndex); + bool + nextIgnoreNull(BaseVector& output, vector_size_t outputIndex, bool& skipped); private: const vector_size_t size_; diff --git a/velox/functions/sparksql/aggregates/ArrayAggAggregate.cpp b/velox/functions/sparksql/aggregates/ArrayAggAggregate.cpp new file mode 100644 index 0000000000000..651a72c122adb --- /dev/null +++ b/velox/functions/sparksql/aggregates/ArrayAggAggregate.cpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/ContainerRowSerde.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/functions/prestosql/aggregates/ValueList.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::velox::functions::aggregate::sparksql { +namespace { + +using namespace facebook::velox::aggregate; + +struct ArrayAccumulator { + ValueList elements; +}; + +class ArrayAggAggregate : public exec::Aggregate { + public: + explicit ArrayAggAggregate(TypePtr resultType) : Aggregate(resultType) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(ArrayAccumulator); + } + + bool isFixedSize() const override { + return false; + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + for (auto index : indices) { + new (groups[index] + offset_) ArrayAccumulator(); + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto vector = (*result)->as(); + VELOX_CHECK(vector); + vector->resize(numGroups); + + auto elements = vector->elements(); + elements->resize(countElements(groups, numGroups)); + + uint64_t* rawNulls = getRawNulls(vector); + vector_size_t offset = 0; + for (int32_t i = 0; i < numGroups; ++i) { + // No null result, either empty array or non-empty array. + clearNull(rawNulls, i); + auto& values = value(groups[i])->elements; + auto arraySize = values.size(); + if (arraySize) { + ValueListReader reader(values); + int count = 0; + for (auto index = 0; index < arraySize; ++index) { + // To mark whether skipped due to null input value. + bool skipped = false; + reader.nextIgnoreNull(*elements, offset + count, skipped); + if (!skipped) { + count++; + } + } + vector->setOffsetAndSize(i, offset, count); + offset += count; + } else { + // Empty array. + vector->setOffsetAndSize(i, offset, 0); + } + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + extractValues(groups, numGroups, result); + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedElements_.decode(*args[0], rows); + rows.applyToSelected([&](vector_size_t row) { + auto group = groups[row]; + auto tracker = trackRowSize(group); + value(group)->elements.appendValue( + decodedElements_, row, allocator_); + }); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedIntermediate_.decode(*args[0], rows); + + auto arrayVector = decodedIntermediate_.base()->as(); + auto& elements = arrayVector->elements(); + rows.applyToSelected([&](vector_size_t row) { + auto group = groups[row]; + auto decodedRow = decodedIntermediate_.index(row); + auto tracker = trackRowSize(group); + value(group)->elements.appendRange( + elements, + arrayVector->offsetAt(decodedRow), + arrayVector->sizeAt(decodedRow), + allocator_); + }); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + auto& values = value(group)->elements; + + decodedElements_.decode(*args[0], rows); + auto tracker = trackRowSize(group); + rows.applyToSelected([&](vector_size_t row) { + values.appendValue(decodedElements_, row, allocator_); + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedIntermediate_.decode(*args[0], rows); + auto arrayVector = decodedIntermediate_.base()->as(); + + auto& values = value(group)->elements; + auto elements = arrayVector->elements(); + rows.applyToSelected([&](vector_size_t row) { + auto decodedRow = decodedIntermediate_.index(row); + values.appendRange( + elements, + arrayVector->offsetAt(decodedRow), + arrayVector->sizeAt(decodedRow), + allocator_); + }); + } + + void destroy(folly::Range groups) override { + for (auto group : groups) { + value(group)->elements.free(allocator_); + } + } + + private: + vector_size_t countElements(char** groups, int32_t numGroups) const { + vector_size_t size = 0; + for (int32_t i = 0; i < numGroups; ++i) { + size += value(groups[i])->elements.size(); + } + return size; + } + + // Reusable instance of DecodedVector for decoding input vectors. + DecodedVector decodedElements_; + DecodedVector decodedIntermediate_; +}; + +exec::AggregateRegistrationResult registerArray(const std::string& name) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .typeVariable("E") + .returnType("array(E)") + .intermediateType("array(E)") + .argumentType("E") + .build()}; + + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType) -> std::unique_ptr { + VELOX_CHECK_EQ( + argTypes.size(), 1, "{} takes at most one argument", name); + return std::make_unique(resultType); + }, + false, + true); +} + +} // namespace + +void registerArrayAggregate(const std::string& prefix) { + registerArray(prefix + kArrayAgg); +} + +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/CMakeLists.txt b/velox/functions/sparksql/aggregates/CMakeLists.txt index 4977cea9ea7c7..03da0b966d6b7 100644 --- a/velox/functions/sparksql/aggregates/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/CMakeLists.txt @@ -13,8 +13,12 @@ # limitations under the License. add_library( velox_functions_spark_aggregates - BitwiseXorAggregate.cpp BloomFilterAggAggregate.cpp FirstLastAggregate.cpp - Register.cpp) + ArrayAggAggregate.cpp + BitwiseXorAggregate.cpp + BloomFilterAggAggregate.cpp + FirstLastAggregate.cpp + Register.cpp + SetAggAggregate.cpp) target_link_libraries(velox_functions_spark_aggregates fmt::fmt velox_exec velox_expression_functions velox_aggregates velox_vector) diff --git a/velox/functions/sparksql/aggregates/Register.cpp b/velox/functions/sparksql/aggregates/Register.cpp index f1983111c028a..176a6feead6b1 100644 --- a/velox/functions/sparksql/aggregates/Register.cpp +++ b/velox/functions/sparksql/aggregates/Register.cpp @@ -25,12 +25,16 @@ namespace facebook::velox::functions::aggregate::sparksql { using namespace facebook::velox::functions::sparksql::aggregates; extern void registerFirstLastAggregates(const std::string& prefix); +extern void registerArrayAggregate(const std::string& prefix); +extern void registerSetAggAggregate(const std::string& prefix); void registerAggregateFunctions(const std::string& prefix) { + registerArrayAggregate(prefix); registerFirstLastAggregates(prefix); registerBitwiseXorAggregate(prefix); registerBloomFilterAggAggregate(prefix + "bloom_filter_agg"); registerDecimalAvgAggregate(prefix + "decimal_avg"); registerDecimalSumAggregate(prefix + "decimal_sum"); + registerSetAggAggregate(prefix); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/SetAggAggregate.cpp b/velox/functions/sparksql/aggregates/SetAggAggregate.cpp new file mode 100644 index 0000000000000..9052e5ebc91f9 --- /dev/null +++ b/velox/functions/sparksql/aggregates/SetAggAggregate.cpp @@ -0,0 +1,451 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/Aggregate.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/functions/prestosql/aggregates/Strings.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::functions::aggregate::sparksql { + +namespace { + +using namespace facebook::velox::aggregate; +using namespace facebook::velox::aggregate::prestosql; + +/// Maintains a set of unique values of fixed-width type (integers). Also +/// maintains a flag indicating whether there was a null value. +template +struct Accumulator { + // bool hasNull{false}; + folly:: + F14FastSet, std::equal_to, AlignedStlAllocator> + uniqueValues; + + explicit Accumulator(HashStringAllocator* allocator) + : uniqueValues{AlignedStlAllocator(allocator)} {} + + /// Adds value if new. No-op if the value was added before. + void addValue( + const DecodedVector& decoded, + vector_size_t index, + HashStringAllocator* /*allocator*/) { + if (decoded.isNullAt(index)) { + // Ignore null. + // hasNull = true; + } else { + uniqueValues.insert(decoded.valueAt(index)); + } + } + + /// Adds new values from an array. + void addValues( + const ArrayVector& arrayVector, + vector_size_t index, + const DecodedVector& values, + HashStringAllocator* allocator) { + const auto size = arrayVector.sizeAt(index); + const auto offset = arrayVector.offsetAt(index); + + for (auto i = 0; i < size; ++i) { + addValue(values, offset + i, allocator); + } + } + + /// Returns number of unique values including null. + size_t size() const { + // return uniqueValues.size() + (hasNull ? 1 : 0); + return uniqueValues.size(); + } + + /// Copies the unique values and null into the specified vector starting at + /// the specified offset. + vector_size_t extractValues(FlatVector& values, vector_size_t offset) { + vector_size_t index = offset; + for (auto value : uniqueValues) { + values.set(index++, value); + } + + // if (hasNull) { + // values.setNull(index++, true); + // } + + return index - offset; + } +}; + +/// Maintains a set of unique strings. +struct StringViewAccumulator { + /// A set of unique StringViews pointing to storage managed by 'strings'. + Accumulator base; + + /// Stores unique non-null non-inline strings. + Strings strings; + + explicit StringViewAccumulator(HashStringAllocator* allocator) + : base{allocator} {} + + void addValue( + const DecodedVector& decoded, + vector_size_t index, + HashStringAllocator* allocator) { + if (decoded.isNullAt(index)) { + // Ignore null. + // base.hasNull = true; + } else { + auto value = decoded.valueAt(index); + if (!value.isInline()) { + if (base.uniqueValues.contains(value)) { + return; + } + value = strings.append(value, *allocator); + } + base.uniqueValues.insert(value); + } + } + + void addValues( + const ArrayVector& arrayVector, + vector_size_t index, + const DecodedVector& values, + HashStringAllocator* allocator) { + const auto size = arrayVector.sizeAt(index); + const auto offset = arrayVector.offsetAt(index); + + for (auto i = 0; i < size; ++i) { + addValue(values, offset + i, allocator); + } + } + + size_t size() const { + return base.size(); + } + + vector_size_t extractValues( + FlatVector& values, + vector_size_t offset) { + return base.extractValues(values, offset); + } +}; + +template +struct AccumulatorTypeTraits { + using AccumulatorType = Accumulator; +}; + +template <> +struct AccumulatorTypeTraits { + using AccumulatorType = StringViewAccumulator; +}; + +template +class SetBaseAggregate : public exec::Aggregate { + public: + explicit SetBaseAggregate(const TypePtr& resultType) + : exec::Aggregate(resultType) {} + + using AccumulatorType = typename AccumulatorTypeTraits::AccumulatorType; + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(AccumulatorType); + } + + bool isFixedSize() const override { + return false; + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) AccumulatorType(allocator_); + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto arrayVector = (*result)->as(); + arrayVector->resize(numGroups); + + auto* rawOffsets = arrayVector->offsets()->asMutable(); + auto* rawSizes = arrayVector->sizes()->asMutable(); + + vector_size_t numValues = 0; + uint64_t* rawNulls = getRawNulls(arrayVector); + for (auto i = 0; i < numGroups; ++i) { + auto* group = groups[i]; + if (isNull(group)) { + // arrayVector->setNull(i, true); + } else { + clearNull(rawNulls, i); + + const auto size = value(group)->size(); + + rawOffsets[i] = numValues; + rawSizes[i] = size; + + numValues += size; + } + } + + auto values = arrayVector->elements()->as>(); + values->resize(numValues); + + vector_size_t offset = 0; + for (auto i = 0; i < numGroups; ++i) { + auto* group = groups[i]; + if (!isNull(group)) { + offset += value(group)->extractValues(*values, offset); + } + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + return extractValues(groups, numGroups, result); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decoded_.decode(*args[0], rows); + + auto baseArray = decoded_.base()->template as(); + decodedElements_.decode(*baseArray->elements()); + + rows.applyToSelected([&](vector_size_t i) { + if (decoded_.isNullAt(i)) { + return; + } + + auto* group = groups[i]; + clearNull(group); + + auto tracker = trackRowSize(group); + + auto decodedIndex = decoded_.index(i); + value(group)->addValues( + *baseArray, decodedIndex, decodedElements_, allocator_); + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decoded_.decode(*args[0], rows); + + auto baseArray = decoded_.base()->template as(); + + decodedElements_.decode(*baseArray->elements()); + + auto* accumulator = value(group); + + auto tracker = trackRowSize(group); + rows.applyToSelected([&](vector_size_t i) { + if (decoded_.isNullAt(i)) { + return; + } + + clearNull(group); + + auto decodedIndex = decoded_.index(i); + accumulator->addValues( + *baseArray, decodedIndex, decodedElements_, allocator_); + }); + } + + void destroy(folly::Range groups) override { + if constexpr (std::is_same_v) { + for (auto* group : groups) { + if (!isNull(group)) { + value(group)->strings.free(*allocator_); + } + } + } + } + + protected: + inline AccumulatorType* value(char* group) { + return reinterpret_cast(group + Aggregate::offset_); + } + + DecodedVector decoded_; + DecodedVector decodedElements_; +}; + +template +class SetAggAggregate : public SetBaseAggregate { + public: + explicit SetAggAggregate(const TypePtr& resultType) + : SetBaseAggregate(resultType) {} + + using Base = SetBaseAggregate; + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + Base::decoded_.decode(*args[0], rows); + + rows.applyToSelected([&](vector_size_t i) { + auto* group = groups[i]; + Base::clearNull(group); + + auto tracker = Base::trackRowSize(group); + Base::value(group)->addValue(Base::decoded_, i, Base::allocator_); + }); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + Base::decoded_.decode(*args[0], rows); + + Base::clearNull(group); + auto* accumulator = Base::value(group); + + auto tracker = Base::trackRowSize(group); + rows.applyToSelected([&](vector_size_t i) { + accumulator->addValue(Base::decoded_, i, Base::allocator_); + }); + } +}; + +template +class SetUnionAggregate : public SetBaseAggregate { + public: + explicit SetUnionAggregate(const TypePtr& resultType) + : SetBaseAggregate(resultType) {} + + using Base = SetBaseAggregate; + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + Base::addIntermediateResults(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + Base::addSingleGroupIntermediateResults(group, rows, args, mayPushdown); + } +}; + +template