Skip to content

Commit

Permalink
Support StreamingAggregate if child output ordering is satisfied
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Nov 24, 2023
1 parent 35756a1 commit 68dd54d
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ package io.glutenproject.execution

import io.glutenproject.execution.CHHashAggregateExecTransformer.getAggregateResultAttributes
import io.glutenproject.expression._
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.`type`.TypeNode
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{LocalFilesBuilder, RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -30,8 +29,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.types._

import com.google.protobuf.Any

import java.util

object CHHashAggregateExecTransformer {
Expand Down Expand Up @@ -285,31 +282,16 @@ case class CHHashAggregateExecTransformer(
)
aggregateFunctionList.add(aggFunctionNode)
})
if (!validation) {
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
context,
operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- originalInputAttributes) {
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}

val extensionNode = getAdvancedExtension(validation, originalInputAttributes)
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}

override def isStreaming: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,11 +654,14 @@ case class HashAggregateExecTransformer(
}
addFunctionNode(args, aggregateFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
})

val extensionNode = getAdvancedExtension()
RelBuilder.makeAggregateRel(
projectRel,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.benchmark

import io.glutenproject.GlutenConfig

import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.internal.SQLConf

/**
* Benchmark to measure performance for streaming aggregate. To run this benchmark:
* {{{
* bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar>
* }}}
*/
object StreamingAggregateBenchmark extends SqlBasedBenchmark {
private val numRows = {
spark.sparkContext.conf.getLong("spark.gluten.benchmark.rows", 8 * 1000 * 1000)
}

private val mode = {
spark.sparkContext.conf.getLong("spark.gluten.benchmark.remainder", 4 * 1000 * 1000)
}

private def doBenchmark(): Unit = {
val benchmark = new Benchmark("streaming aggregate", numRows, output = output)

val query =
"""
|SELECT c1, count(*), sum(c2) FROM (
|SELECT t1.c1, t2.c2 FROM t t1 JOIN t t2 ON t1.c1 = t2.c1
|)
|GROUP BY c1
|""".stripMargin
benchmark.addCase(s"Enable streaming aggregate", 3) {
_ =>
withSQLConf(
GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false"
) {
spark.sql(query).noop()
}
}

benchmark.addCase(s"Disable streaming aggregate", 3) {
_ =>
withSQLConf(
GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false"
) {
spark.sql(query).noop()
}
}

benchmark.run()
}

override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
spark
.range(numRows)
.selectExpr(s"id % $mode as c1", "id as c2")
.write
.saveAsTable("t")

try {
doBenchmark()
} finally {
spark.sql("DROP TABLE t")
}
}
}
5 changes: 5 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::

bool ignoreNullKeys = false;
std::vector<core::FieldAccessTypedExprPtr> preGroupingExprs;
if (aggRel.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(aggRel.advanced_extension(), "isStreaming=")) {
preGroupingExprs.reserve(veloxGroupingExprs.size());
preGroupingExprs.insert(preGroupingExprs.begin(), veloxGroupingExprs.begin(), veloxGroupingExprs.end());
}

