diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index c10b7ff8c371..f80d2e1cb5cb 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -347,11 +347,23 @@ bool SubstraitToVeloxPlanValidator::validateExpression( } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchRel) { - const auto& extension = fetchRel.advanced_extension(); - std::vector types; - if (!validateInputTypes(extension, types)) { - logValidateMsg("native validation failed due to: unsupported input types in FetchRel."); - return false; + RowTypePtr rowType = nullptr; + // Get and validate the input types from extension. + if (fetchRel.has_advanced_extension()) { + const auto& extension = fetchRel.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + logValidateMsg("native validation failed due to: unsupported input types in ExpandRel."); + return false; + } + + int32_t inputPlanNodeId = 0; + std::vector names; + names.reserve(types.size()); + for (auto colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); + } + rowType = std::make_shared(std::move(names), std::move(types)); } if (fetchRel.offset() < 0 || fetchRel.count() < 0) { @@ -359,33 +371,25 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchR return false; } - core::PlanNodePtr childNode; // Check the input of fetchRel, if it's sortRel, we need to check whether the sorting key is duplicated. - ::substrait::SortRel sortRel; bool topNFlag = false; if (fetchRel.has_input()) { topNFlag = fetchRel.input().has_sort(); if (topNFlag) { - sortRel = fetchRel.input().sort(); - childNode = planConverter_.toVeloxPlan(sortRel.input()); - } else { - childNode = planConverter_.toVeloxPlan(fetchRel.input()); - } - } - - if (topNFlag) { - auto [sortingKeys, sortingOrders] = planConverter_.processSortField(sortRel.sorts(), childNode->outputType()); - - folly::F14FastSet sortingKeyNames; - for (const auto& sortingKey : sortingKeys) { - auto result = sortingKeyNames.insert(sortingKey->name()); - if (!result.second) { - logValidateMsg( - "native validation failed due to: if the input of fetchRel is a SortRel, we will convert it to a TopNNode. In Velox, it is important to ensure unique sorting keys. However, duplicate keys were found in this case."); - return false; + ::substrait::SortRel sortRel = fetchRel.input().sort(); + auto [sortingKeys, sortingOrders] = planConverter_.processSortField(sortRel.sorts(), rowType); + folly::F14FastSet sortingKeyNames; + for (const auto& sortingKey : sortingKeys) { + auto result = sortingKeyNames.insert(sortingKey->name()); + if (!result.second) { + logValidateMsg( + "native validation failed due to: if the input of fetchRel is a SortRel, we will convert it to a TopNNode. In Velox, it is important to ensure unique sorting keys. However, duplicate keys were found in this case."); + return false; + } } } } + return true; } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala index ad7c68a6a9f0..961e0c95201f 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala @@ -56,7 +56,11 @@ case class LimitTransformer(child: SparkPlan, offset: Long, count: Long) override protected def doValidateInternal(): ValidationResult = { val context = new SubstraitContext val operatorId = context.nextOperatorId(this.nodeName) - val relNode = getRelNode(context, operatorId, offset, count, child.output, null, true) + val input = child match { + case c: TransformSupport => c.doTransform(context).root + case _ => null + } + val relNode = getRelNode(context, operatorId, offset, count, child.output, input, true) doNativeValidation(context, relNode) } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala index 7bede35f7abc..97b295b5ca42 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala @@ -26,7 +26,7 @@ import io.glutenproject.utils.PhysicalPlanSelector import org.apache.spark.api.python.EvalPythonExecTransformer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SortOrder} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.rules.Rule @@ -677,12 +677,17 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { "columnar topK is not enabled in TakeOrderedAndProjectExec") } else { var tagged: ValidationResult = null - val limitPlan = LimitTransformer(plan.child, 0, plan.limit) - tagged = limitPlan.doValidate() - if (tagged.isValid) { + val orderingSatisfies = + SortOrder.orderingSatisfies(plan.child.outputOrdering, plan.sortOrder) + if (orderingSatisfies) { + val limitPlan = LimitTransformer(plan.child, 0, plan.limit) + tagged = limitPlan.doValidate() + } else { val sortPlan = SortExecTransformer(plan.sortOrder, false, plan.child) - tagged = sortPlan.doValidate() + val limitPlan = LimitTransformer(sortPlan, 0, plan.limit) + tagged = limitPlan.doValidate() } + if (tagged.isValid) { val projectPlan = ProjectExecTransformer(plan.projectList, plan.child) tagged = projectPlan.doValidate()