diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index eefe8b867a72..122088879bdb 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -783,6 +783,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper jdouble reallocThreshold, jlong firstBatchHandle, jlong taskAttemptId, + jint partitionKeySeed, jint pushBufferMaxSize, jobject partitionPusher, jstring partitionWriterTypeJstr) { @@ -825,6 +826,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper } shuffleWriterOptions.task_attempt_id = (int64_t)taskAttemptId; + shuffleWriterOptions.partition_key_seed = partitionKeySeed; shuffleWriterOptions.compression_threshold = bufferCompressThreshold; auto partitionWriterTypeC = env->GetStringUTFChars(partitionWriterTypeJstr, JNI_FALSE); diff --git a/cpp/core/shuffle/Options.h b/cpp/core/shuffle/Options.h index 129b425dd3a6..092a5080f71b 100644 --- a/cpp/core/shuffle/Options.h +++ b/cpp/core/shuffle/Options.h @@ -58,6 +58,7 @@ struct ShuffleWriterOptions { int64_t thread_id = -1; int64_t task_attempt_id = -1; + int32_t partition_key_seed = 0; arrow::ipc::IpcWriteOptions ipc_write_options = arrow::ipc::IpcWriteOptions::Defaults(); diff --git a/cpp/core/shuffle/Partitioner.cc b/cpp/core/shuffle/Partitioner.cc index 31afb6e190b0..c777f6ae2cfb 100644 --- a/cpp/core/shuffle/Partitioner.cc +++ b/cpp/core/shuffle/Partitioner.cc @@ -23,12 +23,15 @@ namespace gluten { -arrow::Result> Partitioner::make(Partitioning partitioning, int32_t numPartitions) { +arrow::Result> Partitioner::make( + Partitioning partitioning, + int32_t numPartitions, + int32_t partitionKeySeed) { switch (partitioning) { case Partitioning::kHash: return std::make_shared(numPartitions); case Partitioning::kRoundRobin: - return std::make_shared(numPartitions); + return std::make_shared(numPartitions, partitionKeySeed); case Partitioning::kSingle: return std::make_shared(); case Partitioning::kRange: diff --git a/cpp/core/shuffle/Partitioner.h b/cpp/core/shuffle/Partitioner.h index c60a15cf45ce..f4517808de9d 100644 --- a/cpp/core/shuffle/Partitioner.h +++ b/cpp/core/shuffle/Partitioner.h @@ -26,7 +26,10 @@ namespace gluten { class Partitioner { public: - static arrow::Result> make(Partitioning partitioning, int32_t numPartitions); + static arrow::Result> make( + Partitioning partitioning, + int32_t numPartitions, + int32_t partitionKeySeed); // Whether the first column is partition key. bool hasPid() const { diff --git a/cpp/core/shuffle/RoundRobinPartitioner.cc b/cpp/core/shuffle/RoundRobinPartitioner.cc index 94de3f9247eb..93dbce56d403 100644 --- a/cpp/core/shuffle/RoundRobinPartitioner.cc +++ b/cpp/core/shuffle/RoundRobinPartitioner.cc @@ -27,21 +27,9 @@ arrow::Status gluten::RoundRobinPartitioner::compute( std::fill(std::begin(partition2RowCount), std::end(partition2RowCount), 0); row2Partition.resize(numRows); - int32_t pidSelection = pidSelection_; - for (int32_t i = 0; i < numRows;) { - int32_t low = i; - int32_t up = std::min((int64_t)(i + (numPartitions_ - pidSelection)), numRows); - for (; low != up;) { - row2Partition[low++] = pidSelection++; - } - - pidSelection_ = pidSelection; - pidSelection = 0; - i = up; - } - - if (pidSelection_ >= numPartitions_) { - pidSelection_ -= numPartitions_; + for (int32_t i = 0; i < numRows; ++i) { + pidSelection_ = (pidSelection_ + 1) % numPartitions_; + row2Partition[i] = pidSelection_; } for (auto& pid : row2Partition) { diff --git a/cpp/core/shuffle/RoundRobinPartitioner.h b/cpp/core/shuffle/RoundRobinPartitioner.h index 8ea15e5afc7c..50facf0cdb34 100644 --- a/cpp/core/shuffle/RoundRobinPartitioner.h +++ b/cpp/core/shuffle/RoundRobinPartitioner.h @@ -23,7 +23,8 @@ namespace gluten { class RoundRobinPartitioner final : public Partitioner { public: - RoundRobinPartitioner(int32_t numPartitions) : Partitioner(numPartitions, false) {} + RoundRobinPartitioner(int32_t numPartitions, int32_t partitionKeySeed) + : Partitioner(numPartitions, false), pidSelection_(partitionKeySeed) {} arrow::Status compute( const int32_t* pidArr, diff --git a/cpp/core/tests/RoundRobinPartitionerTest.cc b/cpp/core/tests/RoundRobinPartitionerTest.cc index 5fb3e00feb19..e29ebcd55d6a 100644 --- a/cpp/core/tests/RoundRobinPartitionerTest.cc +++ b/cpp/core/tests/RoundRobinPartitionerTest.cc @@ -62,18 +62,18 @@ class RoundRobinPartitionerTest : public ::testing::Test { }; TEST_F(RoundRobinPartitionerTest, TestInit) { - int numPart = 0; - prepareData(numPart); + int numPart = 2; + prepareData(numPart, 1); ASSERT_NE(partitioner_, nullptr); int32_t pidSelection = getPidSelection(); - ASSERT_EQ(pidSelection, 0); + ASSERT_EQ(pidSelection, 1); } TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) { // numRows equal numPart { int numPart = 10; - prepareData(numPart); + prepareData(numPart, 0); int numRows = 10; ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_, partition2RowCount_).ok()); ASSERT_EQ(getPidSelection(), 0); @@ -85,7 +85,7 @@ TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) { // numRows less than numPart { int numPart = 10; - prepareData(numPart); + prepareData(numPart, 0); int numRows = 8; ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_, partition2RowCount_).ok()); ASSERT_EQ(getPidSelection(), 8); @@ -99,7 +99,7 @@ TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) { // numRows greater than numPart { int numPart = 10; - prepareData(numPart); + prepareData(numPart, 0); int numRows = 12; ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_, partition2RowCount_).ok()); ASSERT_EQ(getPidSelection(), 2); @@ -113,7 +113,7 @@ TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) { // numRows greater than 2*numPart { int numPart = 10; - prepareData(numPart); + prepareData(numPart, 0); int numRows = 22; ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_, partition2RowCount_).ok()); ASSERT_EQ(getPidSelection(), 2); @@ -127,7 +127,7 @@ TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) { TEST_F(RoundRobinPartitionerTest, TestComoputeContinuous) { int numPart = 10; - prepareData(numPart); + prepareData(numPart, 0); { int numRows = 8; diff --git a/cpp/velox/shuffle/VeloxShuffleWriter.cc b/cpp/velox/shuffle/VeloxShuffleWriter.cc index 2329352a94b5..8a5dd38309d0 100644 --- a/cpp/velox/shuffle/VeloxShuffleWriter.cc +++ b/cpp/velox/shuffle/VeloxShuffleWriter.cc @@ -411,7 +411,9 @@ arrow::Status VeloxShuffleWriter::init() { VELOX_CHECK_NOT_NULL(options_.memory_pool); ARROW_ASSIGN_OR_RAISE(partitionWriter_, partitionWriterCreator_->make(this)); - ARROW_ASSIGN_OR_RAISE(partitioner_, Partitioner::make(options_.partitioning, numPartitions_)); + ARROW_ASSIGN_OR_RAISE( + partitioner_, + Partitioner::make(options_.partitioning, numPartitions_, options_.partition_key_seed)); // pre-allocated buffer size for each partition, unit is row count // when partitioner is SinglePart, partial variables don`t need init diff --git a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala index e6b5efbc97eb..d75f307d5e39 100644 --- a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala +++ b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SparkResourceUtil +import org.apache.spark.util.random.XORShiftRandom import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.CelebornConf @@ -72,6 +73,12 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V]( } else { val handle = ColumnarBatches.getNativeHandle(cb) if (nativeShuffleWriter == -1L) { + val partitionKeySeed = dep.nativePartitioning.getShortName match { + case "rr" => + new XORShiftRandom(context.partitionId()) + .nextInt(dep.partitioner.numPartitions) + case _ => 0 + } nativeShuffleWriter = jniWrapper.makeForRSS( dep.nativePartitioning, nativeBufferSize, diff --git a/gluten-data/src/main/java/io/glutenproject/vectorized/ShuffleWriterJniWrapper.java b/gluten-data/src/main/java/io/glutenproject/vectorized/ShuffleWriterJniWrapper.java index 4b440c766ce7..58ef5da6e550 100644 --- a/gluten-data/src/main/java/io/glutenproject/vectorized/ShuffleWriterJniWrapper.java +++ b/gluten-data/src/main/java/io/glutenproject/vectorized/ShuffleWriterJniWrapper.java @@ -64,7 +64,8 @@ public long make( boolean writeEOS, double reallocThreshold, long handle, - long taskAttemptId) { + long taskAttemptId, + int partitionKeySeed) { return nativeMake( part.getShortName(), part.getNumPartitions(), @@ -81,6 +82,7 @@ public long make( reallocThreshold, handle, taskAttemptId, + partitionKeySeed, 0, null, "local"); @@ -105,6 +107,7 @@ public long makeForRSS( long memoryManagerHandle, long handle, long taskAttemptId, + int partitionKeySeed, String partitionWriterType, double reallocThreshold) { return nativeMake( @@ -123,6 +126,7 @@ public long makeForRSS( reallocThreshold, handle, taskAttemptId, + partitionKeySeed, pushBufferMaxSize, pusher, partitionWriterType); @@ -144,6 +148,7 @@ public native long nativeMake( double reallocThreshold, long handle, long taskAttemptId, + int partitionKeySeed, int pushBufferMaxSize, Object pusher, String partitionWriterType); diff --git a/gluten-data/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/gluten-data/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index 9e3cca7744ce..e28671a37d1c 100644 --- a/gluten-data/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/gluten-data/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -30,6 +30,7 @@ import org.apache.spark.memory.SparkMemoryUtil import org.apache.spark.scheduler.MapStatus import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkDirectoryUtil, SparkResourceUtil, Utils} +import org.apache.spark.util.random.XORShiftRandom import java.io.IOException @@ -121,6 +122,12 @@ class ColumnarShuffleWriter[K, V]( val rows = cb.numRows() val handle = ColumnarBatches.getNativeHandle(cb) if (nativeShuffleWriter == -1L) { + val partitionKeySeed = dep.nativePartitioning.getShortName match { + case "rr" => + new XORShiftRandom(taskContext.partitionId()) + .nextInt(dep.partitioner.numPartitions) + case _ => 0 + } nativeShuffleWriter = jniWrapper.make( dep.nativePartitioning, nativeBufferSize, @@ -155,7 +162,8 @@ class ColumnarShuffleWriter[K, V]( writeEOS, reallocThreshold, handle, - taskContext.taskAttemptId() + taskContext.taskAttemptId(), + partitionKeySeed ) } val startTime = System.nanoTime()