Skip to content

Commit

Permalink
refactor 1101
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 1, 2024
1 parent b152a26 commit 8269184
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ private object CHRuleApi {
// Inject the regular Spark rules directly.
injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply)
injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark))
// injector.injectQueryStagePrepRule(spark => LazyExpandRule(spark))
injector.injectParser(
(spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface))
injector.injectParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ import org.apache.spark.sql.types._

case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
logDebug(
s"xxx 1031 enable lazy aggregate expand: " +
s"${CHBackendSettings.enableLazyAggregateExpand}")
logError(s"xxx enable lazy aggregate expand: {CHBackendSettings.enableLazyAggregateExpand}")
if (!CHBackendSettings.enableLazyAggregateExpand) {
return plan
}
Expand All @@ -73,35 +71,35 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
_,
_
) =>
logDebug(s"xxx match plan:$shuffle")
logError(s"xxx match plan:$shuffle")
val partialAggregate = shuffle.child.asInstanceOf[CHHashAggregateExecTransformer]
val expand = partialAggregate.child.asInstanceOf[ExpandExecTransformer]
logDebug(
logError(
s"xxx partialAggregate: groupingExpressions:" +
s"${partialAggregate.groupingExpressions}\n" +
s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" +
s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" +
s"resultExpressions:${partialAggregate.resultExpressions}")
if (isSupportedAggregate(partialAggregate, expand, shuffle)) {
if (doValidation(partialAggregate, expand, shuffle)) {

val attributesToReplace = buildReplaceAttributeMap(expand)
logDebug(s"xxx attributesToReplace: $attributesToReplace")
logError(s"xxx attributesToReplace: $attributesToReplace")

val newPartialAggregate = buildNewAggregateExec(
val newPartialAggregate = buildAheadAggregateExec(
partialAggregate,
expand,
attributesToReplace
)

val newExpand = buildNewExpandExec(
val newExpand = buildPostExpandExec(
expand,
partialAggregate,
newPartialAggregate,
attributesToReplace
)

val newShuffle = shuffle.copy(child = newExpand)
logDebug(s"xxx new plan: $newShuffle")
logError(s"xxx new plan: $newShuffle")
newShuffle
} else {
shuffle
Expand All @@ -123,23 +121,23 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
val partialAggregate = shuffle.child.asInstanceOf[CHHashAggregateExecTransformer]
val filter = partialAggregate.child.asInstanceOf[FilterExecTransformer]
val expand = filter.child.asInstanceOf[ExpandExecTransformer]
logDebug(
logError(
s"xxx partialAggregate: groupingExpressions:" +
s"${partialAggregate.groupingExpressions}\n" +
s"aggregateAttributes:${partialAggregate.aggregateAttributes}\n" +
s"aggregateExpressions:${partialAggregate.aggregateExpressions}\n" +
s"resultExpressions:${partialAggregate.resultExpressions}")
if (isSupportedAggregate(partialAggregate, expand, shuffle)) {
if (doValidation(partialAggregate, expand, shuffle)) {
val attributesToReplace = buildReplaceAttributeMap(expand)
logDebug(s"xxx attributesToReplace: $attributesToReplace")
logError(s"xxx attributesToReplace: $attributesToReplace")

val newPartialAggregate = buildNewAggregateExec(
val newPartialAggregate = buildAheadAggregateExec(
partialAggregate,
expand,
attributesToReplace
)

val newExpand = buildNewExpandExec(
val newExpand = buildPostExpandExec(
expand,
partialAggregate,
newPartialAggregate,
Expand All @@ -149,7 +147,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
val newFilter = filter.copy(child = newExpand)

val newShuffle = shuffle.copy(child = newFilter)
logDebug(s"xxx new plan: $newShuffle")
logError(s"xxx new plan: $newShuffle")
newShuffle

} else {
Expand All @@ -164,24 +162,27 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
// 2. select n_name, count(distinct n_regionkey) as col1,
// count(distinct concat(n_regionkey, n_nationkey)) as col2 from
// nation group by n_name;
def isSupportedAggregate(
def doValidation(
aggregate: CHHashAggregateExecTransformer,
expand: ExpandExecTransformer,
shuffle: ColumnarShuffleExchangeExec): Boolean = {
// all grouping keys must be attribute references
val expandOutputAttributes = expand.child.output.toSet
if (aggregate.groupingExpressions.exists(!_.isInstanceOf[Attribute])) {
logDebug(s"xxx Not all grouping expression are attribute references")
if (
!aggregate.groupingExpressions.forall(
e => e.isInstanceOf[Attribute] || e.isInstanceOf[Literal])
) {
logError(s"xxx Not all grouping expression are attribute references")
return false
}
// all shuffle keys must be attribute references
if (
shuffle.outputPartitioning
!shuffle.outputPartitioning
.asInstanceOf[HashPartitioning]
.expressions
.exists(!_.isInstanceOf[Attribute])
.forall(e => e.isInstanceOf[Attribute] || e.isInstanceOf[Literal])
) {
logDebug(s"xxx Not all shuffle hash expression are attribute references")
logError(s"xxx Not all shuffle hash expression are attribute references")
return false
}

Expand All @@ -190,57 +191,62 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
if (
!aggregate.aggregateExpressions.forall(
e =>
isSupportedAggregateFunction(e) &&
isValidAggregateFunction(e) &&
e.aggregateFunction.references.forall(expandOutputAttributes.contains(_)))
) {
logDebug(s"xxx Some aggregate functions are not supported")
logError(s"xxx Some aggregate functions are not supported")
return false
}

// ensure the last column of expand is grouping id
val groupIdIndex = findGroupingIdIndex(expand)
logDebug(s"xxx Find group id at index: $groupIdIndex")
if (groupIdIndex == -1) {
return false;
}
val groupIdAttribute = expand.output(groupIdIndex)
if (
!groupIdAttribute.name.startsWith("grouping_id") && !groupIdAttribute.name.startsWith("gid")
&& !groupIdAttribute.name.startsWith("spark_grouping_id")
) {
logDebug(s"xxx Not found group id column at index $groupIdIndex")
return false
}
expand.projections.forall {
projection =>
val groupId = projection(groupIdIndex)
groupId
.isInstanceOf[Literal] && (groupId.dataType.isInstanceOf[LongType] || groupId.dataType
.isInstanceOf[IntegerType])
}
// get the group id's position in the expand's output
val gidIndex = findGroupingIdIndex(expand)
gidIndex != -1
}

// group id column doesn't have a fixed position, so we need to find it.
def findGroupingIdIndex(expand: ExpandExecTransformer): Int = {
def isValidGroupIdColumn(e: Expression, gids: Set[Long]): Long = {
if (!e.isInstanceOf[Literal]) {
return -1
}
val literalValue = e.asInstanceOf[Literal].value
e.dataType match {
case _: LongType =>
if (gids.contains(literalValue.asInstanceOf[Long])) {
-1
} else {
literalValue.asInstanceOf[Long]
}
case _: IntegerType =>
if (gids.contains(literalValue.asInstanceOf[Int].toLong)) {
-1
} else {
literalValue.asInstanceOf[Int].toLong
}
case _ => -1
}
}

var groupIdIndexes = Seq[Int]()
for (col <- 0 until expand.output.length) {
val expandCol = expand.projections(0)(col)
if (
expandCol.isInstanceOf[Literal] && (expandCol.dataType
.isInstanceOf[LongType] || expandCol.dataType.isInstanceOf[IntegerType])
) {
// gids should be unique
var gids = Set[Long]()
if (isValidGroupIdColumn(expandCol, gids) != -1) {
if (
expand.projections.forall {
projection =>
val e = projection(col)
e.isInstanceOf[Literal] &&
(e.dataType.isInstanceOf[LongType] || e.dataType.isInstanceOf[IntegerType])
val res = isValidGroupIdColumn(projection(col), gids)
gids += res
res != -1
}
) {
groupIdIndexes +:= col
}
}
}
if (groupIdIndexes.length == 1) {
logError(s"xxx gid is at pos ${groupIdIndexes(0)}")
groupIdIndexes(0)
} else {
-1
Expand All @@ -252,7 +258,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
// column, avg.
// - sum: if the input's type is decimal, the output are sum and isEmpty, but gluten doesn't use
// the isEmpty column.
def isSupportedAggregateFunction(aggregateExpression: AggregateExpression): Boolean = {
def isValidAggregateFunction(aggregateExpression: AggregateExpression): Boolean = {
if (aggregateExpression.filter.isDefined) {
return false
}
Expand Down Expand Up @@ -290,7 +296,7 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
attributeMap
}

def buildNewExpandProjections(
def buildPostExpandProjections(
originalExpandProjections: Seq[Seq[Expression]],
originalExpandOutput: Seq[Attribute],
newExpandOutput: Seq[Attribute]): Seq[Seq[Expression]] = {
Expand All @@ -309,65 +315,55 @@ case class LazyAggregateExpandRule(session: SparkSession) extends Rule[SparkPlan
newExpandProjections
}

// Make the child of expand be the child of aggregate
// Need to replace some attributes
def buildNewAggregateExec(
// 1. make expand's child be aggregate's child
// 2. replace the attributes in groupingExpressions and resultExpressions as needed
def buildAheadAggregateExec(
partialAggregate: CHHashAggregateExecTransformer,
expand: ExpandExecTransformer,
attributesToReplace: Map[Attribute, Attribute]): SparkPlan = {
val groupIdAttribute = expand.output(findGroupingIdIndex(expand))

// if the grouping keys contains literal, they should not be in attributesToReplace
// And we need to remove them from the grouping keys
val groupingExpressions = partialAggregate.groupingExpressions
val newGroupingExpresion =
groupingExpressions
// New grouping expressions should include the group id column
val groupingExpressions =
partialAggregate.groupingExpressions
.filter(
e => e.toAttribute != groupIdAttribute && attributesToReplace.contains(e.toAttribute))
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
.distinct
logDebug(
s"xxx newGroupingExpresion: $newGroupingExpresion,\n" +
s"groupingExpressions: $groupingExpressions")
logError(
s"xxx newGroupingExpresion: $groupingExpressions,\n" +
s"groupingExpressions: ${partialAggregate.groupingExpressions}")

// Also need to remove literal grouping keys from the result expressions
// Remove group id column from result expressions
val resultExpressions = partialAggregate.resultExpressions
val newResultExpressions =
resultExpressions
.filter {
e =>
e.toAttribute != groupIdAttribute && (groupingExpressions
.find(_.toAttribute == e.toAttribute)
.isEmpty || attributesToReplace.contains(e.toAttribute))
}
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
.distinct
logDebug(
s"xxx newResultExpressions: $newResultExpressions\n" +
s"resultExpressions:$resultExpressions")
.filter(_.toAttribute != groupIdAttribute)
.map(e => getReplaceAttribute(e.toAttribute, attributesToReplace))
logError(
s"xxx newResultExpressions: $resultExpressions\n" +
s"resultExpressions:${partialAggregate.resultExpressions}")
partialAggregate.copy(
groupingExpressions = newGroupingExpresion,
resultExpressions = newResultExpressions,
groupingExpressions = groupingExpressions,
resultExpressions = resultExpressions,
child = expand.child)
}

def buildNewExpandExec(
def buildPostExpandExec(
expand: ExpandExecTransformer,
partialAggregate: CHHashAggregateExecTransformer,
child: SparkPlan,
attributesToReplace: Map[Attribute, Attribute]): SparkPlan = {
// The output of the native plan is not completely consistent with Spark.
val aggregateOutput = partialAggregate.output
logDebug(s"xxx aggregateResultAttributes: ${partialAggregate.aggregateResultAttributes}")
logDebug(s"xxx aggregateOutput: $aggregateOutput")
logError(s"xxx aggregateResultAttributes: ${partialAggregate.aggregateResultAttributes}")
logError(s"xxx aggregateOutput: $aggregateOutput")

val newExpandProjections = buildNewExpandProjections(
val expandProjections = buildPostExpandProjections(
expand.projections,
expand.output,
aggregateOutput
)
logDebug(s"xxx newExpandProjections: $newExpandProjections\nprojections:${expand.projections}")
ExpandExecTransformer(newExpandProjections, aggregateOutput, child)
logError(s"xxx expandProjections: $expandProjections\nprojections:${expand.projections}")
ExpandExecTransformer(expandProjections, aggregateOutput, child)
}

}
Loading

0 comments on commit 8269184

Please sign in to comment.