Skip to content

Commit

Permalink
[VL] Iterator's and its payloads' lifecycle improvements (#3526)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Oct 30, 2023
1 parent a7b68c3 commit 39a5162
Show file tree
Hide file tree
Showing 24 changed files with 618 additions and 359 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,12 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
/**
* Generate closeable ColumnBatch iterator.
*
* FIXME: This no longer overrides parent's method
*
* @param iter
* @return
*/
override def genCloseableColumnBatchIterator(
iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
def genCloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
if (iter.isInstanceOf[CloseableCHColumnBatchIterator]) iter
else new CloseableCHColumnBatchIterator(iter)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import io.glutenproject.metrics.IMetrics
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.substrait.rel.LocalFilesBuilder
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.utils.Iterators
import io.glutenproject.vectorized._

import org.apache.spark.{InterruptibleIterator, Partition, SparkConf, SparkContext, TaskContext}
import org.apache.spark.{Partition, SparkConf, SparkContext, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
Expand All @@ -39,7 +40,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{BinaryType, DateType, StructType, TimestampType}
import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{ExecutorManager, TaskResources}
import org.apache.spark.util.ExecutorManager

import java.net.URLDecoder
import java.nio.charset.StandardCharsets
Expand Down Expand Up @@ -124,17 +125,6 @@ class IteratorApiImpl extends IteratorApi with Logging {
GlutenPartition(index, substraitPlan, localFilesNodesWithLocations.head._2)
}

/**
* Generate closeable ColumnBatch iterator.
*
* @param iter
* @return
*/
override def genCloseableColumnBatchIterator(
iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
new CloseableColumnBatchIterator(iter)
}

/**
* Generate Iterator[ColumnarBatch] for first stage.
*
Expand All @@ -156,30 +146,18 @@ class IteratorApiImpl extends IteratorApi with Logging {
val resIter: GeneralOutIterator =
transKernel.createKernelWithBatchIterator(inputPartition.plan, columnarNativeIterators)
pipelineTime += TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)
TaskResources.addRecycler(s"FirstStageIterator_${resIter.getId}", 100)(resIter.close())
val iter = new Iterator[ColumnarBatch] {
private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics
var finished = false

override def hasNext: Boolean = {
val res = resIter.hasNext
if (!res) {
updateNativeMetrics(resIter.getMetrics)
updateInputMetrics(inputMetrics)
finished = true
}
res
}

override def next(): ColumnarBatch = {
if (finished) {
throw new java.util.NoSuchElementException("End of stream.")
}
resIter.next()
Iterators
.wrap(resIter.asScala)
.recycleIterator {
updateNativeMetrics(resIter.getMetrics)
updateInputMetrics(TaskContext.get().taskMetrics().inputMetrics)
resIter.close()
}
}

new InterruptibleIterator(context, new CloseableColumnBatchIterator(iter, Some(pipelineTime)))
.recyclePayload(batch => batch.close())
.addToPipelineTime(pipelineTime)
.asInterruptible(context)
.create()
}

// scalastyle:off argcount
Expand Down Expand Up @@ -213,25 +191,15 @@ class IteratorApiImpl extends IteratorApi with Logging {

pipelineTime += TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)

val resIter = new Iterator[ColumnarBatch] {
override def hasNext: Boolean = {
val res = nativeResultIterator.hasNext
if (!res) {
updateNativeMetrics(nativeResultIterator.getMetrics)
}
res
}

override def next(): ColumnarBatch = {
nativeResultIterator.next
Iterators
.wrap(nativeResultIterator.asScala)
.recycleIterator {
updateNativeMetrics(nativeResultIterator.getMetrics)
nativeResultIterator.close()
}
}

TaskResources.addRecycler(s"FinalStageIterator_${nativeResultIterator.getId}", 100) {
nativeResultIterator.close()
}

new CloseableColumnBatchIterator(resIter, Some(pipelineTime))
.recyclePayload(batch => batch.close())
.addToPipelineTime(pipelineTime)
.create()
}
// scalastyle:on argcount

Expand All @@ -254,6 +222,9 @@ class IteratorApiImpl extends IteratorApi with Logging {
broadcasted: Broadcast[BuildSideRelation],
broadCastContext: BroadCastHashJoinContext): Iterator[ColumnarBatch] = {
val relation = broadcasted.value.asReadOnlyCopy(broadCastContext)
new CloseableColumnBatchIterator(relation.deserialized)
Iterators
.wrap(relation.deserialized)
.recyclePayload(batch => batch.close())
.create()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.columnarbatch.ColumnarBatches
import io.glutenproject.exec.Runtimes
import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators
import io.glutenproject.memory.nmm.NativeMemoryManagers
import io.glutenproject.utils.ArrowAbiUtil
import io.glutenproject.utils.{ArrowAbiUtil, Iterators}
import io.glutenproject.vectorized._

import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -93,7 +93,6 @@ object RowToVeloxColumnarExec {
val jniWrapper = NativeRowToColumnarJniWrapper.create()
val allocator = ArrowBufferAllocators.contextInstance()
val cSchema = ArrowSchema.allocateNew(allocator)
var closed = false
val r2cHandle =
try {
ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
Expand All @@ -106,26 +105,14 @@ object RowToVeloxColumnarExec {
cSchema.close()
}

TaskResources.addRecycler(s"RowToColumnar_$r2cHandle", 100) {
if (!closed) {
jniWrapper.close(r2cHandle)
closed = true
}
}

val res: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] {
var finished = false

override def hasNext: Boolean = {
if (finished) {
false
} else {
val itHasNext = it.hasNext
if (!itHasNext && !closed) {
jniWrapper.close(r2cHandle)
closed = true
}
itHasNext
it.hasNext
}
}

Expand Down Expand Up @@ -215,6 +202,12 @@ object RowToVeloxColumnarExec {
cb
}
}
new CloseableColumnBatchIterator(res)
Iterators
.wrap(res)
.recycleIterator {
jniWrapper.close(r2cHandle)
}
.recyclePayload(_.close())
.create()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.glutenproject.execution
import io.glutenproject.columnarbatch.ColumnarBatches
import io.glutenproject.extension.ValidationResult
import io.glutenproject.memory.nmm.NativeMemoryManagers
import io.glutenproject.utils.Iterators
import io.glutenproject.vectorized.NativeColumnarToRowJniWrapper

import org.apache.spark.rdd.RDD
Expand All @@ -28,7 +29,6 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.TaskResources

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -98,26 +98,13 @@ object VeloxColumnarToRowExec {

// TODO:: pass the jni jniWrapper and arrowSchema and serializeSchema method by broadcast
val jniWrapper = NativeColumnarToRowJniWrapper.create()
var closed = false
val c2rId = jniWrapper.nativeColumnarToRowInit(
NativeMemoryManagers.contextInstance("ColumnarToRow").getNativeInstanceHandle)

TaskResources.addRecycler(s"ColumnarToRow_$c2rId", 100) {
if (!closed) {
jniWrapper.nativeClose(c2rId)
closed = true
}
}

val res: Iterator[Iterator[InternalRow]] = new Iterator[Iterator[InternalRow]] {

override def hasNext: Boolean = {
val hasNext = batches.hasNext
if (!hasNext && !closed) {
jniWrapper.nativeClose(c2rId)
closed = true
}
hasNext
batches.hasNext
}

override def next(): Iterator[InternalRow] = {
Expand Down Expand Up @@ -170,6 +157,13 @@ object VeloxColumnarToRowExec {
}
}
}
res.flatten
Iterators
.wrap(res.flatten)
.protectInvocationFlow() // Spark may call `hasNext()` again after a false output which
// is not allowed by Gluten iterators. E.g. GroupedIterator#fetchNextGroupIterator
.recycleIterator {
jniWrapper.nativeClose(c2rId)
}
.create()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import io.glutenproject.exec.Runtimes
import io.glutenproject.execution.{RowToVeloxColumnarExec, VeloxColumnarToRowExec}
import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators
import io.glutenproject.memory.nmm.NativeMemoryManagers
import io.glutenproject.utils.ArrowAbiUtil
import io.glutenproject.vectorized.{CloseableColumnBatchIterator, ColumnarBatchSerializerJniWrapper}
import io.glutenproject.utils.{ArrowAbiUtil, Iterators}
import io.glutenproject.vectorized.ColumnarBatchSerializerJniWrapper

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
Expand All @@ -38,7 +38,6 @@ import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.utils.SparkArrowUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.TaskResources

import org.apache.arrow.c.ArrowSchema

Expand Down Expand Up @@ -244,33 +243,34 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with SQLConfHe
nmm.getNativeInstanceHandle
)
cSchema.close()
TaskResources.addRecycler(
s"ColumnarCachedBatchSerializer_convertCachedBatchToColumnarBatch_$deserializerHandle",
50) {
ColumnarBatchSerializerJniWrapper.create().close(deserializerHandle)
}

new CloseableColumnBatchIterator(new Iterator[ColumnarBatch] {
override def hasNext: Boolean = it.hasNext
Iterators
.wrap(new Iterator[ColumnarBatch] {
override def hasNext: Boolean = it.hasNext

override def next(): ColumnarBatch = {
val cachedBatch = it.next().asInstanceOf[CachedColumnarBatch]
val batchHandle =
ColumnarBatchSerializerJniWrapper
.create()
.deserialize(deserializerHandle, cachedBatch.bytes)
val batch = ColumnarBatches.create(Runtimes.contextInstance(), batchHandle)
if (shouldSelectAttributes) {
try {
ColumnarBatches.select(nmm, batch, requestedColumnIndices.toArray)
} finally {
batch.close()
override def next(): ColumnarBatch = {
val cachedBatch = it.next().asInstanceOf[CachedColumnarBatch]
val batchHandle =
ColumnarBatchSerializerJniWrapper
.create()
.deserialize(deserializerHandle, cachedBatch.bytes)
val batch = ColumnarBatches.create(Runtimes.contextInstance(), batchHandle)
if (shouldSelectAttributes) {
try {
ColumnarBatches.select(nmm, batch, requestedColumnIndices.toArray)
} finally {
batch.close()
}
} else {
batch
}
} else {
batch
}
})
.recycleIterator {
ColumnarBatchSerializerJniWrapper.create().close(deserializerHandle)
}
})
.recyclePayload(_.close())
.create()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package org.apache.spark.sql.execution.datasources

import io.glutenproject.datasource.DatasourceJniWrapper
import io.glutenproject.vectorized.{CloseableColumnBatchIterator, ColumnarBatchInIterator}
import io.glutenproject.utils.Iterators
import io.glutenproject.vectorized.ColumnarBatchInIterator

import org.apache.spark.TaskContext
import org.apache.spark.sql.execution.datasources.VeloxWriteQueue.EOS_BATCH
Expand Down Expand Up @@ -50,7 +51,8 @@ class VeloxWriteQueue(
try {
datasourceJniWrapper.write(
dsHandle,
new ColumnarBatchInIterator(new CloseableColumnBatchIterator(scanner).asJava))
new ColumnarBatchInIterator(
Iterators.wrap(scanner).recyclePayload(_.close()).create().asJava))
} catch {
case e: Exception =>
writeException.set(e)
Expand Down
12 changes: 11 additions & 1 deletion cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ JNIEXPORT jboolean JNICALL Java_io_glutenproject_vectorized_ColumnarBatchOutIter

auto iter = ctx->objectStore()->retrieve<ResultIterator>(iterHandle);
if (iter == nullptr) {
std::string errorMessage = "faked to get batch iterator";
std::string errorMessage = "failed to get batch iterator";
throw gluten::GlutenException(errorMessage);
}
return iter->hasNext();
Expand Down Expand Up @@ -1272,6 +1272,16 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_shr
JNI_METHOD_END(kInvalidResourceHandle)
}

JNIEXPORT void JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_hold( // NOLINT
JNIEnv* env,
jclass,
jlong memoryManagerHandle) {
JNI_METHOD_START
auto memoryManager = jniCastOrThrow<MemoryManager>(memoryManagerHandle);
memoryManager->hold();
JNI_METHOD_END()
}

JNIEXPORT void JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_release( // NOLINT
JNIEnv* env,
jclass,
Expand Down
5 changes: 5 additions & 0 deletions cpp/core/memory/MemoryManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ class MemoryManager {
virtual const MemoryUsageStats collectMemoryUsageStats() const = 0;

virtual const int64_t shrink(int64_t size) = 0;

// Hold this memory manager. The underlying memory pools will be released as lately as this memory manager gets
// destroyed. Which means, a call to this function would make sure the memory blocks directly or indirectly managed
// by this manager, be guaranteed safe to access during the period that this manager is alive.
virtual void hold() = 0;
};

} // namespace gluten
Loading

0 comments on commit 39a5162

Please sign in to comment.