Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-7267][CH]Support nested column pruning for HiveTableScan json/parquet/orc format #7268

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -929,4 +929,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
limitExpr: ExpressionTransformer,
original: StringSplit): ExpressionTransformer =
CHStringSplitTransformer(substraitExprName, Seq(srcExpr, regexExpr, limitExpr), original)

override def supportHiveTableScanNestedColumnPruning(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -1416,4 +1416,72 @@ class GlutenClickHouseHiveTableSuite
spark.sql("DROP TABLE test_tbl_7054")
}

test("test hive table scan nested column pruning") {
val json_table_name = "test_tbl_7267_json"
val pq_table_name = "test_tbl_7267_pq"
val create_table_sql =
s"""
| create table if not exists %s(
| id bigint,
| d1 STRUCT<c: STRING, d: ARRAY<STRUCT<x: STRING, y: STRING>>>,
| d2 STRUCT<c: STRING, d: Map<STRING, STRUCT<x: STRING, y: STRING>>>,
| day string,
| hour string
| ) partitioned by(day, hour)
|""".stripMargin
val create_table_1 = create_table_sql.format(json_table_name) +
s"""
| ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'
| STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'
| OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
|""".stripMargin
val create_table_2 = create_table_sql.format(pq_table_name) + " STORED AS PARQUET"
val insert_sql =
"""
| insert into %s values(1,
| named_struct('c', 'c123', 'd', array(named_struct('x', 'x123', 'y', 'y123'))),
| named_struct('c', 'c124', 'd', map('m124', named_struct('x', 'x124', 'y', 'y124'))),
| '2024-09-26', '12'
| )
|""".stripMargin
val insert_sql_1 = insert_sql.format(json_table_name)
val insert_sql_2 = insert_sql.format(pq_table_name)
spark.sql(create_table_1)
spark.sql(create_table_2)
spark.sql(insert_sql_1)
spark.sql(insert_sql_2)
val select_sql_1 =
"select id, d1.c, d1.d[0].x, d2.d['m124'].y from %s where day = '2024-09-26' and hour = '12'"
.format(json_table_name)
val select_sql_2 =
"select id, d1.c, d1.d[0].x, d2.d['m124'].y from %s where day = '2024-09-26' and hour = '12'"
.format(pq_table_name)
withSQLConf(
("spark.sql.hive.convertMetastoreParquet" -> "false"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这俩orc和parquet的开关在什么使用场景下是false呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当需要使用hive parquet/orc serde 读取 table 时,而不是使用spark内置的parquet/orc reader读取时,这两个配置就需要被设置为false @taiyang-li

("spark.sql.hive.convertMetastoreOrc" -> "false")) {
compareResultsAgainstVanillaSpark(
select_sql_1,
compareResult = true,
df => {
val jsonFileScan = collect(df.queryExecution.executedPlan) {
case l: HiveTableScanExecTransformer => l
}
assert(jsonFileScan.size == 1)
}
)
compareResultsAgainstVanillaSpark(
select_sql_2,
compareResult = true,
df => {
val jsonFileScan = collect(df.queryExecution.executedPlan) {
case l: HiveTableScanExecTransformer => l
}
assert(jsonFileScan.size == 1)
}
)
}
spark.sql("drop table if exists %s".format(json_table_name))
spark.sql("drop table if exists %s".format(pq_table_name))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -690,4 +690,6 @@ trait SparkPlanExecApi {
limitExpr: ExpressionTransformer,
original: StringSplit): ExpressionTransformer =
GenericExpressionTransformer(substraitExprName, Seq(srcExpr, regexExpr, limitExpr), original)

def supportHiveTableScanNestedColumnPruning(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPythonExec}
import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim}
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, HiveTableScanNestedColumnPruning}

