Skip to content

Commit

Permalink
Test
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Mar 4, 2024
1 parent 29528ad commit 506c4bf
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 67 deletions.
11 changes: 11 additions & 0 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ class QueryConfig {
static constexpr const char* kSparkBloomFilterMaxNumBits =
"spark.bloom_filter.max_num_bits";

/// The current spark partition id.
static constexpr const char* kSparkPartitionId = "spark.partition_id";

/// The number of local parallel table writer operators per task.
static constexpr const char* kTaskWriterCount = "task_writer_count";

Expand Down Expand Up @@ -685,6 +688,14 @@ class QueryConfig {
return value;
}

int32_t sparkPartitionId() const {
auto id = get<int32_t>(kSparkPartitionId);
VELOX_CHECK(id.has_value(), "Spark partition id is not set.");
auto value = id.value();
VELOX_CHECK_GE(value, 0, "Invalid Spark partition id.");
return value;
}

bool exprTrackCpuUsage() const {
return get<bool>(kExprTrackCpuUsage, false);
}
Expand Down
3 changes: 3 additions & 0 deletions velox/docs/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,6 @@ Spark-specific Configuration
- 4194304
- The maximum number of bits to use for the bloom filter in :spark:func:`bloom_filter_agg` function,
the value of this config can not exceed the default value.
* - spark.partition_id
- integer
- The current task's Spark partition ID. It's set by the query engine (Spark) prior to task execution.
55 changes: 13 additions & 42 deletions velox/functions/sparksql/Rand.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,54 +23,25 @@ template <typename T>
struct RandFunction {
static constexpr bool is_deterministic = false;

FOLLY_ALWAYS_INLINE void call(double& result) {
result = folly::Random::randDouble01();
template <typename TInput>
void initialize(const core::QueryConfig& config, const TInput* seedInput) {
auto partitionId = config.sparkPartitionId();
generator_ = std::mt19937{};
int64_t seed = seedInput ? (int64_t)*seedInput : 0;
generator_.seed(seed + partitionId);
}

FOLLY_ALWAYS_INLINE void callNullable(
double& result,
const int32_t* seed,
const int32_t* partitionIndex) {
initializeGenerator(seed, partitionIndex);
result = folly::Random::randDouble01(*generator_);
}

// To differentiate generator for each thread, seed plus partitionIndex is
// the actual seed used for generator.
FOLLY_ALWAYS_INLINE void callNullable(
double& result,
const int64_t* seed,
const int32_t* partitionIndex) {
initializeGenerator(seed, partitionIndex);
result = folly::Random::randDouble01(*generator_);
FOLLY_ALWAYS_INLINE void call(double& result) {
result = folly::Random::randDouble01();
}

// For NULL constant input of unknown type.
FOLLY_ALWAYS_INLINE void callNullable(
double& result,
const UnknownValue* /*seed*/,
const int32_t* partitionIndex) {
initializeGenerator<int64_t>(nullptr, partitionIndex);
result = folly::Random::randDouble01(*generator_);
template <typename TInput>
FOLLY_ALWAYS_INLINE void callNullable(double& result, TInput /*seedInput*/) {
result = folly::Random::randDouble01(generator_);
}

private:
template <typename TSeed>
FOLLY_ALWAYS_INLINE void initializeGenerator(
const TSeed* seed,
const int32_t* partitionIndex) {
VELOX_USER_CHECK_NOT_NULL(partitionIndex, "partitionIndex cannot be null.");
if (!generator_.has_value()) {
generator_ = std::mt19937{};
if (seed != nullptr) {
generator_->seed((int64_t)*seed + *partitionIndex);
} else {
// For null seed, partitionIndex is the seed, consistent with Spark.
generator_->seed(*partitionIndex);
}
}
}

std::optional<std::mt19937> generator_;
std::mt19937 generator_;
};

} // namespace facebook::velox::functions::sparksql
22 changes: 4 additions & 18 deletions velox/functions/sparksql/RegisterArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,10 @@ namespace facebook::velox::functions::sparksql {

void registerRandFunctions(const std::string& prefix) {
registerFunction<RandFunction, double>({prefix + "rand", prefix + "random"});
// Has seed & partition index as input.
registerFunction<
RandFunction,
double,
int32_t /*seed*/,
int32_t /*partition index*/>({prefix + "rand", prefix + "random"});
// Has seed & partition index as input.
registerFunction<
RandFunction,
double,
int64_t /*seed*/,
int32_t /*partition index*/>({prefix + "rand", prefix + "random"});
// NULL constant as seed of unknown type.
registerFunction<
RandFunction,
double,
UnknownValue /*seed*/,
int32_t /*partition index*/>({prefix + "rand", prefix + "random"});
registerFunction<RandFunction, double, Constant<int32_t>>(
{prefix + "rand", prefix + "random"});
registerFunction<RandFunction, double, Constant<int64_t>>(
{prefix + "rand", prefix + "random"});
}

void registerArithmeticFunctions(const std::string& prefix) {
Expand Down
21 changes: 14 additions & 7 deletions velox/functions/sparksql/tests/RandTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,31 @@ class RandTest : public SparkFunctionBaseTest {
}

protected:
void setSparkPartitionId(int32_t partitionId) {
queryCtx_->testingOverrideConfigUnsafe(
{{core::QueryConfig::kSparkPartitionId, std::to_string(partitionId)}});
}

std::optional<double> rand(int32_t seed, int32_t partitionIndex = 0) {
setSparkPartitionId(partitionIndex);
return evaluateOnce<double>(
fmt::format("rand({}, {})", seed, partitionIndex),
makeRowVector(ROW({}), 1));
fmt::format("rand({})", seed), makeRowVector(ROW({}), 1));
}

std::optional<double> randWithNullSeed(int32_t partitionIndex = 0) {
return evaluateOnce<double>(
fmt::format("rand(NULL, {})", partitionIndex),
makeRowVector(ROW({}), 1));
setSparkPartitionId(partitionIndex);
std::optional<int32_t> seed = std::nullopt;
return evaluateOnce<double>("rand(c0)", seed);
}

std::optional<double> randWithNoSeed() {
setSparkPartitionId(0);
return evaluateOnce<double>("rand()", makeRowVector(ROW({}), 1));
}

VectorPtr randWithBatchInput(int32_t seed, int32_t partitionIndex = 0) {
auto exprSet = compileExpression(
fmt::format("rand({}, {})", seed, partitionIndex), ROW({}));
setSparkPartitionId(partitionIndex);
auto exprSet = compileExpression(fmt::format("rand({})", seed), ROW({}));
return evaluate(*exprSet, makeRowVector(ROW({}), 20));
}

Expand Down Expand Up @@ -92,6 +98,7 @@ TEST_F(RandTest, withSeed) {

// Test with batch input.
auto batchResult1 = randWithBatchInput(100);
ASSERT_FALSE(batchResult1->isConstantEncoding());
auto batchResult2 = randWithBatchInput(100);
// Same seed & partition index produce same results.
velox::test::assertEqualVectors(batchResult1, batchResult2);
Expand Down

0 comments on commit 506c4bf

Please sign in to comment.