Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohahaha committed Nov 26, 2023
1 parent f456a7e commit c88207f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,17 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
updateNativeMetrics: IMetrics => Unit,
inputIterators: Seq[Iterator[ColumnarBatch]] = Seq()
): Iterator[ColumnarBatch] = {
val resIter: GeneralOutIterator = GlutenTimeMetric.millis(pipelineTime) {
_ =>
val transKernel = new CHNativeExpressionEvaluator()
val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
transKernel.createKernelWithBatchIterator(inputPartition.plan, inBatchIters, false)
}
TaskContext.get().addTaskCompletionListener[Unit](_ => resIter.close())

val transKernel = new CHNativeExpressionEvaluator()
val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
val resIter: GeneralOutIterator =
transKernel.createKernelWithBatchIterator(inputPartition.plan, inBatchIters, false)

context.addTaskCompletionListener[Unit](_ => resIter.close())
val iter = new Iterator[Any] {
private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics
private val inputMetrics = context.taskMetrics().inputMetrics
private var outputRowCount = 0L
private var outputVectorCount = 0L
private var metricsUpdated = false
Expand Down Expand Up @@ -155,6 +155,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
// Generate Iterator[ColumnarBatch] for final stage.
// scalastyle:off argcount
override def genFinalStageIterator(
context: TaskContext,
inputIterators: Seq[Iterator[ColumnarBatch]],
numaBindingInfo: GlutenNumaBindingInfo,
sparkConf: SparkConf,
Expand All @@ -165,19 +166,17 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
materializeInput: Boolean): Iterator[ColumnarBatch] = {
// scalastyle:on argcount
GlutenConfig.getConf
val nativeIterator = GlutenTimeMetric.millis(pipelineTime) {
_ =>
val transKernel = new CHNativeExpressionEvaluator()
val columnarNativeIterator =
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
// we need to complete dependency RDD's firstly
transKernel.createKernelWithBatchIterator(
rootNode.toProtobuf.toByteArray,
columnarNativeIterator,
materializeInput)
}

val transKernel = new CHNativeExpressionEvaluator()
val columnarNativeIterator =
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
// we need to complete dependency RDD's firstly
val nativeIterator = transKernel.createKernelWithBatchIterator(
rootNode.toProtobuf.toByteArray,
columnarNativeIterator,
materializeInput)

val resIter = new Iterator[ColumnarBatch] {
private var outputRowCount = 0L
Expand Down Expand Up @@ -212,7 +211,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
// relationHolder.clear()
}

TaskContext.get().addTaskCompletionListener[Unit](_ => close())
context.addTaskCompletionListener[Unit](_ => close())
new CloseableCHColumnBatchIterator(resIter, Some(pipelineTime))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
* @return
*/
override def genFinalStageIterator(
context: TaskContext,
inputIterators: Seq[Iterator[ColumnarBatch]],
numaBindingInfo: GlutenNumaBindingInfo,
sparkConf: SparkConf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ trait IteratorApi {
*/
// scalastyle:off argcount
def genFinalStageIterator(
context: TaskContext,
inputIterators: Seq[Iterator[ColumnarBatch]],
numaBindingInfo: GlutenNumaBindingInfo,
sparkConf: SparkConf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package io.glutenproject.execution

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.metrics.IMetrics
import io.glutenproject.metrics.{GlutenTimeMetric, IMetrics}
import io.glutenproject.substrait.plan.PlanBuilder

import org.apache.spark.{Partition, SparkContext, SparkException, TaskContext}
Expand Down Expand Up @@ -99,17 +99,20 @@ class GlutenWholeStageColumnarRDD(
val numaBindingInfo = GlutenConfig.getConf.numaBindingInfo

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
ExecutorManager.tryTaskSet(numaBindingInfo)
val (inputPartition, inputColumnarRDDPartitions) = castNativePartition(split)
val inputIterators = rdds.getIterators(inputColumnarRDDPartitions, context)
BackendsApiManager.getIteratorApiInstance.genFirstStageIterator(
inputPartition,
context,
pipelineTime,
updateInputMetrics,
updateNativeMetrics,
inputIterators
)
GlutenTimeMetric.millis(pipelineTime) {
_ =>
ExecutorManager.tryTaskSet(numaBindingInfo)
val (inputPartition, inputColumnarRDDPartitions) = castNativePartition(split)
val inputIterators = rdds.getIterators(inputColumnarRDDPartitions, context)
BackendsApiManager.getIteratorApiInstance.genFirstStageIterator(
inputPartition,
context,
pipelineTime,
updateInputMetrics,
updateNativeMetrics,
inputIterators
)
}
}

private def castNativePartition(split: Partition): (BaseGlutenPartition, Seq[Partition]) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package io.glutenproject.execution

import io.glutenproject.GlutenNumaBindingInfo
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.metrics.IMetrics
import io.glutenproject.metrics.{GlutenTimeMetric, IMetrics}

import org.apache.spark.{Partition, SparkConf, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
Expand All @@ -45,19 +45,23 @@ class WholeStageZippedPartitionsRDD(
extends RDD[ColumnarBatch](sc, rdds.getDependencies) {

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partitions = split.asInstanceOf[ZippedPartitionsPartition].inputColumnarRDDPartitions
val inputIterators: Seq[Iterator[ColumnarBatch]] = rdds.getIterators(partitions, context)
BackendsApiManager.getIteratorApiInstance
.genFinalStageIterator(
inputIterators,
numaBindingInfo,
sparkConf,
resCtx.root,
pipelineTime,
updateNativeMetrics,
buildRelationBatchHolder,
materializeInput
)
GlutenTimeMetric.millis(pipelineTime) {
_ =>
val partitions = split.asInstanceOf[ZippedPartitionsPartition].inputColumnarRDDPartitions
val inputIterators: Seq[Iterator[ColumnarBatch]] = rdds.getIterators(partitions, context)
BackendsApiManager.getIteratorApiInstance
.genFinalStageIterator(
context,
inputIterators,
numaBindingInfo,
sparkConf,
resCtx.root,
pipelineTime,
updateNativeMetrics,
buildRelationBatchHolder,
materializeInput
)
}
}

override def getPartitions: Array[Partition] = {
Expand Down

0 comments on commit c88207f

Please sign in to comment.