diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala index ca5118089755a..1f6af34ffd059 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala @@ -193,55 +193,59 @@ class TestOperator extends VeloxWholeStageTransformerSuite { } test("window expression") { - runQueryAndCompare( - "select row_number() over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + Seq("sort", "streaming").foreach { windowType => + withSQLConf( + "spark.gluten.sql.columnar.backend.velox.window.type" -> windowType.toString) { + runQueryAndCompare( + "select row_number() over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } - runQueryAndCompare( - "select rank() over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + runQueryAndCompare( + "select rank() over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } - runQueryAndCompare( - "select dense_rank() over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + runQueryAndCompare( + "select dense_rank() over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } - runQueryAndCompare( - "select percent_rank() over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + runQueryAndCompare( + "select percent_rank() over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } - runQueryAndCompare( - "select cume_dist() over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + runQueryAndCompare( + "select cume_dist() over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } - runQueryAndCompare( - "select l_suppkey, l_orderkey, nth_value(l_orderkey, 2) over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[WindowExecTransformer] - }) > 0) + runQueryAndCompare( + "select l_suppkey, l_orderkey, nth_value(l_orderkey, 2) over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[WindowExecTransformer] + }) > 0) + } } - } - - runQueryAndCompare( - "select sum(l_partkey + 1) over" + - " (partition by l_suppkey order by l_orderkey) from lineitem") { _ => } - runQueryAndCompare( - "select max(l_partkey) over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + runQueryAndCompare( + "select sum(l_partkey + 1) over" + + " (partition by l_suppkey order by l_orderkey) from lineitem") { _ => } - runQueryAndCompare( - "select min(l_partkey) over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + runQueryAndCompare( + "select max(l_partkey) over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } - runQueryAndCompare( - "select avg(l_partkey) over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + runQueryAndCompare( + "select min(l_partkey) over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + runQueryAndCompare( + "select avg(l_partkey) over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + } + } } test("chr function") { diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index bc77d4531dc1b..d3add1750c9bf 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -636,15 +636,29 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: } auto [sortingKeys, sortingOrders] = processSortField(windowRel.sorts(), inputType); - return std::make_shared( - nextPlanNodeId(), - partitionKeys, - sortingKeys, - sortingOrders, - windowColumnNames, - windowNodeFunctions, - true /*inputsSorted*/, - childNode); + + if (windowRel.has_advanced_extension() && + SubstraitParser::configSetInOptimization(windowRel.advanced_extension(), "isStreaming=")) { + return std::make_shared( + nextPlanNodeId(), + partitionKeys, + sortingKeys, + sortingOrders, + windowColumnNames, + windowNodeFunctions, + true /*inputsSorted*/, + childNode); + } else { + return std::make_shared( + nextPlanNodeId(), + partitionKeys, + sortingKeys, + sortingOrders, + windowColumnNames, + windowNodeFunctions, + false /*inputsSorted*/, + childNode); + } } core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::SortRel& sortRel) { diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WindowExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WindowExecTransformer.scala index fdf23f2838532..dbc10c8770bac 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WindowExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WindowExecTransformer.scala @@ -16,6 +16,7 @@ */ package io.glutenproject.execution +import io.glutenproject.GlutenConfig import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.expression._ import io.glutenproject.extension.ValidationResult @@ -33,7 +34,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.window.WindowExecBase import org.apache.spark.sql.vectorized.ColumnarBatch -import com.google.protobuf.Any +import com.google.protobuf.{Any, StringValue} import io.substrait.proto.SortField import java.util.{ArrayList => JArrayList} @@ -80,6 +81,26 @@ case class WindowExecTransformer( override def outputPartitioning: Partitioning = child.outputPartitioning + def genWindowParametersBuilder(): com.google.protobuf.Any.Builder = { + // Start with "WindowParameters:" + val windowParametersStr = new StringBuffer("WindowParameters:") + // isStreaming: 1 for streaming, 0 for sort + val isStreaming: Int = + if (GlutenConfig.getConf.veloxColumnarWindowType.equals("streaming")) 1 else 0 + + windowParametersStr + .append("isStreaming=") + .append(isStreaming) + .append("\n") + val message = StringValue + .newBuilder() + .setValue(windowParametersStr.toString) + .build() + com.google.protobuf.Any.newBuilder + .setValue(message.toByteString) + .setTypeUrl("/google.protobuf.StringValue") + } + def getRelNode( context: SubstraitContext, windowExpression: Seq[NamedExpression], @@ -132,11 +153,14 @@ case class WindowExecTransformer( builder.build() }.asJava if (!validation) { + val extensionNode = + ExtensionBuilder.makeAdvancedExtension(genWindowParametersBuilder.build(), null) RelBuilder.makeWindowRel( input, windowExpressions, partitionsExpressions, sortFieldList, + extensionNode, context, operatorId) } else { diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala index 287ae58076511..30407892750dc 100644 --- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala +++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala @@ -61,6 +61,8 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableColumnarWindow: Boolean = conf.getConf(COLUMNAR_WINDOW_ENABLED) + def veloxColumnarWindowType: String = conf.getConfString(COLUMNAR_VELOX_WINDOW_TYPE.key) + def enableColumnarShuffledHashJoin: Boolean = conf.getConf(COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED) def enableNativeColumnarToRow: Boolean = conf.getConf(COLUMNAR_COLUMNAR_TO_ROW_ENABLED) @@ -619,6 +621,21 @@ object GlutenConfig { .booleanConf .createWithDefault(true) + val COLUMNAR_VELOX_WINDOW_TYPE = + buildConf("spark.gluten.sql.columnar.backend.velox.window.type") + .internal() + .doc( + "Velox backend supports both SortWindow and" + + " StreamingWindow operators." + + " The StreamingWindow operator skips the sorting step" + + " in the input but does not support spill." + + " On the other hand, the SortWindow operator is " + + "responsible for sorting the input data within the" + + " Window operator and also supports spill.") + .stringConf + .checkValues(Set("streaming", "sort")) + .createWithDefault("streaming") + val COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED = buildConf("spark.gluten.sql.columnar.forceShuffledHashJoin") .internal()