Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Support StreamingAggregate if child output ordering is satisfied #3828

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ package io.glutenproject.execution

import io.glutenproject.execution.CHHashAggregateExecTransformer.getAggregateResultAttributes
import io.glutenproject.expression._
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.`type`.TypeNode
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{LocalFilesBuilder, RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -30,8 +29,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.types._

import com.google.protobuf.Any

import java.util

object CHHashAggregateExecTransformer {
Expand Down Expand Up @@ -285,31 +282,16 @@ case class CHHashAggregateExecTransformer(
)
aggregateFunctionList.add(aggFunctionNode)
})
if (!validation) {
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
context,
operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- originalInputAttributes) {
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}

val extensionNode = getAdvancedExtension(validation, originalInputAttributes)
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}

override def isStreaming: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,11 +654,14 @@ case class HashAggregateExecTransformer(
}
addFunctionNode(args, aggregateFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
})

val extensionNode = getAdvancedExtension()
RelBuilder.makeAggregateRel(
projectRel,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
*/
package io.glutenproject.execution

import io.glutenproject.GlutenConfig

import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.{GenerateExec, RDDScanExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.{avg, col, udf}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, StringType, StructField, StructType}

import scala.collection.JavaConverters
Expand Down Expand Up @@ -603,4 +606,37 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
assert(nativePlanString.contains("TableScan"))
}
}