/**
* Converts a vanilla Spark plan node into Gluten plan node. Gluten plan is supposed to be executed
Expand Down Expand Up @@ -275,7 +275,11 @@ object OffloadOthers {
case plan: ProjectExec =>
val columnarChild = plan.child
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ProjectExecTransformer(plan.projectList, columnarChild)
if (HiveTableScanNestedColumnPruning.supportNestedColumnPruning(plan)) {
HiveTableScanNestedColumnPruning.apply(plan)
} else {
ProjectExecTransformer(plan.projectList, columnarChild)
}
case plan: SortAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ import java.net.URI
case class HiveTableScanExecTransformer(
requestedAttributes: Seq[Attribute],
relation: HiveTableRelation,
partitionPruningPred: Seq[Expression])(@transient session: SparkSession)
partitionPruningPred: Seq[Expression],
prunedOutput: Seq[Attribute] = Seq.empty[Attribute])(@transient session: SparkSession)
extends AbstractHiveTableScanExec(requestedAttributes, relation, partitionPruningPred)(session)
with BasicScanExecTransformer {

Expand All @@ -63,7 +64,13 @@ case class HiveTableScanExecTransformer(

override def getMetadataColumns(): Seq[AttributeReference] = Seq.empty

override def outputAttributes(): Seq[Attribute] = output
override def outputAttributes(): Seq[Attribute] = {
if (prunedOutput.nonEmpty) {
prunedOutput
} else {
output
}
}

override def getPartitions: Seq[InputPartition] = partitions

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.execution.ProjectExecTransformer

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, SparkPlan}
import org.apache.spark.sql.hive.HiveTableScanExecTransformer.{ORC_INPUT_FORMAT_CLASS, PARQUET_INPUT_FORMAT_CLASS, TEXT_INPUT_FORMAT_CLASS}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.util.SchemaUtils._
import org.apache.spark.util.Utils

object HiveTableScanNestedColumnPruning extends Logging {
import org.apache.spark.sql.catalyst.expressions.SchemaPruning._

def supportNestedColumnPruning(projectExec: ProjectExec): Boolean = {
if (BackendsApiManager.getSparkPlanExecApiInstance.supportHiveTableScanNestedColumnPruning()) {
projectExec.child match {
case HiveTableScanExecTransformer(_, relation, _, _) =>
relation.tableMeta.storage.inputFormat match {
case Some(inputFormat)
if TEXT_INPUT_FORMAT_CLASS.isAssignableFrom(Utils.classForName(inputFormat)) =>
relation.tableMeta.storage.serde match {
case Some("org.openx.data.jsonserde.JsonSerDe") | Some(
"org.apache.hive.hcatalog.data.JsonSerDe") =>
return true
case _ =>
}
case Some(inputFormat)
if ORC_INPUT_FORMAT_CLASS.isAssignableFrom(Utils.classForName(inputFormat)) =>
return true
case Some(inputFormat)
if PARQUET_INPUT_FORMAT_CLASS.isAssignableFrom(Utils.classForName(inputFormat)) =>
return true
case _ =>
}
case _ =>
}
}
false
}

def apply(plan: SparkPlan): SparkPlan = {
plan match {
case ProjectExec(projectList, child) =>
child match {
case h: HiveTableScanExecTransformer =>
val newPlan = prunePhysicalColumns(
h.relation,
projectList,
Seq.empty[Expression],
(prunedDataSchema, prunedMetadataSchema) => {
buildNewHiveTableScan(h, prunedDataSchema, prunedMetadataSchema)
},
(schema, requestFields) => {
h.pruneSchema(schema, requestFields)
}
)
if (newPlan.nonEmpty) {
return newPlan.get
} else {
return ProjectExecTransformer(projectList, child)
}
case _ =>
return ProjectExecTransformer(projectList, child)
}
case _ =>
}
plan
}

private def prunePhysicalColumns(
relation: HiveTableRelation,
projects: Seq[NamedExpression],
filters: Seq[Expression],
leafNodeBuilder: (StructType, StructType) => LeafExecNode,
pruneSchemaFunc: (StructType, Seq[SchemaPruning.RootField]) => StructType)
: Option[SparkPlan] = {
val (normalizedProjects, normalizedFilters) =
normalizeAttributeRefNames(relation.output, projects, filters)
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)
// If requestedRootFields includes a nested field, continue. Otherwise,
// return op
if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) {
val prunedDataSchema = pruneSchemaFunc(relation.tableMeta.dataSchema, requestedRootFields)
val metaFieldNames = relation.tableMeta.schema.fieldNames
val metadataSchema = relation.output.collect {
case attr: AttributeReference if metaFieldNames.contains(attr.name) => attr
}.toStructType
val prunedMetadataSchema = if (metadataSchema.nonEmpty) {
pruneSchemaFunc(metadataSchema, requestedRootFields)
} else {
metadataSchema
}
// If the data schema is different from the pruned data schema
// OR
// the metadata schema is different from the pruned metadata schema, continue.
// Otherwise, return None.
if (
countLeaves(relation.tableMeta.dataSchema) > countLeaves(prunedDataSchema) ||
countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)
) {
val leafNode = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema)
val projectionOverSchema = ProjectionOverSchema(
prunedDataSchema.merge(prunedMetadataSchema),
AttributeSet(relation.output))
Some(
buildNewProjection(
projects,
normalizedProjects,
normalizedFilters,
leafNode,
projectionOverSchema))
} else {
None
}
} else {
None
}
}

/**
* Normalizes the names of the attribute references in the given projects and filters to reflect
* the names in the given logical relation. This makes it possible to compare attributes and
* fields by name. Returns a tuple with the normalized projects and filters, respectively.
*/
private def normalizeAttributeRefNames(
output: Seq[AttributeReference],
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = {
val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap
val normalizedProjects = projects
.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
})
.map { case expr: NamedExpression => expr }
val normalizedFilters = filters.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
})
(normalizedProjects, normalizedFilters)
}

