Skip to content

Commit

Permalink
[GLUTEN-3590][CORE] Reduce driver memory usage by using serialized by…
Browse files Browse the repository at this point in the history
…tes for substrait plan in GlutenPartition (#3591)
  • Loading branch information
exmy authored Nov 2, 2023
1 parent f82737a commit 805d909
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.glutenproject.substrait.plan.PlanNode;

import com.google.protobuf.Any;
import io.substrait.proto.Plan;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.internal.SQLConf;

Expand Down Expand Up @@ -80,12 +79,12 @@ private PlanNode buildNativeConfNode(Map<String, String> confs) {
// Used by WholeStageTransform to create the native computing pipeline and
// return a columnar result iterator.
public GeneralOutIterator createKernelWithBatchIterator(
Plan wsPlan, List<GeneralInIterator> iterList, boolean materializeInput) {
byte[] wsPlan, List<GeneralInIterator> iterList, boolean materializeInput) {
long allocId = CHNativeMemoryAllocators.contextInstance().getNativeInstanceId();
long handle =
jniWrapper.nativeCreateKernelWithIterator(
allocId,
getPlanBytesBuf(wsPlan),
wsPlan,
iterList.toArray(new GeneralInIterator[0]),
buildNativeConfNode(
GlutenConfig.getNativeBackendConf(
Expand Down Expand Up @@ -115,10 +114,6 @@ public GeneralOutIterator createKernelWithBatchIterator(
return createOutIterator(handle);
}

private byte[] getPlanBytesBuf(Plan planNode) {
return planNode.toByteArray();
}

private GeneralOutIterator createOutIterator(long nativeHandle) {
return new BatchIterator(nativeHandle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,19 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
fileFormats(i)),
SoftAffinityUtil.getFilePartitionLocations(f))
case _ =>
throw new UnsupportedOperationException(s"Unsupport operators.")
throw new UnsupportedOperationException(s"Unsupported input partition.")
})
wsCxt.substraitContext.initLocalFilesNodesIndex(0)
wsCxt.substraitContext.setLocalFilesNodes(localFilesNodesWithLocations.map(_._1))
val substraitPlan = wsCxt.root.toProtobuf
if (index < 3) {
if (index == 0) {
logOnLevel(
GlutenConfig.getConf.substraitPlanLogLevel,
s"The substrait plan for partition $index:\n${SubstraitPlanPrinterUtil
.substraitPlanToJson(substraitPlan)}"
)
}
GlutenPartition(index, substraitPlan, localFilesNodesWithLocations.head._2)
GlutenPartition(index, substraitPlan.toByteArray, localFilesNodesWithLocations.head._2)
}

/**
Expand Down Expand Up @@ -185,7 +185,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
}.asJava)
// we need to complete dependency RDD's firstly
transKernel.createKernelWithBatchIterator(
rootNode.toProtobuf,
rootNode.toProtobuf.toByteArray,
columnarNativeIterator,
materializeInput)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
wsCxt.substraitContext.initLocalFilesNodesIndex(0)
wsCxt.substraitContext.setLocalFilesNodes(localFilesNodesWithLocations.map(_._1))
val substraitPlan = wsCxt.root.toProtobuf
GlutenPartition(index, substraitPlan, localFilesNodesWithLocations.head._2)
GlutenPartition(index, substraitPlan.toByteArray, localFilesNodesWithLocations.head._2)
}

/**
Expand Down Expand Up @@ -187,7 +187,9 @@ class IteratorApiImpl extends IteratorApi with Logging {
iter => new ColumnarBatchInIterator(iter.asJava)
}.asJava)
val nativeResultIterator =
transKernel.createKernelWithBatchIterator(rootNode.toProtobuf, columnarNativeIterator)
transKernel.createKernelWithBatchIterator(
rootNode.toProtobuf.toByteArray,
columnarNativeIterator)

pipelineTime += TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import java.util.Map;

public class PlanBuilder {

public static byte[] EMPTY_PLAN = empty().toProtobuf().toByteArray();

private PlanBuilder() {}

public static PlanNode makePlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,22 @@ import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ExecutorManager

import io.substrait.proto.Plan

import scala.collection.mutable

trait BaseGlutenPartition extends Partition with InputPartition {
def plan: Plan
def plan: Array[Byte]
}

case class GlutenPartition(index: Int, plan: Plan, locations: Array[String] = Array.empty[String])
case class GlutenPartition(
index: Int,
plan: Array[Byte],
locations: Array[String] = Array.empty[String])
extends BaseGlutenPartition {

override def preferredLocations(): Array[String] = locations
}

case class GlutenFilePartition(index: Int, files: Array[PartitionedFile], plan: Plan)
case class GlutenFilePartition(index: Int, files: Array[PartitionedFile], plan: Array[Byte])
extends BaseGlutenPartition {
override def preferredLocations(): Array[String] = {
// Computes total number of bytes can be retrieved from each host.
Expand Down Expand Up @@ -74,15 +75,11 @@ case class GlutenMergeTreePartition(
tablePath: String,
minParts: Long,
maxParts: Long,
plan: Plan = PlanBuilder.empty().toProtobuf)
plan: Array[Byte] = PlanBuilder.EMPTY_PLAN)
extends BaseGlutenPartition {
override def preferredLocations(): Array[String] = {
Array.empty[String]
}

def copySubstraitPlan(newSubstraitPlan: Plan): GlutenMergeTreePartition = {
this.copy(plan = newSubstraitPlan)
}
}

case class FirstZippedPartitionsPartition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,11 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
.genFilePartition(i, currentPartitions, allScanPartitionSchemas, fileFormats, wsCxt)
})
(wsCxt, substraitPlanPartitions)
}(t => logOnLevel(substraitPlanLogLevel, s"Generating the Substrait plan took: $t ms."))
}(
t =>
logOnLevel(
substraitPlanLogLevel,
s"$nodeName generating the substrait plan took: $t ms."))

new GlutenWholeStageColumnarRDD(
sparkContext,
Expand All @@ -291,7 +295,8 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
* result, genFinalStageIterator rather than genFirstStageIterator will be invoked
*/
val resCtx = GlutenTimeMetric.withMillisTime(doWholeStageTransform()) {
t => logOnLevel(substraitPlanLogLevel, s"Generating the Substrait plan took: $t ms.")
t =>
logOnLevel(substraitPlanLogLevel, s"$nodeName generating the substrait plan took: $t ms.")
}
new WholeStageZippedPartitionsRDD(
sparkContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object ExpressionConverter extends SQLConfHelper with Logging {
expr: Expression,
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
logDebug(
s"replaceWithExpressionTransformer expr: $expr class: ${expr.getClass}} " +
s"replaceWithExpressionTransformer expr: $expr class: ${expr.getClass} " +
s"name: ${expr.prettyName}")

expr match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate

val locations = SoftAffinityUtil.getFilePartitionLocations(partition)

val nativePartition = new GlutenPartition(0, PlanBuilder.empty().toProtobuf, locations)
val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations)
assertResult(Set("host-1", "host-2", "host-3")) {
nativePartition.preferredLocations().toSet
}
Expand Down Expand Up @@ -91,7 +91,7 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate

val locations = SoftAffinityUtil.getFilePartitionLocations(partition)

val nativePartition = new GlutenPartition(0, PlanBuilder.empty().toProtobuf, locations)
val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations)

assertResult(Set("host-1", "host-4", "host-5")) {
nativePartition.preferredLocations().toSet
Expand Down Expand Up @@ -121,7 +121,7 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate

val locations = SoftAffinityUtil.getFilePartitionLocations(partition)

val nativePartition = new GlutenPartition(0, PlanBuilder.empty().toProtobuf, locations)
val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations)

assertResult(Set("executor_host-2_2", "executor_host-1_0")) {
nativePartition.preferredLocations().toSet
Expand All @@ -133,7 +133,7 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate

val locations = SoftAffinityUtil.getNativeMergeTreePartitionLocations(partition)

val nativePartition = new GlutenPartition(0, PlanBuilder.empty().toProtobuf, locations)
val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations)

assertResult(Set("executor_host-1_1")) {
nativePartition.preferredLocations().toSet
Expand Down Expand Up @@ -163,7 +163,7 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate

val locations = SoftAffinityUtil.getFilePartitionLocations(partition)

val nativePartition = new GlutenPartition(0, PlanBuilder.empty().toProtobuf, locations)
val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations)

assertResult(Set("host-1", "host-5", "host-6")) {
nativePartition.preferredLocations().toSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import io.glutenproject.utils.DebugUtil;
import io.glutenproject.validate.NativePlanValidationInfo;

import io.substrait.proto.Plan;
import org.apache.spark.TaskContext;
import org.apache.spark.util.SparkDirectoryUtil;

Expand Down Expand Up @@ -58,7 +57,7 @@ public NativePlanValidationInfo doNativeValidateWithFailureReason(byte[] subPlan
// Used by WholeStageTransform to create the native computing pipeline and
// return a columnar result iterator.
public GeneralOutIterator createKernelWithBatchIterator(
Plan wsPlan, List<GeneralInIterator> iterList) throws RuntimeException, IOException {
byte[] wsPlan, List<GeneralInIterator> iterList) throws RuntimeException, IOException {
final AtomicReference<ColumnarBatchOutIterator> outIterator = new AtomicReference<>();
final NativeMemoryManager nmm =
NativeMemoryManagers.create(
Expand All @@ -85,7 +84,7 @@ public GeneralOutIterator createKernelWithBatchIterator(
long iterHandle =
jniWrapper.nativeCreateKernelWithIterator(
memoryManagerHandle,
getPlanBytesBuf(wsPlan),
wsPlan,
iterList.toArray(new GeneralInIterator[0]),
TaskContext.get().stageId(),
TaskContext.getPartitionId(),
Expand All @@ -100,8 +99,4 @@ private ColumnarBatchOutIterator createOutIterator(
Runtime runtime, long iterHandle, NativeMemoryManager nmm) throws IOException {
return new ColumnarBatchOutIterator(runtime, iterHandle, nmm);
}

private byte[] getPlanBytesBuf(Plan planNode) {
return planNode.toByteArray();
}
}

0 comments on commit 805d909

Please sign in to comment.