From 4efcba47bbb9ac7fbe5e5ab192b7a402dea50ab4 Mon Sep 17 00:00:00 2001 From: Kaifei Yi Date: Thu, 31 Oct 2024 16:47:12 +0800 Subject: [PATCH] [GLUTEN-7703][VL] Make ColumnarBuildSideRelation transform support multiple columns (#7704) --- .../execution/ColumnarBuildSideRelation.scala | 45 +------------------ .../gluten/execution/VeloxHashJoinSuite.scala | 39 +++++++++++++++- 2 files changed, 40 insertions(+), 44 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index 5b34104a3f29..feaf72f64fb2 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -25,7 +25,7 @@ import org.apache.gluten.utils.ArrowAbiUtil import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.utils.SparkArrowUtil @@ -141,48 +141,7 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra ColumnarBatches.getNativeHandle(batch), 0) batch.close() - val columnNames = key.flatMap { - case expression: AttributeReference => - Some(expression) - case _ => - None - } - if (columnNames.isEmpty) { - throw new IllegalArgumentException(s"Key column not found in expression: $key") - } - if (columnNames.size != 1) { - throw new IllegalArgumentException(s"Multiple key columns found in expression: $key") - } - val columnExpr = columnNames.head - val oneColumnWithSameName = output.count(_.name == columnExpr.name) == 1 - val columnInOutput = output.zipWithIndex.filter { - p: (Attribute, Int) => - if (oneColumnWithSameName) { - // The comparison of exprId can be ignored when - // only one attribute name match is found. - p._1.name == columnExpr.name - } else { - // A case where output has multiple columns with same name - p._1.name == columnExpr.name && p._1.exprId == columnExpr.exprId - } - } - if (columnInOutput.isEmpty) { - throw new IllegalStateException( - s"Key $key not found from build side relation output: $output") - } - if (columnInOutput.size != 1) { - throw new IllegalStateException( - s"More than one key $key found from build side relation output: $output") - } - val replacement = - BoundReference(columnInOutput.head._2, columnExpr.dataType, columnExpr.nullable) - - val projExpr = key.transformDown { - case _: AttributeReference => - replacement - } - - val proj = UnsafeProjection.create(projExpr) + val proj = UnsafeProjection.create(Seq(key), output) new Iterator[InternalRow] { var rowId = 0 diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala index 49678189bdd3..b58ea4d3974f 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala @@ -20,13 +20,16 @@ import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, InputIteratorTransformer} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSubqueryBroadcastExec, InputIteratorTransformer} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec} class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite { override protected val resourcePath: String = "/tpch-data-parquet" override protected val fileFormat: String = "parquet" + import testImplicits._ + override def beforeAll(): Unit = { super.beforeAll() } @@ -144,4 +147,38 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite { }.size == 1) } } + + test("ColumnarBuildSideRelation transform support multiple key columns") { + withTable("t1", "t2") { + val df1 = + (0 until 50).map(i => (i % 2, i % 3, s"${i % 25}")).toDF("t1_c1", "t1_c2", "date").as("df1") + val df2 = (0 until 50) + .map(i => (i % 11, i % 13, s"${i % 10}")) + .toDF("t2_c1", "t2_c2", "date") + .as("df2") + df1.write.partitionBy("date").saveAsTable("t1") + df2.write.partitionBy("date").saveAsTable("t2") + + val df = sql(""" + |SELECT t1.date, t1.t1_c1, t2.t2_c2 + |FROM t1 + |JOIN t2 ON t1.date = t2.date + |WHERE t1.date=if(3 <= t2.t2_c2, if(3 < t2.t2_c1, 3, t2.t2_c1), t2.t2_c2) + |ORDER BY t1.date DESC, t1.t1_c1 DESC, t2.t2_c2 DESC + |LIMIT 1 + |""".stripMargin) + + checkAnswer(df, Row("3", 1, 4) :: Nil) + // collect the DPP plan. + val subqueryBroadcastExecs = collectWithSubqueries(df.queryExecution.executedPlan) { + case subqueryBroadcast: ColumnarSubqueryBroadcastExec => subqueryBroadcast + } + assert(subqueryBroadcastExecs.size == 2) + val buildKeysAttrs = subqueryBroadcastExecs + .flatMap(_.buildKeys) + .map(e => e.collect { case a: AttributeReference => a }) + // the buildKeys function can accept expressions with multiple columns. + assert(buildKeysAttrs.exists(_.size > 1)) + } + } }