/** Builds the new output [[Project]] Spark SQL operator that has the `leafNode`. */
private def buildNewProjection(
projects: Seq[NamedExpression],
normalizedProjects: Seq[NamedExpression],
filters: Seq[Expression],
leafNode: LeafExecNode,
projectionOverSchema: ProjectionOverSchema): ProjectExecTransformer = {
// Construct a new target for our projection by rewriting and
// including the original filters where available
val projectionChild =
if (filters.nonEmpty) {
val projectedFilters = filters.map(_.transformDown {
case projectionOverSchema(expr) => expr
})
val newFilterCondition = projectedFilters.reduce(And)
FilterExec(newFilterCondition, leafNode)
} else {
leafNode
}

// Construct the new projections of our Project by
// rewriting the original projections
val newProjects =
normalizedProjects.map(_.transformDown { case projectionOverSchema(expr) => expr }).map {
case expr: NamedExpression => expr
}

if (log.isDebugEnabled) {
logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}")
}
ProjectExecTransformer(
restoreOriginalOutputNames(newProjects, projects.map(_.name)),
projectionChild)
}

private def buildNewHiveTableScan(
hiveTableScan: HiveTableScanExecTransformer,
prunedDataSchema: StructType,
prunedMetadataSchema: StructType): HiveTableScanExecTransformer = {
val relation = hiveTableScan.relation
val partitionSchema = relation.tableMeta.partitionSchema
val prunedBaseSchema = StructType(
prunedDataSchema.fields.filterNot(
f => partitionSchema.fieldNames.contains(f.name)) ++ partitionSchema.fields)
val finalSchema = prunedBaseSchema.merge(prunedMetadataSchema)
val prunedOutput = getPrunedOutput(relation.output, finalSchema)
var finalOutput = Seq.empty[Attribute]
for (p <- hiveTableScan.output) {
var flag = false
for (q <- prunedOutput if !flag) {
if (p.name.equals(q.name)) {
finalOutput :+= q
flag = true
}
}
}
HiveTableScanExecTransformer(
hiveTableScan.requestedAttributes,
relation,
hiveTableScan.partitionPruningPred,
finalOutput)(hiveTableScan.session)
}

// Prune the given output to make it consistent with `requiredSchema`.
private def getPrunedOutput(
output: Seq[AttributeReference],
requiredSchema: StructType): Seq[Attribute] = {
// We need to update the data type of the output attributes to use the pruned ones.
// so that references to the original relation's output are not broken
val nameAttributeMap = output.map(att => (att.name, att)).toMap
val requiredAttributes =
requiredSchema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
requiredAttributes.map {
case att if nameAttributeMap.contains(att.name) =>
nameAttributeMap(att.name).withDataType(att.dataType)
case att => att
}
}

/**
* Counts the "leaf" fields of the given dataType. Informally, this is the number of fields of
* non-complex data type in the tree representation of [[DataType]].
*/
private def countLeaves(dataType: DataType): Int = {
dataType match {
case array: ArrayType => countLeaves(array.elementType)
case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType)
case struct: StructType =>
struct.map(field => countLeaves(field.dataType)).sum
case _ => 1
}
}
}
Loading
Loading