Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 5, 2024
1 parent 75a588e commit f63624c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,45 +159,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan): HashAggregateExecBaseTransformer = {
logError(s"xxx aggregateExpressions:$aggregateExpressions")
logError(s"xxx aggregateAttributes:$aggregateAttributes")
logError(s"xxx resultExpressions:$resultExpressions")
logError(s"xxx agg expr to result: ${aggregateExpressions.map(_.resultAttribute)}")
logError(
s"xxx agg:" +
s"${aggregateExpressions.map(e => e.aggregateFunction.aggBufferAttributes.length)}")
aggregateExpressions.foreach {
e => logError(s"xxx agg fun:$e, ${e.aggregateFunction.aggBufferAttributes}")
}
val modes = aggregateExpressions.map(_.mode)
logError(s"xxx modes: $modes")
val replacedResultExpressions = CHHashAggregateExecTransformer.getCHAggregateResultExpressions(
groupingExpressions,
aggregateExpressions,
resultExpressions)
logError(s"xxx adjust agg output: $replacedResultExpressions")
val agg1 = CHHashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions.distinct,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions.distinct,
// replacedResultExpressions,
child
)
logError(s"xxx agg output: ${agg1.output}")
val agg = CHHashAggregateExecTransformer(
CHHashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions.distinct,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
// resultExpressions.distinct,
replacedResultExpressions.distinct,
child
)
agg
}

/** Generate HashAggregateExecPullOutHelper */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,13 @@ class CHTransformerApi extends TransformerApi with Logging {
// output name will be different from grouping expressions,
// so using output attribute instead of grouping expression
val groupingExpressions = hash.output.splitAt(hash.groupingExpressions.size)._1
val aggResultAttributes = CHHashAggregateExecTransformer.getAggregateResultAttributes(
groupingExpressions,
hash.aggregateExpressions
)
val aggResultAttributes = CHHashAggregateExecTransformer
.getCHAggregateResultExpressions(
groupingExpressions,
hash.aggregateExpressions,
hash.resultExpressions
)
.map(_.toAttribute)
if (aggResultAttributes.size == hash.output.size) {
aggResultAttributes
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution.CHHashAggregateExecTransformer.getAggregateResultAttributes
import org.apache.gluten.expression._
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
Expand Down Expand Up @@ -82,13 +81,6 @@ object CHHashAggregateExecTransformer {
}
}

def getAggregateResultAttributes(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression]): Seq[Attribute] = {
groupingExpressions.map(ConverterUtils.getAttrFromExpr(_).toAttribute) ++ aggregateExpressions
.map(_.resultAttribute)
}

private val curId = new java.util.concurrent.atomic.AtomicLong()

def newStructFieldId(): Long = curId.getAndIncrement()
Expand Down Expand Up @@ -124,8 +116,7 @@ case class CHHashAggregateExecTransformer(
resultExpressions,
child) {

lazy val aggregateResultAttributes =
getAggregateResultAttributes(groupingExpressions, aggregateExpressions)
lazy val aggregateResultAttributes = resultExpressions.map(_.toAttribute)

protected val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct

Expand Down Expand Up @@ -284,12 +275,16 @@ case class CHHashAggregateExecTransformer(
"PartialMerge's child not being HashAggregateExecBaseTransformer" +
" is unsupported yet")
}
val hashAggregateChild = child.asInstanceOf[BaseAggregateExec]
val aggTypesExpr = ExpressionConverter
.replaceWithExpressionTransformer(
aggExpr.resultAttribute,
CHHashAggregateExecTransformer.getAggregateResultAttributes(
child.asInstanceOf[BaseAggregateExec].groupingExpressions,
child.asInstanceOf[BaseAggregateExec].aggregateExpressions)
CHHashAggregateExecTransformer
.getCHAggregateResultExpressions(
hashAggregateChild.groupingExpressions,
hashAggregateChild.aggregateExpressions,
hashAggregateChild.resultExpressions)
.map(_.toAttribute)
)
Seq(aggTypesExpr.doTransform(args))
case Final | PartialMerge =>
Expand Down
2 changes: 0 additions & 2 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ void adjustOutput(const DB::QueryPlanPtr & query_plan, const substrait::PlanRel
auto cols = query_plan->getCurrentHeader().getNamesAndTypesList();
if (cols.getNames().size() != static_cast<size_t>(root_rel.root().names_size()))
{
LOG_ERROR(getLogger("SerializedPlanParser"), "invalid query plan:\n{}", PlanUtil::explainPlan(*query_plan));
LOG_ERROR(getLogger("SerializedPlanParser"), "invalid substrait plan:\n{}", root_rel.DebugString());
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Missmatch result columns size. plan column size {}, subtrait plan size {}.",
Expand Down

0 comments on commit f63624c

Please sign in to comment.