diff --git a/velox/functions/sparksql/Hash.cpp b/velox/functions/sparksql/Hash.cpp index 8d48f9d882db..7ec03cc011e6 100644 --- a/velox/functions/sparksql/Hash.cpp +++ b/velox/functions/sparksql/Hash.cpp @@ -46,57 +46,39 @@ struct HashTraits { template < typename HashClass, - typename T, - typename SeedType = typename HashTraits::SeedType, - typename ReturnType = typename HashTraits::ReturnType> + typename SeedType, + typename ReturnType, + typename T> ReturnType hashOne(T input, SeedType seed) { return HashClass::hashInt32(input, seed); } -template < - typename HashClass, - typename SeedType = typename HashTraits::SeedType, - typename ReturnType = typename HashTraits::ReturnType> +template ReturnType hashOne(int64_t input, SeedType seed) { return HashClass::hashInt64(input, seed); } -template < - typename HashClass, - typename SeedType = typename HashTraits::SeedType, - typename ReturnType = typename HashTraits::ReturnType> +template ReturnType hashOne(float input, SeedType seed) { return HashClass::hashFloat(input, seed); } -template < - typename HashClass, - typename SeedType = typename HashTraits::SeedType, - typename ReturnType = typename HashTraits::ReturnType> +template ReturnType hashOne(double input, SeedType seed) { return HashClass::hashDouble(input, seed); } -template < - typename HashClass, - typename SeedType = typename HashTraits::SeedType, - typename ReturnType = typename HashTraits::ReturnType> +template ReturnType hashOne(int128_t input, SeedType seed) { return HashClass::hashLongDecimal(input, seed); } -template < - typename HashClass, - typename SeedType = typename HashTraits::SeedType, - typename ReturnType = typename HashTraits::ReturnType> +template ReturnType hashOne(Timestamp input, SeedType seed) { return HashClass::hashTimestamp(input, seed); } -template < - typename HashClass, - typename SeedType = typename HashTraits::SeedType, - typename ReturnType = typename HashTraits::ReturnType> +template ReturnType hashOne(StringView input, SeedType seed) { return HashClass::hashBytes(input, seed); } @@ -130,7 +112,7 @@ template < typename HashClass, typename SeedType = typename HashTraits::SeedType, typename ReturnType = typename HashTraits::ReturnType> -class VectorHasher { +class SparkVectorHasher { public: // Compute the hash value of input vector at index. ReturnType hashAt(vector_size_t index, SeedType seed) { @@ -142,22 +124,22 @@ class VectorHasher { virtual ReturnType hashNotNull(vector_size_t index, SeedType seed) = 0; - VectorHasher(DecodedVector& decoded) : decoded_(decoded) {} + SparkVectorHasher(DecodedVector& decoded) : decoded_(decoded) {} - virtual ~VectorHasher() = default; + virtual ~SparkVectorHasher() = default; protected: const DecodedVector& decoded_; }; template -std::shared_ptr> createPrimitiveVectorHasher( +std::shared_ptr> createPrimitiveVectorHasher( DecodedVector& decoded) { return std::make_shared>(decoded); } template -std::shared_ptr> createVectorHasher( +std::shared_ptr> createVectorHasher( DecodedVector& decoded) { auto baseType = decoded.base()->type(); if (baseType->isPrimitiveType()) { @@ -178,14 +160,13 @@ template < TypeKind kind, typename SeedType, typename ReturnType> -class PrimitiveVectorHasher - : public VectorHasher { +class PrimitiveVectorHasher : public SparkVectorHasher { public: PrimitiveVectorHasher(DecodedVector& decoded) - : VectorHasher(decoded) {} + : SparkVectorHasher(decoded) {} ReturnType hashNotNull(vector_size_t index, SeedType seed) override { - return hashOne( + return hashOne( this->decoded_.template valueAt::NativeType>( index), seed); @@ -193,9 +174,10 @@ class PrimitiveVectorHasher }; template -class ArrayVectorHasher : public VectorHasher { +class ArrayVectorHasher : public SparkVectorHasher { public: - ArrayVectorHasher(DecodedVector& decoded) : VectorHasher(decoded) { + ArrayVectorHasher(DecodedVector& decoded) + : SparkVectorHasher(decoded) { base_ = decoded.base()->as(); indices_ = decoded.indices(); @@ -219,13 +201,14 @@ class ArrayVectorHasher : public VectorHasher { const ArrayVector* base_; const int32_t* indices_; DecodedVector decodedElements_; - std::shared_ptr> elementHasher_; + std::shared_ptr> elementHasher_; }; template -class MapVectorHasher : public VectorHasher { +class MapVectorHasher : public SparkVectorHasher { public: - MapVectorHasher(DecodedVector& decoded) : VectorHasher(decoded) { + MapVectorHasher(DecodedVector& decoded) + : SparkVectorHasher(decoded) { base_ = decoded.base()->as(); indices_ = decoded.indices(); @@ -253,14 +236,15 @@ class MapVectorHasher : public VectorHasher { const int32_t* indices_; DecodedVector decodedKeys_; DecodedVector decodedValues_; - std::shared_ptr> keyHasher_; - std::shared_ptr> valueHasher_; + std::shared_ptr> keyHasher_; + std::shared_ptr> valueHasher_; }; template -class RowVectorHasher : public VectorHasher { +class RowVectorHasher : public SparkVectorHasher { public: - RowVectorHasher(DecodedVector& decoded) : VectorHasher(decoded) { + RowVectorHasher(DecodedVector& decoded) + : SparkVectorHasher(decoded) { base_ = decoded.base()->as(); indices_ = decoded.indices(); @@ -285,7 +269,7 @@ class RowVectorHasher : public VectorHasher { const RowVector* base_; const int32_t* indices_; std::vector decodedChildren_; - std::vector>> hashers_; + std::vector>> hashers_; }; // ReturnType can be either int32_t or int64_t