diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 59d912d8e75d..6f07fc321f06 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -348,6 +348,18 @@ object CHBackendSettings extends BackendSettingsApi with Logging { ) } + // It try to move the expand node after the pre-aggregate node. That is to make the plan from + // expand -> pre-aggregate -> shuffle -> final-aggregate + // to + // pre-aggregate -> expand -> shuffle -> final-aggregate + // It could reduce the overhead of pre-aggregate node. + def enableLazyAggregateExpand(): Boolean = { + SparkEnv.get.conf.getBoolean( + CHConf.runtimeConfig("enable_lazy_aggregate_expand"), + defaultValue = true + ) + } + override def enableNativeWriteFiles(): Boolean = { GlutenConfig.getConf.enableNativeWriter.getOrElse(false) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyExpandRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyExpandRule.scala index f39c7d9f4841..b78feb78b193 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyExpandRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyExpandRule.scala @@ -16,6 +16,8 @@ */ package org.apache.gluten.extension +import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings + import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ @@ -48,104 +50,172 @@ import org.apache.spark.sql.execution.exchange._ */ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Logging { - override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case finalAggregate @ HashAggregateExec( - _, - _, - _, - _, - _, - _, - _, - _, - ShuffleExchangeExec( - HashPartitioning(hashExpressions, _), - HashAggregateExec( - _, - _, - _, - groupingExpressions, - aggregateExpressions, - _, - _, - resultExpressions, - ExpandExec(projections, output, child)), - _ + override def apply(plan: SparkPlan): SparkPlan = { + logDebug(s"xxx enable lazy aggregate expand: ${CHBackendSettings.enableLazyAggregateExpand}") + if (!CHBackendSettings.enableLazyAggregateExpand) { + return plan + } + plan.transformUp { + case finalAggregate @ HashAggregateExec( + _, + _, + _, + _, + _, + _, + _, + _, + ShuffleExchangeExec( + HashPartitioning(hashExpressions, _), + HashAggregateExec( + _, + _, + _, + groupingExpressions, + aggregateExpressions, + _, + _, + resultExpressions, + ExpandExec(projections, output, child)), + _ + ) + ) => + logError(s"xxx match plan:$finalAggregate") + // move expand node after shuffle node + if ( + groupingExpressions.forall(_.isInstanceOf[Attribute]) && + hashExpressions.forall(_.isInstanceOf[Attribute]) && + aggregateExpressions.forall(_.filter.isEmpty) + ) { + val shuffle = + finalAggregate.asInstanceOf[HashAggregateExec].child.asInstanceOf[ShuffleExchangeExec] + val partialAggregate = shuffle.child.asInstanceOf[HashAggregateExec] + val expand = partialAggregate.child.asInstanceOf[ExpandExec] + + val attributesToReplace = buildReplaceAttributeMapForAggregate( + groupingExpressions, + projections, + output ) - ) => - // move expand node after shuffle node - if ( - projections.exists( - projection => - projection.forall( - e => !e.isInstanceOf[Literal] || e.asInstanceOf[Literal].value != null)) && - groupingExpressions.forall(_.isInstanceOf[Attribute]) && - hashExpressions.forall(_.isInstanceOf[Attribute]) && - aggregateExpressions.forall(_.filter.isEmpty) - ) { - val shuffle = - finalAggregate.asInstanceOf[HashAggregateExec].child.asInstanceOf[ShuffleExchangeExec] - val partialAggregate = shuffle.child.asInstanceOf[HashAggregateExec] - - val attributesToReplace = buildReplaceAttributeMapForAggregate( - groupingExpressions, - projections, - output - ) - val newGroupingExpresion = - groupingExpressions - .filter(_.name.startsWith("spark_grouping_id") == false) - .map(e => attributesToReplace.getOrElse(e.name, e)) - val newResultExpressions = - resultExpressions - .filter(_.name.startsWith("spark_grouping_id") == false) - .map(e => attributesToReplace.getOrElse(e.name, e)) - val newHashExpresions = - hashExpressions - .filter(_.asInstanceOf[Attribute].name.startsWith("spark_grouping_id") == false) - .map { - e => - e match { - case ne: NamedExpression => attributesToReplace.getOrElse(ne.name, e) - case _ => e - } - } - val newExpandProjectionTemplate = - partialAggregate.output.map(e => attributesToReplace.getOrElse(e.name, e)) - val newExpandProjections = buildNewExpandProjections( - groupingExpressions, - projections, - output, - newExpandProjectionTemplate - ) - val newPartialAggregate = partialAggregate.copy( - groupingExpressions = newGroupingExpresion, - resultExpressions = newResultExpressions, - child = child - ) - val newExpand = - ExpandExec(newExpandProjections, partialAggregate.output, newPartialAggregate) - val newShuffle = shuffle.copy(child = newExpand) - finalAggregate.copy(child = newShuffle) - } else { - finalAggregate - } + logError(s"xxx attributesToReplace: $attributesToReplace") + + val newPartialAggregate = buildNewAggregateExec( + partialAggregate, + expand, + attributesToReplace + ) + + val newExpand = buildNewExpandExec( + expand, + partialAggregate, + newPartialAggregate, + attributesToReplace + ) + + val newShuffle = shuffle.copy(child = newExpand) + val newFinalAggregate = finalAggregate.copy(child = newShuffle) + logError(s"xxx new plan: $newFinalAggregate") + newFinalAggregate + } else { + finalAggregate + } + case finalAggregate @ HashAggregateExec( + _, + _, + _, + _, + _, + _, + _, + _, + ShuffleExchangeExec( + HashPartitioning(hashExpressions, _), + HashAggregateExec( + _, + _, + _, + groupingExpressions, + aggregateExpressions, + _, + _, + resultExpressions, + FilterExec(_, ExpandExec(projections, output, child))), + _ + ) + ) => + logError(s"xxx match plan:$finalAggregate") + if ( + groupingExpressions.forall(_.isInstanceOf[Attribute]) && + hashExpressions.forall(_.isInstanceOf[Attribute]) && + aggregateExpressions.forall(_.filter.isEmpty) + ) { + val shuffle = + finalAggregate.asInstanceOf[HashAggregateExec].child.asInstanceOf[ShuffleExchangeExec] + val partialAggregate = shuffle.child.asInstanceOf[HashAggregateExec] + val filter = partialAggregate.child.asInstanceOf[FilterExec] + val expand = filter.child.asInstanceOf[ExpandExec] + + val attributesToReplace = buildReplaceAttributeMapForAggregate( + groupingExpressions, + projections, + output + ) + logDebug(s"xxx attributesToReplace: $attributesToReplace") + + val newPartialAggregate = buildNewAggregateExec( + partialAggregate, + expand, + attributesToReplace + ) + + val newExpand = buildNewExpandExec( + expand, + partialAggregate, + newPartialAggregate, + attributesToReplace + ) + + val newFilter = filter.copy(child = newExpand) + val newShuffle = shuffle.copy(child = newFilter) + val newFinalAggregate = finalAggregate.copy(child = newShuffle) + logError(s"xxx new plan: $newFinalAggregate") + newFinalAggregate + } else { + finalAggregate + } + } + } + + def getReplaceAttribute( + toReplace: Attribute, + attributesToReplace: Map[Attribute, Attribute]): Attribute = { + attributesToReplace.getOrElse(toReplace, toReplace) } def buildReplaceAttributeMapForAggregate( originalGroupingExpressions: Seq[NamedExpression], originalExpandProjections: Seq[Seq[Expression]], - originalExpandOutput: Seq[Attribute]): Map[String, Attribute] = { - val fullExpandProjection = originalExpandProjections - .filter( - projection => - projection.forall( - e => !e.isInstanceOf[Literal] || e.asInstanceOf[Literal].value != null))(0) - var attributeMap = Map[String, Attribute]() - originalGroupingExpressions.filter(_.name.startsWith("spark_grouping_id") == false).foreach { + originalExpandOutput: Seq[Attribute]): Map[Attribute, Attribute] = { + + var fullExpandProjection = Seq[Expression]() + for (i <- 0 until originalExpandProjections(0).length) { + val attr = originalExpandProjections.find(x => x(i).isInstanceOf[Attribute]) match { + case Some(projection) => projection(i).asInstanceOf[Attribute] + case None => null + } + fullExpandProjection = fullExpandProjection :+ attr + } + var attributeMap = Map[Attribute, Attribute]() + val groupIdAttribute = originalExpandOutput(originalExpandOutput.length - 1) + originalGroupingExpressions.filter(_.toAttribute != groupIdAttribute).foreach { e => val index = originalExpandOutput.indexWhere(_.semanticEquals(e.toAttribute)) - attributeMap += (e.name -> fullExpandProjection(index).asInstanceOf[Attribute]) + val attr = fullExpandProjection(index).asInstanceOf[Attribute] + // if the grouping key is a literal, cast it to Attribute will be null + if (attr != null) { + attributeMap += (e.toAttribute -> attr) + } + // attributeMap +=(e.toAttribute -> fullExpandProjection(index).asInstanceOf[Attribute]) } attributeMap } @@ -172,9 +242,13 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo projection => val res = newExpandOutput.map { attr => - groupingKeysPosition.get(attr.name) match { - case Some(attrPos) => projection(attrPos) - case None => attr + if (attr.isInstanceOf[Attribute]) { + groupingKeysPosition.get(attr.name) match { + case Some(attrPos) => projection(attrPos) + case None => attr + } + } else { + attr } } res @@ -182,4 +256,66 @@ case class LazyExpandRule(session: SparkSession) extends Rule[SparkPlan] with Lo newExpandProjections } + // Make the child of expand be the child of aggregate + // Need to replace some attributes + def buildNewAggregateExec( + partialAggregate: HashAggregateExec, + expand: ExpandExec, + attributesToReplace: Map[Attribute, Attribute]): SparkPlan = { + val expandOutput = expand.output + // As far as know, the last attribute in the output is the groupId attribute. + val groupIdAttribute = expandOutput(expandOutput.length - 1) + + // if the grouping keys contains literal, they should not be in attributesToReplace + // And we need to remove them from the grouping keys + val groupingExpressions = partialAggregate.groupingExpressions + val newGroupingExpresion = + groupingExpressions + .filter( + e => e.toAttribute != groupIdAttribute && attributesToReplace.contains(e.toAttribute)) + .map(e => getReplaceAttribute(e.toAttribute, attributesToReplace)) + logError( + s"xxx newGroupingExpresion: $newGroupingExpresion,\n" + + s"groupingExpressions: $groupingExpressions") + + // Also need to remove literal grouping keys from the result expressions + val resultExpressions = partialAggregate.resultExpressions + val newResultExpressions = + resultExpressions + .filter { + e => + e.toAttribute != groupIdAttribute && (groupingExpressions + .find(_.toAttribute == e.toAttribute) + .isEmpty || attributesToReplace.contains(e.toAttribute)) + } + .map(e => getReplaceAttribute(e.toAttribute, attributesToReplace)) + logError( + s"xxx newResultExpressions: $newResultExpressions\n" + + s"resultExpressions:$resultExpressions") + partialAggregate.copy( + groupingExpressions = newGroupingExpresion, + resultExpressions = newResultExpressions, + child = expand.child) + } + + def buildNewExpandExec( + expand: ExpandExec, + partialAggregate: HashAggregateExec, + child: SparkPlan, + attributesToReplace: Map[Attribute, Attribute]): SparkPlan = { + val newExpandProjectionTemplate = + partialAggregate.output + .map(e => getReplaceAttribute(e, attributesToReplace)) + logError(s"xxx newExpandProjectionTemplate: $newExpandProjectionTemplate") + + val newExpandProjections = buildNewExpandProjections( + partialAggregate.groupingExpressions, + expand.projections, + expand.output, + newExpandProjectionTemplate + ) + logError(s"xxx newExpandProjections: $newExpandProjections\nprojections:${expand.projections}") + ExpandExec(newExpandProjections, partialAggregate.output, child) + } + } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index a56f45d1ba3d..2f9b2a78b4a7 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -563,5 +563,63 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { compareResultsAgainstVanillaSpark(sql, true, { _ => }) spark.sql("drop table t1") } + + test("GLLUTEN-7647 lazy expand") { + var sql = + """ + |select n_regionkey, n_nationkey, sum(n_regionkey), count(n_name) + |from nation group by n_regionkey, n_nationkey with cube + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select n_regionkey, n_nationkey, sum(n_regionkey), count(distinct n_name) + |from nation group by n_regionkey, n_nationkey with cube + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select n_regionkey, n_nationkey, sum(distinct n_regionkey), count(distinct n_name) + |from nation group by n_regionkey, n_nationkey with cube + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = + """ + |select n_regionkey, n_nationkey, sum(distinct n_regionkey), count(distinct n_name) + |from nation group by n_regionkey, n_nationkey grouping sets((n_regionkey), (n_nationkey)) + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select n_regionkey, n_nationkey, sum(distinct n_regionkey), count(distinct n_name) + |from nation group by n_regionkey, n_nationkey + |grouping sets((n_regionkey, null), (null, n_nationkey)) + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select * from( + |select n_regionkey, n_nationkey, sum(n_regionkey), count(n_name) + |from nation group by n_regionkey, n_nationkey with cube + |) where n_regionkey != 0 + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select * from( + |select n_regionkey, n_nationkey, sum(n_regionkey), count(distinct n_name) + |from nation group by n_regionkey, n_nationkey with cube + |) where n_regionkey != 0 + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } } // scalastyle:off line.size.limit diff --git a/cpp-ch/local-engine/Operator/AdvancedExpandStep.cpp b/cpp-ch/local-engine/Operator/AdvancedExpandStep.cpp index fc3814bfe8fd..61815af123fe 100644 --- a/cpp-ch/local-engine/Operator/AdvancedExpandStep.cpp +++ b/cpp-ch/local-engine/Operator/AdvancedExpandStep.cpp @@ -26,9 +26,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -77,7 +79,7 @@ static DB::ITransformingStep::Traits getTraits() AdvancedExpandStep::AdvancedExpandStep( DB::ContextPtr context_, const DB::Block & input_header_, - const DB::Names & grouping_keys_, + size_t grouping_keys_, const DB::AggregateDescriptions & aggregate_descriptions_, const ExpandField & project_set_exprs_) : DB::ITransformingStep(input_header_, buildOutputHeader(input_header_, project_set_exprs_), getTraits()) @@ -102,12 +104,17 @@ DB::Block AdvancedExpandStep::buildOutputHeader(const DB::Block &, const ExpandF return DB::Block(std::move(cols)); } -void AdvancedExpandStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings &) +void AdvancedExpandStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & pipeline_settings) { const auto & settings = context->getSettingsRef(); - // aggregate grouping keys need a extra grouping id column. - auto aggregate_grouping_keys = grouping_keys; - aggregate_grouping_keys.push_back(output_header->getByPosition(grouping_keys.size()).name); + DB::Names aggregate_grouping_keys; + for (size_t i = 0; i < output_header->columns(); ++i) + { + const auto & col = output_header->getByPosition(i); + if (typeid_cast(col.column.get())) + break; + aggregate_grouping_keys.push_back(col.name); + } DB::Aggregator::Params params( aggregate_grouping_keys, aggregate_descriptions, @@ -137,19 +144,20 @@ void AdvancedExpandStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, for (auto & output : outputs) { auto expand_processor - = std::make_shared(input_header, *output_header, grouping_keys.size(), project_set_exprs); + = std::make_shared(input_header, *output_header, grouping_keys, project_set_exprs); DB::connect(*output, expand_processor->getInputs().front()); new_processors.push_back(expand_processor); auto expand_output_header = expand_processor->getOutputs().front().getHeader(); - + auto transform_params = std::make_shared(expand_output_header, params, false); auto aggregate_processor = std::make_shared(expand_output_header, transform_params, context, false, false); DB::connect(expand_processor->getOutputs().back(), aggregate_processor->getInputs().front()); new_processors.push_back(aggregate_processor); + auto aggregate_output_header = aggregate_processor->getOutputs().front().getHeader(); - auto resize_processor = std::make_shared(*output_header, 2, 1); + auto resize_processor = std::make_shared(expand_output_header, 2, 1); DB::connect(aggregate_processor->getOutputs().front(), resize_processor->getInputs().front()); DB::connect(expand_processor->getOutputs().front(), resize_processor->getInputs().back()); new_processors.push_back(resize_processor); diff --git a/cpp-ch/local-engine/Operator/AdvancedExpandStep.h b/cpp-ch/local-engine/Operator/AdvancedExpandStep.h index 8edce1abff9e..295084658531 100644 --- a/cpp-ch/local-engine/Operator/AdvancedExpandStep.h +++ b/cpp-ch/local-engine/Operator/AdvancedExpandStep.h @@ -37,7 +37,7 @@ class AdvancedExpandStep : public DB::ITransformingStep explicit AdvancedExpandStep( DB::ContextPtr context_, const DB::Block & input_header_, - const DB::Names & grouping_keys_, + size_t grouping_keys_, const DB::AggregateDescriptions & aggregate_descriptions_, const ExpandField & project_set_exprs_); ~AdvancedExpandStep() override = default; @@ -51,7 +51,7 @@ class AdvancedExpandStep : public DB::ITransformingStep protected: DB::ContextPtr context; - DB::Names grouping_keys; + size_t grouping_keys; DB::AggregateDescriptions aggregate_descriptions; ExpandField project_set_exprs; diff --git a/cpp-ch/local-engine/Parser/RelParsers/ExpandRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/ExpandRelParser.cpp index 77cbd5894af4..f3d8b4ab11ec 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/ExpandRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/ExpandRelParser.cpp @@ -89,6 +89,10 @@ ExpandField ExpandRelParser::buildExpandField(const DB::Block & header, const su auto field = project_expr.selection().direct_reference().struct_field().field(); kinds.push_back(ExpandFieldKind::EXPAND_FIELD_KIND_SELECTION); fields.push_back(field); + if (field >= header.columns()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Field index out of range: {}, header: {}", field, header.dumpStructure()); + } updateType(types[i], header.getByPosition(field).type); const auto & name = header.getByPosition(field).name; if (names[i].empty()) @@ -168,12 +172,7 @@ DB::QueryPlanPtr ExpandRelParser::lazyAggregateExpandParse( auto aggregate_rel = rel.expand().input().aggregate(); auto aggregate_descriptions = buildAggregations(input_header, expand_field, aggregate_rel); - DB::Names grouping_keys; - for (size_t i = 0; i < aggregate_rel.groupings(0).grouping_expressions_size(); ++i) - { - const auto & col = input_header.getByPosition(i); - grouping_keys.push_back(col.name); - } + size_t grouping_keys = aggregate_rel.groupings(0).grouping_expressions_size(); auto expand_step = std::make_unique(getContext(), input_header, grouping_keys, aggregate_descriptions, expand_field); @@ -187,25 +186,21 @@ DB::AggregateDescriptions ExpandRelParser::buildAggregations( const DB::Block & input_header, const ExpandField & expand_field, const substrait::AggregateRel & aggregate_rel) { auto header = AdvancedExpandStep::buildOutputHeader(input_header, expand_field); - //auto header = input_header; - LOG_ERROR(getLogger("ExpandRelParser"), "xxx this header is: {}", header.dumpStructure()); DB::AggregateDescriptions descriptions; - size_t grouping_keys_size = aggregate_rel.groupings(0).grouping_expressions_size(); + DB::ColumnsWithTypeAndName aggregate_columns; + for (const auto & col : header.getColumnsWithTypeAndName()) + { + if (typeid_cast(col.column.get())) + aggregate_columns.push_back(col); + } for (size_t i = 0; i < aggregate_rel.measures_size(); ++i) { - /// The output header of the aggregate is [grouping keys] ++ [grouping id] ++ [aggregation columns] + /// The output header of the aggregate is [grouping keys] ++ [aggregation columns] const auto & measure = aggregate_rel.measures(i); - const auto & col = header.getByPosition(grouping_keys_size + i + 1); + const auto & col = aggregate_columns[i]; DB::AggregateDescription description; auto aggregate_col = typeid_cast(col.column.get()); - if (!aggregate_col) - throw DB::Exception( - DB::ErrorCodes::LOGICAL_ERROR, - "The column is not an aggregate column: {}, grouping_keys_size: {}, i: {}", - col.column->dumpStructure(), - grouping_keys_size, - i); description.column_name = col.name; description.argument_names = {col.name}; @@ -218,11 +213,6 @@ DB::AggregateDescriptions ExpandRelParser::buildAggregations( DB::AggregateFunctionProperties aggregate_function_properties; description.function = getAggregateFunction(function_name_with_combinator, {col.type}, aggregate_function_properties, description.parameters); - LOG_ERROR( - getLogger("ExpandRelParser"), - "xxx this aggregate function is: {}, {}", - function_name_with_combinator, - description.function->getName()); descriptions.emplace_back(description); }