// Get the output names of Aggregation.
std::vector<std::string> aggOutNames;
Expand Down
4 changes: 3 additions & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
if (aggRel.has_advanced_extension()) {
std::vector<TypePtr> types;
const auto& extension = aggRel.advanced_extension();
if (!validateInputTypes(extension, types)) {
// Aggregate always has advanced extension for steaming aggregate optimization,
// but only some of them have enhancement for validation.
if (extension.has_enhancement() && !validateInputTypes(extension, types)) {
logValidateMsg("native validation failed due to: Validation failed for input types in AggregateRel.");
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,6 @@ public static RelNode makeProjectRel(
return new ProjectRelNode(input, expressionNodes, extensionNode, emitStartIndex);
}

public static RelNode makeAggregateRel(
RelNode input,
List<ExpressionNode> groupings,
List<AggregateFunctionNode> aggregateFunctionNodes,
List<ExpressionNode> filters,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new AggregateRelNode(input, groupings, aggregateFunctionNodes, filters);
}

public static RelNode makeAggregateRel(
RelNode input,
List<ExpressionNode> groupings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP
}
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override protected def doValidateInternal(): ValidationResult = {
if (cond == null) {
// The computing of this Filter is not needed.
Expand Down Expand Up @@ -181,6 +183,7 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP
case class ProjectExecTransformer private (projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryTransformSupport
with PredicateHelper
with AliasAwareOutputOrdering
with Logging {

// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
Expand All @@ -189,6 +192,10 @@ case class ProjectExecTransformer private (projectList: Seq[NamedExpression], ch

val sparkConf: SparkConf = sparkContext.getConf

override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering

override protected def outputExpressions: Seq[NamedExpression] = projectList

override protected def doValidateInternal(): ValidationResult = {
val substraitContext = new SubstraitContext
// Firstly, need to check if the Substrait plan for this operator can be successfully generated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
*/
package io.glutenproject.execution

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression._
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.substrait.`type`.TypeBuilder
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder}
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.rdd.RDD
Expand All @@ -35,7 +36,7 @@ import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.Any
import com.google.protobuf.{Any, StringValue}

import java.util.{ArrayList => JArrayList, List => JList}

Expand Down Expand Up @@ -73,6 +74,24 @@ abstract class HashAggregateExecBaseTransformer(
aggregateAttributes)
}

protected def isGroupingKeysPreGrouped: Boolean = {
if (!conf.getConf(GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE)) {
return false
}
if (groupingExpressions.isEmpty) {
return false
}

val childOrdering = child match {
case agg: HashAggregateExecBaseTransformer
if agg.groupingExpressions == this.groupingExpressions =>
agg.child.outputOrdering
case _ => child.outputOrdering
}
val requiredOrdering = groupingExpressions.map(expr => SortOrder.apply(expr, Ascending))
SortOrder.orderingSatisfies(childOrdering, requiredOrdering)
}

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
throw new UnsupportedOperationException(s"This operator doesn't support doExecuteColumnar().")
}
Expand Down Expand Up @@ -328,11 +347,13 @@ abstract class HashAggregateExecBaseTransformer(
}
})

val extensionNode = getAdvancedExtension()
RelBuilder.makeAggregateRel(
inputRel,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}
Expand Down Expand Up @@ -533,30 +554,39 @@ abstract class HashAggregateExecBaseTransformer(
aggExpr.mode,
aggregateFunctionList)
})
if (!validation) {
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
context,
operatorId)
} else {

val extensionNode = getAdvancedExtension(validation, originalInputAttributes)
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}

protected def getAdvancedExtension(
validation: Boolean = false,
originalInputAttributes: Seq[Attribute] = Seq.empty): AdvancedExtensionNode = {
val enhancement = if (validation) {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)
} else {
null
}

val isStreaming = if (isGroupingKeysPreGrouped) {
"1"
} else {
"0"
}
val optimization =
Any.pack(StringValue.newBuilder.setValue(s"isStreaming=$isStreaming\n").build)
ExtensionBuilder.makeAdvancedExtension(optimization, enhancement)
}

protected def getAggRel(
Expand Down
11 changes: 11 additions & 0 deletions shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,17 @@ object GlutenConfig {
.checkValues(Set("streaming", "sort"))
.createWithDefault("streaming")

val COLUMNAR_PREFER_STREAMING_AGGREGATE =
buildConf("spark.gluten.sql.columnar.preferStreamingAggregate")
.internal()
.doc(
"Velox backend supports `StreamingAggregate`. `StreamingAggregate` uses the less " +
"memory as it does not need to hold all groups in memory, so it could avoid spill. " +
"When true and the child output ordering satisfies the grouping key then " +
"Gluten will choose `StreamingAggregate` as the native operator.")
.booleanConf
.createWithDefault(true)

val COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED =
buildConf("spark.gluten.sql.columnar.forceShuffledHashJoin")
.internal()
Expand Down

0 comments on commit 68dd54d

Please sign in to comment.