-
Notifications
You must be signed in to change notification settings - Fork 435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[VL] Support StreamingAggregate if child output ordering is satisfied #3828
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.sql.execution.benchmark | ||
|
||
import io.glutenproject.GlutenConfig | ||
|
||
import org.apache.spark.benchmark.Benchmark | ||
import org.apache.spark.sql.internal.SQLConf | ||
|
||
/** | ||
* Benchmark to measure performance for streaming aggregate. To run this benchmark: | ||
* {{{ | ||
* bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar> | ||
* }}} | ||
*/ | ||
object StreamingAggregateBenchmark extends SqlBasedBenchmark { | ||
private val numRows = { | ||
spark.sparkContext.conf.getLong("spark.gluten.benchmark.rows", 8 * 1000 * 1000) | ||
} | ||
|
||
private val mode = { | ||
spark.sparkContext.conf.getLong("spark.gluten.benchmark.remainder", 4 * 1000 * 1000) | ||
} | ||
|
||
private def doBenchmark(): Unit = { | ||
val benchmark = new Benchmark("streaming aggregate", numRows, output = output) | ||
|
||
val query = | ||
""" | ||
|SELECT c1, count(*), sum(c2) FROM ( | ||
|SELECT t1.c1, t2.c2 FROM t t1 JOIN t t2 ON t1.c1 = t2.c1 | ||
|) | ||
|GROUP BY c1 | ||
|""".stripMargin | ||
benchmark.addCase(s"Enable streaming aggregate", 3) { | ||
_ => | ||
withSQLConf( | ||
GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE.key -> "true", | ||
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | ||
GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false" | ||
) { | ||
spark.sql(query).noop() | ||
} | ||
} | ||
|
||
benchmark.addCase(s"Disable streaming aggregate", 3) { | ||
_ => | ||
withSQLConf( | ||
GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE.key -> "false", | ||
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | ||
GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false" | ||
) { | ||
spark.sql(query).noop() | ||
} | ||
} | ||
|
||
benchmark.run() | ||
} | ||
|
||
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | ||
spark | ||
.range(numRows) | ||
.selectExpr(s"id % $mode as c1", "id as c2") | ||
.write | ||
.saveAsTable("t") | ||
|
||
try { | ||
doBenchmark() | ||
} finally { | ||
spark.sql("DROP TABLE t") | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -976,7 +976,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag | |
if (aggRel.has_advanced_extension()) { | ||
std::vector<TypePtr> types; | ||
const auto& extension = aggRel.advanced_extension(); | ||
if (!validateInputTypes(extension, types)) { | ||
// Aggregate always has advanced extension for steaming aggregate optimization, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: steaming -> streaming |
||
// but only some of them have enhancement for validation. | ||
if (extension.has_enhancement() && !validateInputTypes(extension, types)) { | ||
logValidateMsg("native validation failed due to: Validation failed for input types in AggregateRel."); | ||
return false; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,14 +16,15 @@ | |
*/ | ||
package io.glutenproject.execution | ||
|
||
import io.glutenproject.GlutenConfig | ||
import io.glutenproject.backendsapi.BackendsApiManager | ||
import io.glutenproject.expression._ | ||
import io.glutenproject.extension.ValidationResult | ||
import io.glutenproject.metrics.MetricsUpdater | ||
import io.glutenproject.substrait.`type`.TypeBuilder | ||
import io.glutenproject.substrait.{AggregationParams, SubstraitContext} | ||
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode} | ||
import io.glutenproject.substrait.extensions.ExtensionBuilder | ||
import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder} | ||
import io.glutenproject.substrait.rel.{RelBuilder, RelNode} | ||
|
||
import org.apache.spark.rdd.RDD | ||
|
@@ -35,7 +36,7 @@ import org.apache.spark.sql.execution.aggregate._ | |
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.vectorized.ColumnarBatch | ||
|
||
import com.google.protobuf.Any | ||
import com.google.protobuf.{Any, StringValue} | ||
|
||
import java.util.{ArrayList => JArrayList, List => JList} | ||
|
||
|
@@ -73,6 +74,24 @@ abstract class HashAggregateExecBaseTransformer( | |
aggregateAttributes) | ||
} | ||
|
||
protected def isGroupingKeysPreGrouped: Boolean = { | ||
if (!conf.getConf(GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE)) { | ||
return false | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check doesn't do as the method name I would suggest renaming |
||
if (groupingExpressions.isEmpty) { | ||
return false | ||
} | ||
|
||
val childOrdering = child match { | ||
case agg: HashAggregateExecBaseTransformer | ||
if agg.groupingExpressions == this.groupingExpressions => | ||
agg.child.outputOrdering | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why propagating child's child's ordering here? It would be great if we could also leave some comments in code. Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added some comments, hope it is clear now |
||
case _ => child.outputOrdering | ||
} | ||
val requiredOrdering = groupingExpressions.map(expr => SortOrder.apply(expr, Ascending)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why ascending by default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It follows the default behavior of SMJ requiredChildOrdering. see https://github.com/apache/spark/blob/ef27b9b15687dad416b6353409b1b44bc1451885/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala#L81-L92 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does that mean it only fits with SMJ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it by Spark's limitation that a plan node can't require "asc or desc" child ordering at the same time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it's a kind of hard code, but so far, I did not see any requirements about changing it. Will Velox operators require or introduce Descending ordering ? |
||
SortOrder.orderingSatisfies(childOrdering, requiredOrdering) | ||
} | ||
|
||
override def doExecuteColumnar(): RDD[ColumnarBatch] = { | ||
throw new UnsupportedOperationException(s"This operator doesn't support doExecuteColumnar().") | ||
} | ||
|
@@ -328,11 +347,13 @@ abstract class HashAggregateExecBaseTransformer( | |
} | ||
}) | ||
|
||
val extensionNode = getAdvancedExtension() | ||
RelBuilder.makeAggregateRel( | ||
inputRel, | ||
groupingList, | ||
aggregateFunctionList, | ||
aggFilterList, | ||
extensionNode, | ||
context, | ||
operatorId) | ||
} | ||
|
@@ -533,30 +554,39 @@ abstract class HashAggregateExecBaseTransformer( | |
aggExpr.mode, | ||
aggregateFunctionList) | ||
}) | ||
if (!validation) { | ||
RelBuilder.makeAggregateRel( | ||
input, | ||
groupingList, | ||
aggregateFunctionList, | ||
aggFilterList, | ||
context, | ||
operatorId) | ||
} else { | ||
|
||
val extensionNode = getAdvancedExtension(validation, originalInputAttributes) | ||
RelBuilder.makeAggregateRel( | ||
input, | ||
groupingList, | ||
aggregateFunctionList, | ||
aggFilterList, | ||
extensionNode, | ||
context, | ||
operatorId) | ||
} | ||
|
||
protected def getAdvancedExtension( | ||
validation: Boolean = false, | ||
originalInputAttributes: Seq[Attribute] = Seq.empty): AdvancedExtensionNode = { | ||
val enhancement = if (validation) { | ||
// Use a extension node to send the input types through Substrait plan for validation. | ||
val inputTypeNodeList = originalInputAttributes | ||
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) | ||
.asJava | ||
val extensionNode = ExtensionBuilder.makeAdvancedExtension( | ||
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) | ||
RelBuilder.makeAggregateRel( | ||
input, | ||
groupingList, | ||
aggregateFunctionList, | ||
aggFilterList, | ||
extensionNode, | ||
context, | ||
operatorId) | ||
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf) | ||
} else { | ||
null | ||
} | ||
|
||
val isStreaming = if (isGroupingKeysPreGrouped) { | ||
"1" | ||
} else { | ||
"0" | ||
} | ||
val optimization = | ||
Any.pack(StringValue.newBuilder.setValue(s"isStreaming=$isStreaming\n").build) | ||
ExtensionBuilder.makeAdvancedExtension(optimization, enhancement) | ||
} | ||
|
||
protected def getAggRel( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to me that we don't check query result in the benchmark code. So should we add some simple UT code with result check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I will add a test later