test("Support StreamingAggregate if child output ordering is satisfied") {
withTable("t") {
spark
.range(10000)
.selectExpr(s"id % 999 as c1", "id as c2")
.write
.saveAsTable("t")

withSQLConf(
GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE.key -> "true",
GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1"
) {
val query =
"""
|SELECT c1, count(*), sum(c2) FROM (
|SELECT t1.c1, t2.c2 FROM t t1 JOIN t t2 ON t1.c1 = t2.c1
|)
|GROUP BY c1
|""".stripMargin
runQueryAndCompare(query) {
df =>
assert(
find(df.queryExecution.executedPlan)(
_.isInstanceOf[SortMergeJoinExecTransformer]).isDefined)
assert(
find(df.queryExecution.executedPlan)(
_.isInstanceOf[HashAggregateExecTransformer]).isDefined)
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.benchmark

import io.glutenproject.GlutenConfig

import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.internal.SQLConf

/**
* Benchmark to measure performance for streaming aggregate. To run this benchmark:
* {{{
* bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar>
* }}}
*/
object StreamingAggregateBenchmark extends SqlBasedBenchmark {
private val numRows = {
spark.sparkContext.conf.getLong("spark.gluten.benchmark.rows", 8 * 1000 * 1000)
}

private val mode = {
spark.sparkContext.conf.getLong("spark.gluten.benchmark.remainder", 4 * 1000 * 1000)
}

private def doBenchmark(): Unit = {
val benchmark = new Benchmark("streaming aggregate", numRows, output = output)

val query =
"""
|SELECT c1, count(*), sum(c2) FROM (
|SELECT t1.c1, t2.c2 FROM t t1 JOIN t t2 ON t1.c1 = t2.c1
|)
|GROUP BY c1
|""".stripMargin
benchmark.addCase(s"Enable streaming aggregate", 3) {
_ =>
withSQLConf(
GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false"
) {
spark.sql(query).noop()
}
}

benchmark.addCase(s"Disable streaming aggregate", 3) {
_ =>
withSQLConf(
GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false"
) {
spark.sql(query).noop()
}
}

benchmark.run()
}

override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
spark
.range(numRows)
.selectExpr(s"id % $mode as c1", "id as c2")
.write
.saveAsTable("t")

try {
doBenchmark()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me that we don't check query result in the benchmark code. So should we add some simple UT code with result check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I will add a test later

} finally {
spark.sql("DROP TABLE t")
}
}
}
5 changes: 5 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::

bool ignoreNullKeys = false;
std::vector<core::FieldAccessTypedExprPtr> preGroupingExprs;
if (aggRel.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(aggRel.advanced_extension(), "isStreaming=")) {
preGroupingExprs.reserve(veloxGroupingExprs.size());
preGroupingExprs.insert(preGroupingExprs.begin(), veloxGroupingExprs.begin(), veloxGroupingExprs.end());
}

// Get the output names of Aggregation.
std::vector<std::string> aggOutNames;
Expand Down
4 changes: 3 additions & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
if (aggRel.has_advanced_extension()) {
std::vector<TypePtr> types;
const auto& extension = aggRel.advanced_extension();
if (!validateInputTypes(extension, types)) {
// Aggregate always has advanced extension for steaming aggregate optimization,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: steaming -> streaming

// but only some of them have enhancement for validation.
if (extension.has_enhancement() && !validateInputTypes(extension, types)) {
logValidateMsg("native validation failed due to: Validation failed for input types in AggregateRel.");
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,6 @@ public static RelNode makeProjectRel(
return new ProjectRelNode(input, expressionNodes, extensionNode, emitStartIndex);
}

public static RelNode makeAggregateRel(
RelNode input,
List<ExpressionNode> groupings,
List<AggregateFunctionNode> aggregateFunctionNodes,
List<ExpressionNode> filters,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new AggregateRelNode(input, groupings, aggregateFunctionNodes, filters);
}

public static RelNode makeAggregateRel(
RelNode input,
List<ExpressionNode> groupings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
*/
package io.glutenproject.execution

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression._
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.substrait.`type`.TypeBuilder
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder}
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.rdd.RDD
Expand All @@ -35,7 +36,7 @@ import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.Any
import com.google.protobuf.{Any, StringValue}

import java.util.{ArrayList => JArrayList, List => JList}

Expand Down Expand Up @@ -73,6 +74,24 @@ abstract class HashAggregateExecBaseTransformer(
aggregateAttributes)
}

protected def isGroupingKeysPreGrouped: Boolean = {
if (!conf.getConf(GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE)) {
return false
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check doesn't do as the method name isGroupingKeysPreGrouped said?

I would suggest renaming isGroupingKeysPreGrouped to something like isCapableForStreamingAggregation, or just moving this config check out.

if (groupingExpressions.isEmpty) {
return false
}

val childOrdering = child match {
case agg: HashAggregateExecBaseTransformer
if agg.groupingExpressions == this.groupingExpressions =>
agg.child.outputOrdering
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why propagating child's child's ordering here? It would be great if we could also leave some comments in code. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some comments, hope it is clear now

case _ => child.outputOrdering
}
val requiredOrdering = groupingExpressions.map(expr => SortOrder.apply(expr, Ascending))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why ascending by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean it only fits with SMJ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC Ascending is the Spark default ordering, all the bulit-in SQL operators follow it. e.g., SMJ, Window, SortAggregate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it by Spark's limitation that a plan node can't require "asc or desc" child ordering at the same time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's a kind of hard code, but so far, I did not see any requirements about changing it. Will Velox operators require or introduce Descending ordering ?

SortOrder.orderingSatisfies(childOrdering, requiredOrdering)
}

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
throw new UnsupportedOperationException(s"This operator doesn't support doExecuteColumnar().")
}
Expand Down Expand Up @@ -328,11 +347,13 @@ abstract class HashAggregateExecBaseTransformer(
}
})

val extensionNode = getAdvancedExtension()
RelBuilder.makeAggregateRel(
inputRel,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}
Expand Down Expand Up @@ -533,30 +554,39 @@ abstract class HashAggregateExecBaseTransformer(
aggExpr.mode,
aggregateFunctionList)
})
if (!validation) {
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
context,
operatorId)
} else {

val extensionNode = getAdvancedExtension(validation, originalInputAttributes)
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}

protected def getAdvancedExtension(
validation: Boolean = false,
originalInputAttributes: Seq[Attribute] = Seq.empty): AdvancedExtensionNode = {
val enhancement = if (validation) {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)
} else {
null
}

val isStreaming = if (isGroupingKeysPreGrouped) {
"1"
} else {
"0"
}
val optimization =
Any.pack(StringValue.newBuilder.setValue(s"isStreaming=$isStreaming\n").build)
ExtensionBuilder.makeAdvancedExtension(optimization, enhancement)
}

protected def getAggRel(
Expand Down
Loading
Loading