Skip to content

Commit

Permalink
Support get native plan tree string
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Nov 15, 2023
1 parent 31e354f commit c0d9752
Show file tree
Hide file tree
Showing 15 changed files with 133 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,8 @@ class CHTransformerApi extends TransformerApi with Logging {
val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
}

override def getNativePlanString(substraitPlan: Array[Byte], details: Boolean): String = {
throw new UnsupportedOperationException("CH backend does not support this method")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
package io.glutenproject.backendsapi.velox

import io.glutenproject.backendsapi.TransformerApi
import io.glutenproject.exec.Runtimes
import io.glutenproject.expression.ConverterUtils
import io.glutenproject.extension.ValidationResult
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}
import io.glutenproject.utils.InputPartitionsUtil
import io.glutenproject.vectorized.PlanEvaluatorJniWrapper

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateMap, Explode, Generator, JsonTuple, Literal, PosExplode}
Expand Down Expand Up @@ -123,4 +125,14 @@ class TransformerApiImpl extends TransformerApi with Logging {
val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
ExpressionBuilder.makeCast(typeNode, childNode, !nullOnOverflow)
}

override def getNativePlanString(substraitPlan: Array[Byte], details: Boolean): String = {
val tmpRuntime = Runtimes.tmpInstance()
try {
val jniWrapper = PlanEvaluatorJniWrapper.forRuntime(tmpRuntime)
jniWrapper.nativePlanString(substraitPlan, details)
} finally {
tmpRuntime.release()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class NativeBenchmarkPlanGenerator extends VeloxWholeStageTransformerSuite {
val executedPlan = df.queryExecution.executedPlan
val lastStageTransformer = executedPlan.find(_.isInstanceOf[WholeStageTransformer])
assert(lastStageTransformer.nonEmpty)
var planJson = lastStageTransformer.get.asInstanceOf[WholeStageTransformer].getPlanJson
var planJson = lastStageTransformer.get.asInstanceOf[WholeStageTransformer].substraitPlanJson
assert(planJson.isEmpty)
executedPlan.execute()
planJson = lastStageTransformer.get.asInstanceOf[WholeStageTransformer].getPlanJson
planJson = lastStageTransformer.get.asInstanceOf[WholeStageTransformer].substraitPlanJson
assert(planJson.nonEmpty)
}
spark.sparkContext.setLogLevel(logLevel)
Expand All @@ -82,7 +82,7 @@ class NativeBenchmarkPlanGenerator extends VeloxWholeStageTransformerSuite {
val finalPlan = executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
val lastStageTransformer = finalPlan.find(_.isInstanceOf[WholeStageTransformer])
assert(lastStageTransformer.nonEmpty)
val planJson = lastStageTransformer.get.asInstanceOf[WholeStageTransformer].getPlanJson
val planJson = lastStageTransformer.get.asInstanceOf[WholeStageTransformer].substraitPlanJson
assert(planJson.nonEmpty)
}
spark.sparkContext.setLogLevel(logLevel)
Expand Down Expand Up @@ -141,7 +141,7 @@ class NativeBenchmarkPlanGenerator extends VeloxWholeStageTransformerSuite {
val lastStageTransformer = finalPlan.find(_.isInstanceOf[WholeStageTransformer])
assert(lastStageTransformer.nonEmpty)
val plan =
lastStageTransformer.get.asInstanceOf[WholeStageTransformer].getPlanJson.split('\n')
lastStageTransformer.get.asInstanceOf[WholeStageTransformer].substraitPlanJson.split('\n')

val exampleJsonFile = Paths.get(generatedPlanDir, "example.json")
Files.write(exampleJsonFile, plan.toList.asJava, StandardCharsets.UTF_8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package io.glutenproject.execution
import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.{GenerateExec, RDDScanExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.{avg, col, udf}
import org.apache.spark.sql.types.{DecimalType, StringType, StructField, StructType}

import scala.collection.JavaConverters

class TestOperator extends VeloxWholeStageTransformerSuite {
class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper {

protected val rootPath: String = getClass.getResource("/").getPath
override protected val backend: String = "velox"
Expand Down Expand Up @@ -585,4 +586,17 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
}
}
}

test("Support get native plan tree string") {
runQueryAndCompare("select l_partkey + 1, count(*) from lineitem group by l_partkey + 1") {
df =>
val wholeStageTransformers = collect(df.queryExecution.executedPlan) {
case w: WholeStageTransformer => w
}
val nativePlanString = wholeStageTransformers.head.nativePlanString()
assert(nativePlanString.contains("Aggregation[FINAL"))
assert(nativePlanString.contains("Aggregation[PARTIAL"))
assert(nativePlanString.contains("TableScan"))
}
}
}
2 changes: 2 additions & 0 deletions cpp/core/compute/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
GLUTEN_CHECK(parseProtobuf(data, size, &substraitPlan_) == true, "Parse substrait plan failed");
}

virtual std::string planString(bool details, const std::unordered_map<std::string, std::string>& sessionConf) = 0;

// Just for benchmark
::substrait::Plan& getPlan() {
return substraitPlan_;
Expand Down
18 changes: 18 additions & 0 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,24 @@ JNIEXPORT void JNICALL Java_io_glutenproject_exec_RuntimeJniWrapper_releaseRunti
JNI_METHOD_END()
}

JNIEXPORT jstring JNICALL Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativePlanString( // NOLINT
JNIEnv* env,
jobject wrapper,
jbyteArray planArray,
jboolean details) {
JNI_METHOD_START

auto planData = reinterpret_cast<const uint8_t*>(env->GetByteArrayElements(planArray, 0));
auto planSize = env->GetArrayLength(planArray);
auto ctx = gluten::getRuntime(env, wrapper);
ctx->parsePlan(planData, planSize);
auto& conf = ctx->getConfMap();
auto planString = ctx->planString(details, conf);
return env->NewStringUTF(planString.c_str());

JNI_METHOD_END(nullptr)
}

JNIEXPORT jlong JNICALL
Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithIterator( // NOLINT
JNIEnv* env,
Expand Down
21 changes: 15 additions & 6 deletions cpp/velox/compute/VeloxPlanConverter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ using namespace facebook;
VeloxPlanConverter::VeloxPlanConverter(
const std::vector<std::shared_ptr<ResultIterator>>& inputIters,
velox::memory::MemoryPool* veloxPool,
const std::unordered_map<std::string, std::string>& confMap)
: inputIters_(inputIters), substraitVeloxPlanConverter_(veloxPool, confMap), pool_(veloxPool) {}
const std::unordered_map<std::string, std::string>& confMap,
bool validationMode)
: inputIters_(inputIters),
validationMode_(validationMode),
substraitVeloxPlanConverter_(veloxPool, confMap, validationMode),
pool_(veloxPool) {}

void VeloxPlanConverter::setInputPlanNode(const ::substrait::FetchRel& fetchRel) {
if (fetchRel.has_input()) {
Expand Down Expand Up @@ -118,9 +122,6 @@ void VeloxPlanConverter::setInputPlanNode(const ::substrait::ReadRel& sread) {
if (iterIdx == -1) {
return;
}
if (inputIters_.size() == 0) {
throw std::runtime_error("Invalid input iterator.");
}

// Get the input schema of this iterator.
uint64_t colNum = 0;
Expand All @@ -140,8 +141,16 @@ void VeloxPlanConverter::setInputPlanNode(const ::substrait::ReadRel& sread) {
outNames.emplace_back(colName);
}

std::shared_ptr<ResultIterator> iterator;
if (!validationMode_) {
if (inputIters_.size() == 0) {
throw std::runtime_error("Invalid input iterator.");
}
iterator = inputIters_[iterIdx];
}

auto outputType = ROW(std::move(outNames), std::move(veloxTypeList));
auto vectorStream = std::make_shared<RowVectorStream>(pool_, std::move(inputIters_[iterIdx]), outputType);
auto vectorStream = std::make_shared<RowVectorStream>(pool_, std::move(iterator), outputType);
auto valuesNode = std::make_shared<ValueStreamNode>(nextPlanNodeId(), outputType, std::move(vectorStream));
substraitVeloxPlanConverter_.insertInputNode(iterIdx, valuesNode, planNodeId_);
}
Expand Down
5 changes: 4 additions & 1 deletion cpp/velox/compute/VeloxPlanConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class VeloxPlanConverter {
explicit VeloxPlanConverter(
const std::vector<std::shared_ptr<ResultIterator>>& inputIters,
facebook::velox::memory::MemoryPool* veloxPool,
const std::unordered_map<std::string, std::string>& confMap);
const std::unordered_map<std::string, std::string>& confMap,
bool validationMode = false);

std::shared_ptr<const facebook::velox::core::PlanNode> toVeloxPlan(::substrait::Plan& substraitPlan);

Expand Down Expand Up @@ -71,6 +72,8 @@ class VeloxPlanConverter {

std::vector<std::shared_ptr<ResultIterator>> inputIters_;

bool validationMode_;

SubstraitToVeloxPlanConverter substraitVeloxPlanConverter_;

facebook::velox::memory::MemoryPool* pool_;
Expand Down
8 changes: 8 additions & 0 deletions cpp/velox/compute/VeloxRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ void VeloxRuntime::getInfoAndIds(
}
}

std::string VeloxRuntime::planString(bool details, const std::unordered_map<std::string, std::string>& sessionConf) {
std::vector<std::shared_ptr<ResultIterator>> inputs;
auto veloxMemoryPool = gluten::defaultLeafVeloxMemoryPool();
VeloxPlanConverter veloxPlanConverter(inputs, veloxMemoryPool.get(), sessionConf, true);
auto veloxPlan = veloxPlanConverter.toVeloxPlan(substraitPlan_);
return veloxPlan->toString(details, true);
}

std::shared_ptr<ResultIterator> VeloxRuntime::createResultIterator(
MemoryManager* memoryManager,
const std::string& spillDir,
Expand Down
2 changes: 2 additions & 0 deletions cpp/velox/compute/VeloxRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class VeloxRuntime final : public Runtime {
arrow::MemoryPool* arrowPool,
struct ArrowSchema* cSchema) override;

std::string planString(bool details, const std::unordered_map<std::string, std::string>& sessionConf) override;

std::shared_ptr<const facebook::velox::core::PlanNode> getVeloxPlan() {
return veloxPlan_;
}
Expand Down
3 changes: 3 additions & 0 deletions cpp/velox/tests/RuntimeTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class DummyRuntime final : public Runtime {
std::shared_ptr<ColumnarBatch> select(MemoryManager*, std::shared_ptr<ColumnarBatch>, std::vector<int32_t>) override {
throw GlutenException("Not yet implemented");
}
std::string planString(bool details, std::unordered_map<std::string, std::string>& sessionConf) override {
throw GlutenException("Not yet implemented");
}

private:
ResourceMap<std::shared_ptr<ResultIterator>> resultIteratorHolder_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,6 @@ trait TransformerApi {
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode

def getNativePlanString(substraitPlan: Array[Byte], details: Boolean): String
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,28 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
val numaBindingInfo: GlutenNumaBindingInfo = GlutenConfig.getConf.numaBindingInfo
val substraitPlanLogLevel: String = GlutenConfig.getConf.substraitPlanLogLevel

private var planJson: String = ""

def getPlanJson: String = {
if (log.isDebugEnabled() && planJson.isEmpty) {
logWarning("Plan in JSON string is empty. This may due to the plan has not been executed.")
@transient
private var wholeStageTransformerContext: Option[WholeStageTransformContext] = None

def substraitPlan: PlanNode = {
if (wholeStageTransformerContext.isDefined) {
// TODO: remove this work around after we make `RelNode#toProtobuf` idempotent
// see `SubstraitContext#getCurrentLocalFileNode`.
wholeStageTransformerContext.get.substraitContext.initLocalFilesNodesIndex(0)
wholeStageTransformerContext.get.root
} else {
generateWholeStageTransformContext().root
}
planJson
}

def substraitPlanJson: String = {
SubstraitPlanPrinterUtil.substraitPlanToJson(substraitPlan.toProtobuf)
}

def nativePlanString(details: Boolean = true): String = {
BackendsApiManager.getTransformerApiInstance.getNativePlanString(
substraitPlan.toProtobuf.toByteArray,
details)
}

override def output: Seq[Attribute] = child.output
Expand Down Expand Up @@ -145,9 +160,9 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
maxFields,
printNodeId = printNodeId,
indent)
if (verbose && planJson.nonEmpty) {
if (verbose && wholeStageTransformerContext.isDefined) {
append(prefix + "Substrait plan:\n")
append(planJson)
append(substraitPlanJson)
append("\n")
}
}
Expand All @@ -157,10 +172,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
// See buildSparkPlanGraphNode in SparkPlanGraph.scala of Spark.
override def nodeName: String = s"WholeStageCodegenTransformer ($transformStageId)"

def doWholeStageTransform(): WholeStageTransformContext = {
// invoke SparkPlan.prepare to do subquery preparation etc.
super.prepare()

private def generateWholeStageTransformContext(): WholeStageTransformContext = {
val substraitContext = new SubstraitContext
val childCtx = child
.asInstanceOf[TransformSupport]
Expand Down Expand Up @@ -191,13 +203,19 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
PlanBuilder.makePlan(substraitContext, Lists.newArrayList(childCtx.root), outNames)
}

if (log.isDebugEnabled()) {
planJson = SubstraitPlanPrinterUtil.substraitPlanToJson(planNode.toProtobuf)
}

WholeStageTransformContext(planNode, substraitContext)
}

def doWholeStageTransform(): WholeStageTransformContext = {
// invoke SparkPlan.prepare to do subquery preparation etc.
super.prepare()
val context = generateWholeStageTransformContext()
if (conf.getConf(GlutenConfig.CACHE_WHOLE_STAGE_TRANSFORMER_CONTEXT)) {
wholeStageTransformerContext = Some(context)
}
context
}

/** Find all BasicScanExecTransformers in one WholeStageTransformer */
private def findAllScanTransformers(): Seq[BasicScanExecTransformer] = {
val basicScanExecTransformers = new mutable.ListBuffer[BasicScanExecTransformer]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public long handle() {
*/
native NativePlanValidationInfo nativeValidateWithFailureReason(byte[] subPlan);

public native String nativePlanString(byte[] substraitPlan, Boolean details);

/**
* Create a native compute kernel and return a columnar result iterator.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1311,4 +1311,12 @@ object GlutenConfig {
"'spark.bloom_filter.max_num_bits'")
.longConf
.createWithDefault(4194304L)

val CACHE_WHOLE_STAGE_TRANSFORMER_CONTEXT =
buildConf("spark.gluten.sql.cacheWholeStageTransformerContext")
.internal()
.doc("When true, `WholeStageTransformer` will cache the `WholeStageTransformerContext` " +
"when executing")
.booleanConf
.createWithDefault(false)
}

0 comments on commit c0d9752

Please sign in to comment.