diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index f07e80a177ff..8beaa0aecfbb 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -49,7 +49,7 @@ private object CHRuleApi { // Inject the regular Spark rules directly. injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark)) - injector.injectQueryStagePrepRule(spark => LazyExpandRule(spark)) + // injector.injectQueryStagePrepRule(spark => LazyExpandRule(spark)) injector.injectParser( (spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface)) injector.injectParser( @@ -84,6 +84,7 @@ private object CHRuleApi { injector.injectTransform(_ => CollapseProjectExecTransformer) injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session)) injector.injectTransform(c => PushdownAggregatePreProjectionAheadExpand.apply(c.session)) + injector.injectTransform(c => LazyAggregateExpandRule.apply(c.session)) injector.injectTransform( c => intercept( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyExpandRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala similarity index 50% rename from backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyExpandRule.scala rename to backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala index b78feb78b193..ad5893d801cb 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyExpandRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala @@ -17,16 +17,17 @@ package org.apache.gluten.extension import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings +import org.apache.gluten.execution._ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.aggregate._ -import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.types._ /* * For aggregation with grouping sets, we need to expand the grouping sets @@ -49,48 +50,39 @@ import org.apache.spark.sql.execution.exchange._ * If the aggregation involves distinct, we can't do this optimization. */ -case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Logging { +case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan] with Logging { override def apply(plan: SparkPlan): SparkPlan = { - logDebug(s"xxx enable lazy aggregate expand: ${CHBackendSettings.enableLazyAggregateExpand}") + logError( + s"xxx 1031 enable lazy aggregate expand: " + + s"${CHBackendSettings.enableLazyAggregateExpand}") if (!CHBackendSettings.enableLazyAggregateExpand) { return plan } plan.transformUp { - case finalAggregate @ HashAggregateExec( - _, - _, - _, - _, - _, - _, + case shuffle @ ColumnarShuffleExchangeExec( + HashPartitioning(hashExpressions, _), + CHHashAggregateExecTransformer( + _, + groupingExpressions, + aggregateExpressions, + _, + _, + resultExpressions, + ExpandExecTransformer(projections, output, child)), _, _, - ShuffleExchangeExec( - HashPartitioning(hashExpressions, _), - HashAggregateExec( - _, - _, - _, - groupingExpressions, - aggregateExpressions, - _, - _, - resultExpressions, - ExpandExec(projections, output, child)), - _ - ) + _ ) => - logError(s"xxx match plan:$finalAggregate") - // move expand node after shuffle node - if ( - groupingExpressions.forall(_.isInstanceOf[Attribute]) && - hashExpressions.forall(_.isInstanceOf[Attribute]) && - aggregateExpressions.forall(_.filter.isEmpty) - ) { - val shuffle = - finalAggregate.asInstanceOf[HashAggregateExec].child.asInstanceOf[ShuffleExchangeExec] - val partialAggregate = shuffle.child.asInstanceOf[HashAggregateExec] - val expand = partialAggregate.child.asInstanceOf[ExpandExec] + logError(s"xxx match plan:$shuffle") + val partialAggregate = shuffle.child.asInstanceOf[CHHashAggregateExecTransformer] + val expand = partialAggregate.child.asInstanceOf[ExpandExecTransformer] + logError( + s"xxx partialAggregate: groupingExpressions:" + + s"${partialAggregate.groupingExpressions}\n" + + s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" + + s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" + + s"resultExpressions:${partialAggregate.resultExpressions}") + if (isSupportedAggregate(partialAggregate, expand, shuffle)) { val attributesToReplace = buildReplaceAttributeMapForAggregate( groupingExpressions, @@ -113,54 +105,41 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo ) val newShuffle = shuffle.copy(child = newExpand) - val newFinalAggregate = finalAggregate.copy(child = newShuffle) - logError(s"xxx new plan: $newFinalAggregate") - newFinalAggregate + logError(s"xxx new plan: $newShuffle") + newShuffle } else { - finalAggregate + shuffle } - case finalAggregate @ HashAggregateExec( - _, - _, - _, - _, + case shuffle @ ColumnarShuffleExchangeExec( + HashPartitioning(hashExpressions, _), + CHHashAggregateExecTransformer( + _, + groupingExpressions, + aggregateExpressions, + _, + _, + resultExpressions, + FilterExecTransformer(_, ExpandExecTransformer(projections, output, child))), _, _, - _, - _, - ShuffleExchangeExec( - HashPartitioning(hashExpressions, _), - HashAggregateExec( - _, - _, - _, - groupingExpressions, - aggregateExpressions, - _, - _, - resultExpressions, - FilterExec(_, ExpandExec(projections, output, child))), - _ - ) + _ ) => - logError(s"xxx match plan:$finalAggregate") - if ( - groupingExpressions.forall(_.isInstanceOf[Attribute]) && - hashExpressions.forall(_.isInstanceOf[Attribute]) && - aggregateExpressions.forall(_.filter.isEmpty) - ) { - val shuffle = - finalAggregate.asInstanceOf[HashAggregateExec].child.asInstanceOf[ShuffleExchangeExec] - val partialAggregate = shuffle.child.asInstanceOf[HashAggregateExec] - val filter = partialAggregate.child.asInstanceOf[FilterExec] - val expand = filter.child.asInstanceOf[ExpandExec] - + val partialAggregate = shuffle.child.asInstanceOf[CHHashAggregateExecTransformer] + val filter = partialAggregate.child.asInstanceOf[FilterExecTransformer] + val expand = filter.child.asInstanceOf[ExpandExecTransformer] + logError( + s"xxx partialAggregate: groupingExpressions:" + + s"${partialAggregate.groupingExpressions}\n" + + s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" + + s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" + + s"resultExpressions:${partialAggregate.resultExpressions}") + if (isSupportedAggregate(partialAggregate, expand, shuffle)) { val attributesToReplace = buildReplaceAttributeMapForAggregate( groupingExpressions, projections, output ) - logDebug(s"xxx attributesToReplace: $attributesToReplace") + logError(s"xxx attributesToReplace: $attributesToReplace") val newPartialAggregate = buildNewAggregateExec( partialAggregate, @@ -176,16 +155,125 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo ) val newFilter = filter.copy(child = newExpand) + val newShuffle = shuffle.copy(child = newFilter) - val newFinalAggregate = finalAggregate.copy(child = newShuffle) - logError(s"xxx new plan: $newFinalAggregate") - newFinalAggregate + logError(s"xxx new plan: $newShuffle") + newShuffle + } else { - finalAggregate + shuffle } } } + // Just enable for simple cases. Some of cases that are not supported: + // 1. select count(a),count(b), count(1), count(distinct(a)), count(distinct(b)) from values + // (1, null), (2,2) as data(a,b); + // 2. select n_name, count(distinct n_regionkey) as col1, + // count(distinct concat(n_regionkey, n_nationkey)) as col2 from + // nation group by n_name; + def isSupportedAggregate( + aggregate: CHHashAggregateExecTransformer, + expand: ExpandExecTransformer, + shuffle: ColumnarShuffleExchangeExec): Boolean = { + // all grouping keys must be attribute references + val expandOutputAttributes = expand.child.output.toSet + if (aggregate.groupingExpressions.exists(!_.isInstanceOf[Attribute])) { + logError(s"xxx Not all grouping expression are attribute references") + return false + } + // all shuffle keys must be attribute references + if ( + shuffle.outputPartitioning + .asInstanceOf[HashPartitioning] + .expressions + .exists(!_.isInstanceOf[Attribute]) + ) { + logError(s"xxx Not all shuffle hash expression are attribute references") + return false + } + + // For safety, only enalbe for some aggregate functions + // All the parameters in aggregate functions must be the references of the output of expand's + // child + if ( + !aggregate.aggregateExpressions.forall( + e => + isSupportedAggregateFunction(e) && e.aggregateFunction.references.forall( + expandOutputAttributes.contains(_))) + ) { + logError(s"xxx Some aggregate functions are not supported") + return false + } + + // ensure the last column of expand is grouping id + val groupIdIndex = findGroupingIdIndex(expand) + logError(s"xxx Find group id at index: $groupIdIndex") + if (groupIdIndex == -1) { + return false; + } + val groupIdAttribute = expand.output(groupIdIndex) + if ( + !groupIdAttribute.name.startsWith("grouping_id") && !groupIdAttribute.name.startsWith("gid") + && !groupIdAttribute.name.startsWith("spark_grouping_id") + ) { + logError(s"xxx Not found group id column at index $groupIdIndex") + return false + } + expand.projections.forall { + projection => + val groupId = projection(groupIdIndex) + groupId + .isInstanceOf[Literal] && (groupId.dataType.isInstanceOf[LongType] || groupId.dataType + .isInstanceOf[IntegerType]) + } + } + + def findGroupingIdIndex(expand: ExpandExecTransformer): Int = { + var groupIdIndexes = Seq[Int]() + for (col <- 0 until expand.output.length) { + val expandCol = expand.projections(0)(col) + if ( + expandCol.isInstanceOf[Literal] && (expandCol.dataType + .isInstanceOf[LongType] || expandCol.dataType.isInstanceOf[IntegerType]) + ) { + if ( + expand.projections.forall { + projection => + val e = projection(col) + e.isInstanceOf[Literal] && + (e.dataType.isInstanceOf[LongType] || e.dataType.isInstanceOf[IntegerType]) + } + ) { + groupIdIndexes +:= col + } + } + } + if (groupIdIndexes.length == 1) { + groupIdIndexes(0) + } else { + -1 + } + } + + // Some of aggregate functions' output columns are not consistent with the output of gluten. + // - average: in partial aggregation, the outputs are sum and count, but gluten only generates one + // column, avg. + // - sum: if the input's type is decimal, the output are sum and isEmpty, but gluten doesn't use + // the isEmpty column. + def isSupportedAggregateFunction(aggregateExpression: AggregateExpression): Boolean = { + if (aggregateExpression.filter.isDefined) { + return false + } + aggregateExpression.aggregateFunction match { + case _: Count => true + case _: Max => true + case _: Min => true + case sum: Sum => !sum.dataType.isInstanceOf[DecimalType] + case _ => false + } + } + def getReplaceAttribute( toReplace: Attribute, attributesToReplace: Map[Attribute, Attribute]): Attribute = { @@ -215,7 +303,6 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo if (attr != null) { attributeMap += (e.toAttribute -> attr) } - // attributeMap +=(e.toAttribute -> fullExpandProjection(index).asInstanceOf[Attribute]) } attributeMap } @@ -225,33 +312,17 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo originalExpandProjections: Seq[Seq[Expression]], originalExpandOutput: Seq[Attribute], newExpandOutput: Seq[Attribute]): Seq[Seq[Expression]] = { - var groupingKeysPosition = Map[String, Int]() - originalGroupingExpressions.foreach { - e => - e match { - case ne: NamedExpression => - val index = originalExpandOutput.indexWhere(_.semanticEquals(ne.toAttribute)) - if (index != -1) { - groupingKeysPosition += (ne.name -> index) - } - case _ => - } - } - val newExpandProjections = originalExpandProjections.map { projection => - val res = newExpandOutput.map { + newExpandOutput.map { attr => - if (attr.isInstanceOf[Attribute]) { - groupingKeysPosition.get(attr.name) match { - case Some(attrPos) => projection(attrPos) - case None => attr - } + val index = originalExpandOutput.indexWhere(_.semanticEquals(attr)) + if (index != -1) { + projection(index) } else { attr } } - res } newExpandProjections } @@ -259,11 +330,10 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo // Make the child of expand be the child of aggregate // Need to replace some attributes def buildNewAggregateExec( - partialAggregate: HashAggregateExec, - expand: ExpandExec, + partialAggregate: CHHashAggregateExecTransformer, + expand: ExpandExecTransformer, attributesToReplace: Map[Attribute, Attribute]): SparkPlan = { val expandOutput = expand.output - // As far as know, the last attribute in the output is the groupId attribute. val groupIdAttribute = expandOutput(expandOutput.length - 1) // if the grouping keys contains literal, they should not be in attributesToReplace @@ -274,6 +344,7 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo .filter( e => e.toAttribute != groupIdAttribute && attributesToReplace.contains(e.toAttribute)) .map(e => getReplaceAttribute(e.toAttribute, attributesToReplace)) + .distinct logError( s"xxx newGroupingExpresion: $newGroupingExpresion,\n" + s"groupingExpressions: $groupingExpressions") @@ -289,6 +360,7 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo .isEmpty || attributesToReplace.contains(e.toAttribute)) } .map(e => getReplaceAttribute(e.toAttribute, attributesToReplace)) + .distinct logError( s"xxx newResultExpressions: $newResultExpressions\n" + s"resultExpressions:$resultExpressions") @@ -299,13 +371,15 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo } def buildNewExpandExec( - expand: ExpandExec, - partialAggregate: HashAggregateExec, + expand: ExpandExecTransformer, + partialAggregate: CHHashAggregateExecTransformer, child: SparkPlan, attributesToReplace: Map[Attribute, Attribute]): SparkPlan = { - val newExpandProjectionTemplate = - partialAggregate.output - .map(e => getReplaceAttribute(e, attributesToReplace)) + // The output of the native plan is not completely consistent with Spark. + val aggregateOutput = partialAggregate.output + val newExpandProjectionTemplate = aggregateOutput + // aggregateOutput.map(e => getReplaceAttribute(e, attributesToReplace)) + logError(s"xxx aggregateResultAttributes: ${partialAggregate.aggregateResultAttributes}") logError(s"xxx newExpandProjectionTemplate: $newExpandProjectionTemplate") val newExpandProjections = buildNewExpandProjections( @@ -315,7 +389,7 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo newExpandProjectionTemplate ) logError(s"xxx newExpandProjections: $newExpandProjections\nprojections:${expand.projections}") - ExpandExec(newExpandProjections, partialAggregate.output, child) + ExpandExecTransformer(newExpandProjections, aggregateOutput, child) } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index 2f9b2a78b4a7..a56f45d1ba3d 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -563,63 +563,5 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { compareResultsAgainstVanillaSpark(sql, true, { _ => }) spark.sql("drop table t1") } - - test("GLLUTEN-7647 lazy expand") { - var sql = - """ - |select n_regionkey, n_nationkey, sum(n_regionkey), count(n_name) - |from nation group by n_regionkey, n_nationkey with cube - |order by n_regionkey, n_nationkey - |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) - - sql = """ - |select n_regionkey, n_nationkey, sum(n_regionkey), count(distinct n_name) - |from nation group by n_regionkey, n_nationkey with cube - |order by n_regionkey, n_nationkey - |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) - - sql = """ - |select n_regionkey, n_nationkey, sum(distinct n_regionkey), count(distinct n_name) - |from nation group by n_regionkey, n_nationkey with cube - |order by n_regionkey, n_nationkey - |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) - - sql = - """ - |select n_regionkey, n_nationkey, sum(distinct n_regionkey), count(distinct n_name) - |from nation group by n_regionkey, n_nationkey grouping sets((n_regionkey), (n_nationkey)) - |order by n_regionkey, n_nationkey - |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) - - sql = """ - |select n_regionkey, n_nationkey, sum(distinct n_regionkey), count(distinct n_name) - |from nation group by n_regionkey, n_nationkey - |grouping sets((n_regionkey, null), (null, n_nationkey)) - |order by n_regionkey, n_nationkey - |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) - - sql = """ - |select * from( - |select n_regionkey, n_nationkey, sum(n_regionkey), count(n_name) - |from nation group by n_regionkey, n_nationkey with cube - |) where n_regionkey != 0 - |order by n_regionkey, n_nationkey - |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) - - sql = """ - |select * from( - |select n_regionkey, n_nationkey, sum(n_regionkey), count(distinct n_name) - |from nation group by n_regionkey, n_nationkey with cube - |) where n_regionkey != 0 - |order by n_regionkey, n_nationkey - |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) - } } // scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 739b040dba1d..191325c88195 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -3022,5 +3022,69 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr compareResultsAgainstVanillaSpark(query_sql, true, { _ => }) spark.sql("drop table test_tbl_7220") } + + test("GLLUTEN-7647 lazy expand") { + var sql = + """ + |select n_regionkey, n_nationkey, + |sum(n_regionkey), count(n_name), max(n_regionkey), min(n_regionkey) + |from nation group by n_regionkey, n_nationkey with cube + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select n_regionkey, n_nationkey, sum(n_regionkey), count(distinct n_name) + |from nation group by n_regionkey, n_nationkey with cube + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select n_regionkey, n_nationkey, + |sum(distinct n_regionkey), count(distinct n_name), max(n_regionkey), min(n_regionkey) + |from nation group by n_regionkey, n_nationkey with cube + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = + """ + |select n_regionkey, n_nationkey, + |sum(distinct n_regionkey), count(distinct n_name), max(n_regionkey), min(n_regionkey) + |from nation group by n_regionkey, n_nationkey grouping sets((n_regionkey), (n_nationkey)) + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select n_regionkey, n_nationkey, + |sum(distinct n_regionkey), count(distinct n_name), max(n_regionkey), min(n_regionkey) + |from nation group by n_regionkey, n_nationkey + |grouping sets((n_regionkey, null), (null, n_nationkey)) + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select * from( + |select n_regionkey, n_nationkey, + |sum(n_regionkey), count(n_name), max(n_regionkey), min(n_regionkey) + |from nation group by n_regionkey, n_nationkey with cube + |) where n_regionkey != 0 + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select * from( + |select n_regionkey, n_nationkey, + |sum(n_regionkey), count(distinct n_name), max(n_regionkey), min(n_regionkey) + |from nation group by n_regionkey, n_nationkey with cube + |) where n_regionkey != 0 + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } } // scalastyle:on line.size.limit