Skip to content

Commit

Permalink
[GLUTEN-7703][VL] Make ColumnarBuildSideRelation transform support mu…
Browse files Browse the repository at this point in the history
…ltiple columns (#7704)
  • Loading branch information
yikf authored Oct 31, 2024
1 parent 0737113 commit 4efcba4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.gluten.utils.ArrowAbiUtil
import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.utils.SparkArrowUtil
Expand Down Expand Up @@ -141,48 +141,7 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra
ColumnarBatches.getNativeHandle(batch),
0)
batch.close()
val columnNames = key.flatMap {
case expression: AttributeReference =>
Some(expression)
case _ =>
None
}
if (columnNames.isEmpty) {
throw new IllegalArgumentException(s"Key column not found in expression: $key")
}
if (columnNames.size != 1) {
throw new IllegalArgumentException(s"Multiple key columns found in expression: $key")
}
val columnExpr = columnNames.head
val oneColumnWithSameName = output.count(_.name == columnExpr.name) == 1
val columnInOutput = output.zipWithIndex.filter {
p: (Attribute, Int) =>
if (oneColumnWithSameName) {
// The comparison of exprId can be ignored when
// only one attribute name match is found.
p._1.name == columnExpr.name
} else {
// A case where output has multiple columns with same name
p._1.name == columnExpr.name && p._1.exprId == columnExpr.exprId
}
}
if (columnInOutput.isEmpty) {
throw new IllegalStateException(
s"Key $key not found from build side relation output: $output")
}
if (columnInOutput.size != 1) {
throw new IllegalStateException(
s"More than one key $key found from build side relation output: $output")
}
val replacement =
BoundReference(columnInOutput.head._2, columnExpr.dataType, columnExpr.nullable)

val projExpr = key.transformDown {
case _: AttributeReference =>
replacement
}

val proj = UnsafeProjection.create(projExpr)
val proj = UnsafeProjection.create(Seq(key), output)

new Iterator[InternalRow] {
var rowId = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, InputIteratorTransformer}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSubqueryBroadcastExec, InputIteratorTransformer}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec}

class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "/tpch-data-parquet"
override protected val fileFormat: String = "parquet"

import testImplicits._

override def beforeAll(): Unit = {
super.beforeAll()
}
Expand Down Expand Up @@ -144,4 +147,38 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
}.size == 1)
}
}

test("ColumnarBuildSideRelation transform support multiple key columns") {
withTable("t1", "t2") {
val df1 =
(0 until 50).map(i => (i % 2, i % 3, s"${i % 25}")).toDF("t1_c1", "t1_c2", "date").as("df1")
val df2 = (0 until 50)
.map(i => (i % 11, i % 13, s"${i % 10}"))
.toDF("t2_c1", "t2_c2", "date")
.as("df2")
df1.write.partitionBy("date").saveAsTable("t1")
df2.write.partitionBy("date").saveAsTable("t2")

val df = sql("""
|SELECT t1.date, t1.t1_c1, t2.t2_c2
|FROM t1
|JOIN t2 ON t1.date = t2.date
|WHERE t1.date=if(3 <= t2.t2_c2, if(3 < t2.t2_c1, 3, t2.t2_c1), t2.t2_c2)
|ORDER BY t1.date DESC, t1.t1_c1 DESC, t2.t2_c2 DESC
|LIMIT 1
|""".stripMargin)

checkAnswer(df, Row("3", 1, 4) :: Nil)
// collect the DPP plan.
val subqueryBroadcastExecs = collectWithSubqueries(df.queryExecution.executedPlan) {
case subqueryBroadcast: ColumnarSubqueryBroadcastExec => subqueryBroadcast
}
assert(subqueryBroadcastExecs.size == 2)
val buildKeysAttrs = subqueryBroadcastExecs
.flatMap(_.buildKeys)
.map(e => e.collect { case a: AttributeReference => a })
// the buildKeys function can accept expressions with multiple columns.
assert(buildKeysAttrs.exists(_.size > 1))
}
}
}

0 comments on commit 4efcba4

Please sign in to comment.