diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 8beaa0aecfbb..c34a46e15d0e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -49,7 +49,6 @@ private object CHRuleApi { // Inject the regular Spark rules directly. injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark)) - // injector.injectQueryStagePrepRule(spark => LazyExpandRule(spark)) injector.injectParser( (spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface)) injector.injectParser( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala index 5cdf5388021b..109305393516 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala @@ -52,9 +52,7 @@ import org.apache.spark.sql.types._ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan] with Logging { override def apply(plan: SparkPlan): SparkPlan = { - logDebug( - s"xxx 1031 enable lazy aggregate expand: " + - s"${CHBackendSettings.enableLazyAggregateExpand}") + logError(s"xxx enable lazy aggregate expand: {CHBackendSettings.enableLazyAggregateExpand}") if (!CHBackendSettings.enableLazyAggregateExpand) { return plan } @@ -73,27 +71,27 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan _, _ ) => - logDebug(s"xxx match plan:$shuffle") + logError(s"xxx match plan:$shuffle") val partialAggregate = shuffle.child.asInstanceOf[CHHashAggregateExecTransformer] val expand = partialAggregate.child.asInstanceOf[ExpandExecTransformer] - logDebug( + logError( s"xxx partialAggregate: groupingExpressions:" + s"${partialAggregate.groupingExpressions}\n" + s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" + s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" + s"resultExpressions:${partialAggregate.resultExpressions}") - if (isSupportedAggregate(partialAggregate, expand, shuffle)) { + if (doValidation(partialAggregate, expand, shuffle)) { val attributesToReplace = buildReplaceAttributeMap(expand) - logDebug(s"xxx attributesToReplace: $attributesToReplace") + logError(s"xxx attributesToReplace: $attributesToReplace") - val newPartialAggregate = buildNewAggregateExec( + val newPartialAggregate = buildAheadAggregateExec( partialAggregate, expand, attributesToReplace ) - val newExpand = buildNewExpandExec( + val newExpand = buildPostExpandExec( expand, partialAggregate, newPartialAggregate, @@ -101,7 +99,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan ) val newShuffle = shuffle.copy(child = newExpand) - logDebug(s"xxx new plan: $newShuffle") + logError(s"xxx new plan: $newShuffle") newShuffle } else { shuffle @@ -123,23 +121,23 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan val partialAggregate = shuffle.child.asInstanceOf[CHHashAggregateExecTransformer] val filter = partialAggregate.child.asInstanceOf[FilterExecTransformer] val expand = filter.child.asInstanceOf[ExpandExecTransformer] - logDebug( + logError( s"xxx partialAggregate: groupingExpressions:" + s"${partialAggregate.groupingExpressions}\n" + s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" + s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" + s"resultExpressions:${partialAggregate.resultExpressions}") - if (isSupportedAggregate(partialAggregate, expand, shuffle)) { + if (doValidation(partialAggregate, expand, shuffle)) { val attributesToReplace = buildReplaceAttributeMap(expand) - logDebug(s"xxx attributesToReplace: $attributesToReplace") + logError(s"xxx attributesToReplace: $attributesToReplace") - val newPartialAggregate = buildNewAggregateExec( + val newPartialAggregate = buildAheadAggregateExec( partialAggregate, expand, attributesToReplace ) - val newExpand = buildNewExpandExec( + val newExpand = buildPostExpandExec( expand, partialAggregate, newPartialAggregate, @@ -149,7 +147,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan val newFilter = filter.copy(child = newExpand) val newShuffle = shuffle.copy(child = newFilter) - logDebug(s"xxx new plan: $newShuffle") + logError(s"xxx new plan: $newShuffle") newShuffle } else { @@ -164,24 +162,27 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan // 2. select n_name, count(distinct n_regionkey) as col1, // count(distinct concat(n_regionkey, n_nationkey)) as col2 from // nation group by n_name; - def isSupportedAggregate( + def doValidation( aggregate: CHHashAggregateExecTransformer, expand: ExpandExecTransformer, shuffle: ColumnarShuffleExchangeExec): Boolean = { // all grouping keys must be attribute references val expandOutputAttributes = expand.child.output.toSet - if (aggregate.groupingExpressions.exists(!_.isInstanceOf[Attribute])) { - logDebug(s"xxx Not all grouping expression are attribute references") + if ( + !aggregate.groupingExpressions.forall( + e => e.isInstanceOf[Attribute] || e.isInstanceOf[Literal]) + ) { + logError(s"xxx Not all grouping expression are attribute references") return false } // all shuffle keys must be attribute references if ( - shuffle.outputPartitioning + !shuffle.outputPartitioning .asInstanceOf[HashPartitioning] .expressions - .exists(!_.isInstanceOf[Attribute]) + .forall(e => e.isInstanceOf[Attribute] || e.isInstanceOf[Literal]) ) { - logDebug(s"xxx Not all shuffle hash expression are attribute references") + logError(s"xxx Not all shuffle hash expression are attribute references") return false } @@ -190,50 +191,54 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan if ( !aggregate.aggregateExpressions.forall( e => - isSupportedAggregateFunction(e) && + isValidAggregateFunction(e) && e.aggregateFunction.references.forall(expandOutputAttributes.contains(_))) ) { - logDebug(s"xxx Some aggregate functions are not supported") + logError(s"xxx Some aggregate functions are not supported") return false } - // ensure the last column of expand is grouping id - val groupIdIndex = findGroupingIdIndex(expand) - logDebug(s"xxx Find group id at index: $groupIdIndex") - if (groupIdIndex == -1) { - return false; - } - val groupIdAttribute = expand.output(groupIdIndex) - if ( - !groupIdAttribute.name.startsWith("grouping_id") && !groupIdAttribute.name.startsWith("gid") - && !groupIdAttribute.name.startsWith("spark_grouping_id") - ) { - logDebug(s"xxx Not found group id column at index $groupIdIndex") - return false - } - expand.projections.forall { - projection => - val groupId = projection(groupIdIndex) - groupId - .isInstanceOf[Literal] && (groupId.dataType.isInstanceOf[LongType] || groupId.dataType - .isInstanceOf[IntegerType]) - } + // get the group id's position in the expand's output + val gidIndex = findGroupingIdIndex(expand) + gidIndex != -1 } + // group id column doesn't have a fixed position, so we need to find it. def findGroupingIdIndex(expand: ExpandExecTransformer): Int = { + def isValidGroupIdColumn(e: Expression, gids: Set[Long]): Long = { + if (!e.isInstanceOf[Literal]) { + return -1 + } + val literalValue = e.asInstanceOf[Literal].value + e.dataType match { + case _: LongType => + if (gids.contains(literalValue.asInstanceOf[Long])) { + -1 + } else { + literalValue.asInstanceOf[Long] + } + case _: IntegerType => + if (gids.contains(literalValue.asInstanceOf[Int].toLong)) { + -1 + } else { + literalValue.asInstanceOf[Int].toLong + } + case _ => -1 + } + } + var groupIdIndexes = Seq[Int]() for (col <- 0 until expand.output.length) { val expandCol = expand.projections(0)(col) - if ( - expandCol.isInstanceOf[Literal] && (expandCol.dataType - .isInstanceOf[LongType] || expandCol.dataType.isInstanceOf[IntegerType]) - ) { + // gids should be unique + var gids = Set[Long]() + if (isValidGroupIdColumn(expandCol, gids) != -1) { if ( expand.projections.forall { projection => - val e = projection(col) - e.isInstanceOf[Literal] && - (e.dataType.isInstanceOf[LongType] || e.dataType.isInstanceOf[IntegerType]) + val res = isValidGroupIdColumn(projection(col), gids) + gids += res + res != -1 } ) { groupIdIndexes +:= col @@ -241,6 +246,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan } } if (groupIdIndexes.length == 1) { + logError(s"xxx gid is at pos ${groupIdIndexes(0)}") groupIdIndexes(0) } else { -1 @@ -252,7 +258,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan // column, avg. // - sum: if the input's type is decimal, the output are sum and isEmpty, but gluten doesn't use // the isEmpty column. - def isSupportedAggregateFunction(aggregateExpression: AggregateExpression): Boolean = { + def isValidAggregateFunction(aggregateExpression: AggregateExpression): Boolean = { if (aggregateExpression.filter.isDefined) { return false } @@ -290,7 +296,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan attributeMap } - def buildNewExpandProjections( + def buildPostExpandProjections( originalExpandProjections: Seq[Seq[Expression]], originalExpandOutput: Seq[Attribute], newExpandOutput: Seq[Attribute]): Seq[Seq[Expression]] = { @@ -309,65 +315,55 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan newExpandProjections } - // Make the child of expand be the child of aggregate - // Need to replace some attributes - def buildNewAggregateExec( + // 1. make expand's child be aggregate's child + // 2. replace the attributes in groupingExpressions and resultExpressions as needed + def buildAheadAggregateExec( partialAggregate: CHHashAggregateExecTransformer, expand: ExpandExecTransformer, attributesToReplace: Map[Attribute, Attribute]): SparkPlan = { val groupIdAttribute = expand.output(findGroupingIdIndex(expand)) - // if the grouping keys contains literal, they should not be in attributesToReplace - // And we need to remove them from the grouping keys - val groupingExpressions = partialAggregate.groupingExpressions - val newGroupingExpresion = - groupingExpressions + // New grouping expressions should include the group id column + val groupingExpressions = + partialAggregate.groupingExpressions .filter( e => e.toAttribute != groupIdAttribute && attributesToReplace.contains(e.toAttribute)) .map(e => getReplaceAttribute(e.toAttribute, attributesToReplace)) .distinct - logDebug( - s"xxx newGroupingExpresion: $newGroupingExpresion,\n" + - s"groupingExpressions: $groupingExpressions") + logError( + s"xxx newGroupingExpresion: $groupingExpressions,\n" + + s"groupingExpressions: ${partialAggregate.groupingExpressions}") - // Also need to remove literal grouping keys from the result expressions + // Remove group id column from result expressions val resultExpressions = partialAggregate.resultExpressions - val newResultExpressions = - resultExpressions - .filter { - e => - e.toAttribute != groupIdAttribute && (groupingExpressions - .find(_.toAttribute == e.toAttribute) - .isEmpty || attributesToReplace.contains(e.toAttribute)) - } - .map(e => getReplaceAttribute(e.toAttribute, attributesToReplace)) - .distinct - logDebug( - s"xxx newResultExpressions: $newResultExpressions\n" + - s"resultExpressions:$resultExpressions") + .filter(_.toAttribute != groupIdAttribute) + .map(e => getReplaceAttribute(e.toAttribute, attributesToReplace)) + logError( + s"xxx newResultExpressions: $resultExpressions\n" + + s"resultExpressions:${partialAggregate.resultExpressions}") partialAggregate.copy( - groupingExpressions = newGroupingExpresion, - resultExpressions = newResultExpressions, + groupingExpressions = groupingExpressions, + resultExpressions = resultExpressions, child = expand.child) } - def buildNewExpandExec( + def buildPostExpandExec( expand: ExpandExecTransformer, partialAggregate: CHHashAggregateExecTransformer, child: SparkPlan, attributesToReplace: Map[Attribute, Attribute]): SparkPlan = { // The output of the native plan is not completely consistent with Spark. val aggregateOutput = partialAggregate.output - logDebug(s"xxx aggregateResultAttributes: ${partialAggregate.aggregateResultAttributes}") - logDebug(s"xxx aggregateOutput: $aggregateOutput") + logError(s"xxx aggregateResultAttributes: ${partialAggregate.aggregateResultAttributes}") + logError(s"xxx aggregateOutput: $aggregateOutput") - val newExpandProjections = buildNewExpandProjections( + val expandProjections = buildPostExpandProjections( expand.projections, expand.output, aggregateOutput ) - logDebug(s"xxx newExpandProjections: $newExpandProjections\nprojections:${expand.projections}") - ExpandExecTransformer(newExpandProjections, aggregateOutput, child) + logError(s"xxx expandProjections: $expandProjections\nprojections:${expand.projections}") + ExpandExecTransformer(expandProjections, aggregateOutput, child) } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 191325c88195..66aa69b10e34 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -3024,6 +3024,13 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr } test("GLLUTEN-7647 lazy expand") { + def checkLazyExpand(df: DataFrame): Unit = { + val expands = collectWithSubqueries(df.queryExecution.executedPlan) { + case e: ExpandExecTransformer if (e.child.isInstanceOf[HashAggregateExecBaseTransformer]) => + e + } + assert(expands.size == 1) + } var sql = """ |select n_regionkey, n_nationkey, @@ -3031,14 +3038,14 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |from nation group by n_regionkey, n_nationkey with cube |order by n_regionkey, n_nationkey |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) + compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand) sql = """ |select n_regionkey, n_nationkey, sum(n_regionkey), count(distinct n_name) |from nation group by n_regionkey, n_nationkey with cube |order by n_regionkey, n_nationkey |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) + compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand) sql = """ |select n_regionkey, n_nationkey, @@ -3046,7 +3053,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |from nation group by n_regionkey, n_nationkey with cube |order by n_regionkey, n_nationkey |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) + compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand) sql = """ @@ -3055,7 +3062,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |from nation group by n_regionkey, n_nationkey grouping sets((n_regionkey), (n_nationkey)) |order by n_regionkey, n_nationkey |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) + compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand) sql = """ |select n_regionkey, n_nationkey, @@ -3064,7 +3071,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |grouping sets((n_regionkey, null), (null, n_nationkey)) |order by n_regionkey, n_nationkey |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) + compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand) sql = """ |select * from( @@ -3074,7 +3081,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |) where n_regionkey != 0 |order by n_regionkey, n_nationkey |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) + compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand) sql = """ |select * from( @@ -3084,7 +3091,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |) where n_regionkey != 0 |order by n_regionkey, n_nationkey |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) + compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand) } } // scalastyle:on line.size.limit