Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed Apr 11, 2024
1 parent 848b257 commit f4020fd
Showing 1 changed file with 47 additions and 44 deletions.
91 changes: 47 additions & 44 deletions velox/functions/sparksql/Hash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ const int32_t kDefaultSeed = 42;
struct Murmur3Hash;
struct XxHash64;

/// A template struct that contains the seed and return type of the hash
/// function.
template <typename HashClass>
struct HashTraits {};

Expand All @@ -44,12 +46,9 @@ struct HashTraits<XxHash64> {
using ReturnType = int64_t;
};

template <
typename HashClass,
typename SeedType,
typename ReturnType,
typename T>
ReturnType hashOne(T input, SeedType seed) {
// Computes the hash value of input using the hash function in HashClass.
template <typename HashClass, typename SeedType, typename ReturnType>
ReturnType hashOne(int32_t input, SeedType seed) {
return HashClass::hashInt32(input, seed);
}

Expand Down Expand Up @@ -83,54 +82,56 @@ ReturnType hashOne(StringView input, SeedType seed) {
return HashClass::hashBytes(input, seed);
}

/// Class to compute hashes identical to one produced by Spark.
/// Hashes are computed using the algorithm implemented in HashClass.
template <
typename HashClass,
TypeKind kind,
typename SeedType = typename HashTraits<HashClass>::SeedType,
typename ReturnType = typename HashTraits<HashClass>::ReturnType>
class PrimitiveVectorHasher;
class SparkVectorHasher {
public:
SparkVectorHasher(DecodedVector& decoded) : decoded_(decoded) {}

virtual ~SparkVectorHasher() = default;

// Compute the hash value of input vector at index.
ReturnType hashAt(vector_size_t index, SeedType seed) {
if (decoded_.isNullAt(index)) {
return seed;
}
return hashNotNull(index, seed);
}

virtual ReturnType hashNotNull(vector_size_t index, SeedType seed) = 0;

protected:
const DecodedVector& decoded_;
};

template <
typename HashClass,
TypeKind kind,
typename SeedType = typename HashTraits<HashClass>::SeedType,
typename ReturnType = typename HashTraits<HashClass>::ReturnType>
class ArrayVectorHasher;
class PrimitiveVectorHasher;

template <
typename HashClass,
typename SeedType = typename HashTraits<HashClass>::SeedType,
typename ReturnType = typename HashTraits<HashClass>::ReturnType>
class MapVectorHasher;
class ArrayVectorHasher;

template <
typename HashClass,
typename SeedType = typename HashTraits<HashClass>::SeedType,
typename ReturnType = typename HashTraits<HashClass>::ReturnType>
class RowVectorHasher;
class MapVectorHasher;

template <
typename HashClass,
typename SeedType = typename HashTraits<HashClass>::SeedType,
typename ReturnType = typename HashTraits<HashClass>::ReturnType>
class SparkVectorHasher {
public:
// Compute the hash value of input vector at index.
ReturnType hashAt(vector_size_t index, SeedType seed) {
if (decoded_.isNullAt(index)) {
return seed;
}
return hashNotNull(index, seed);
}

virtual ReturnType hashNotNull(vector_size_t index, SeedType seed) = 0;

SparkVectorHasher(DecodedVector& decoded) : decoded_(decoded) {}

virtual ~SparkVectorHasher() = default;

protected:
const DecodedVector& decoded_;
};
class RowVectorHasher;

template <typename HashClass, TypeKind kind>
std::shared_ptr<SparkVectorHasher<HashClass>> createPrimitiveVectorHasher(
Expand All @@ -141,18 +142,20 @@ std::shared_ptr<SparkVectorHasher<HashClass>> createPrimitiveVectorHasher(
template <typename HashClass>
std::shared_ptr<SparkVectorHasher<HashClass>> createVectorHasher(
DecodedVector& decoded) {
auto baseType = decoded.base()->type();
if (baseType->isPrimitiveType()) {
return VELOX_DYNAMIC_SCALAR_TEMPLATE_TYPE_DISPATCH(
createPrimitiveVectorHasher, HashClass, baseType->kind(), decoded);
} else if (baseType->isArray()) {
return std::make_shared<ArrayVectorHasher<HashClass>>(decoded);
} else if (baseType->isMap()) {
return std::make_shared<MapVectorHasher<HashClass>>(decoded);
} else if (baseType->isRow()) {
return std::make_shared<RowVectorHasher<HashClass>>(decoded);
}
VELOX_UNREACHABLE();
switch (decoded.base()->typeKind()) {
case TypeKind::ARRAY:
return std::make_shared<ArrayVectorHasher<HashClass>>(decoded);
case TypeKind::MAP:
return std::make_shared<MapVectorHasher<HashClass>>(decoded);
case TypeKind::ROW:
return std::make_shared<RowVectorHasher<HashClass>>(decoded);
default:
return VELOX_DYNAMIC_SCALAR_TEMPLATE_TYPE_DISPATCH(
createPrimitiveVectorHasher,
HashClass,
decoded.base()->typeKind(),
decoded);
}
}

template <
Expand Down Expand Up @@ -288,7 +291,7 @@ void applyWithType(
SeedType hashSeed = seed ? *seed : kDefaultSeed;

auto& result = *resultRef->as<FlatVector<ReturnType>>();
rows.applyToSelected([&](int row) { result.set(row, hashSeed); });
rows.applyToSelected([&](auto row) { result.set(row, hashSeed); });

exec::LocalSelectivityVector selectedMinusNulls(context);

Expand All @@ -304,7 +307,7 @@ void applyWithType(
}

auto hasher = createVectorHasher<HashClass>(*decoded);
selected->applyToSelected([&](int row) {
selected->applyToSelected([&](auto row) {
result.set(row, hasher->hashNotNull(row, result.valueAt(row)));
});
}
Expand Down

0 comments on commit f4020fd

Please sign in to comment.