Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Iterator's and its payloads' lifecycle improvements #3526

Merged
merged 14 commits into from
Oct 30, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,12 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
/**
* Generate closeable ColumnBatch iterator.
*
* FIXME: This no longer overrides parent's method
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep it in the CHIteratorApi

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggested change? I only removed override modifier. But not sure where to put this method yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggested change? I only removed override modifier. But not sure where to put this method yet.

This change is OK to me.

*
* @param iter
* @return
*/
override def genCloseableColumnBatchIterator(
iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
def genCloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
if (iter.isInstanceOf[CloseableCHColumnBatchIterator]) iter
else new CloseableCHColumnBatchIterator(iter)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import io.glutenproject.substrait.rel.LocalFilesBuilder
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.vectorized._

import org.apache.spark.{InterruptibleIterator, Partition, SparkConf, SparkContext, TaskContext}
import org.apache.spark.{Partition, SparkConf, SparkContext, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
Expand All @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{BinaryType, DateType, StructType, TimestampType}
import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{ExecutorManager, TaskResources}
import org.apache.spark.util.{ExecutorManager, Iterators}

import java.net.URLDecoder
import java.nio.charset.StandardCharsets
Expand Down Expand Up @@ -124,17 +124,6 @@ class IteratorApiImpl extends IteratorApi with Logging {
GlutenPartition(index, substraitPlan, localFilesNodesWithLocations.head._2)
}

/**
* Generate closeable ColumnBatch iterator.
*
* @param iter
* @return
*/
override def genCloseableColumnBatchIterator(
iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
new CloseableColumnBatchIterator(iter)
}

/**
* Generate Iterator[ColumnarBatch] for first stage.
*
Expand All @@ -156,30 +145,18 @@ class IteratorApiImpl extends IteratorApi with Logging {
val resIter: GeneralOutIterator =
transKernel.createKernelWithBatchIterator(inputPartition.plan, columnarNativeIterators)
pipelineTime += TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)
TaskResources.addRecycler(s"FirstStageIterator_${resIter.getId}", 100)(resIter.close())
val iter = new Iterator[ColumnarBatch] {
private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics
var finished = false

override def hasNext: Boolean = {
val res = resIter.hasNext
if (!res) {
updateNativeMetrics(resIter.getMetrics)
updateInputMetrics(inputMetrics)
finished = true
}
res
}

override def next(): ColumnarBatch = {
if (finished) {
throw new java.util.NoSuchElementException("End of stream.")
}
resIter.next()
Iterators
.wrap(resIter.asScala)
.recycleIterator {
updateNativeMetrics(resIter.getMetrics)
updateInputMetrics(TaskContext.get().taskMetrics().inputMetrics)
resIter.close()
}
}

new InterruptibleIterator(context, new CloseableColumnBatchIterator(iter, Some(pipelineTime)))
.recyclePayload(batch => batch.close())
.addToPipelineTime(pipelineTime)
.asInterruptible(context)
.create()
}

// scalastyle:off argcount
Expand Down Expand Up @@ -213,25 +190,15 @@ class IteratorApiImpl extends IteratorApi with Logging {

pipelineTime += TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)

val resIter = new Iterator[ColumnarBatch] {
override def hasNext: Boolean = {
val res = nativeResultIterator.hasNext
if (!res) {
updateNativeMetrics(nativeResultIterator.getMetrics)
}
res
}

override def next(): ColumnarBatch = {
nativeResultIterator.next
Iterators
.wrap(nativeResultIterator.asScala)
.recycleIterator {
updateNativeMetrics(nativeResultIterator.getMetrics)
nativeResultIterator.close()
}
}

TaskResources.addRecycler(s"FinalStageIterator_${nativeResultIterator.getId}", 100) {
nativeResultIterator.close()
}

new CloseableColumnBatchIterator(resIter, Some(pipelineTime))
.recyclePayload(batch => batch.close())
.addToPipelineTime(pipelineTime)
.create()
}
// scalastyle:on argcount

Expand All @@ -254,6 +221,9 @@ class IteratorApiImpl extends IteratorApi with Logging {
broadcasted: Broadcast[BuildSideRelation],
broadCastContext: BroadCastHashJoinContext): Iterator[ColumnarBatch] = {
val relation = broadcasted.value.asReadOnlyCopy(broadCastContext)
new CloseableColumnBatchIterator(relation.deserialized)
Iterators
.wrap(relation.deserialized)
.recyclePayload(batch => batch.close())
.create()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.utils.SparkArrowUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.TaskResources
import org.apache.spark.util.{Iterators, TaskResources}

import org.apache.arrow.c.ArrowSchema
import org.apache.arrow.memory.ArrowBuf
Expand Down Expand Up @@ -93,7 +93,6 @@ object RowToVeloxColumnarExec {
val jniWrapper = NativeRowToColumnarJniWrapper.create()
val allocator = ArrowBufferAllocators.contextInstance()
val cSchema = ArrowSchema.allocateNew(allocator)
var closed = false
val r2cHandle =
try {
ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
Expand All @@ -106,26 +105,14 @@ object RowToVeloxColumnarExec {
cSchema.close()
}

TaskResources.addRecycler(s"RowToColumnar_$r2cHandle", 100) {
if (!closed) {
jniWrapper.close(r2cHandle)
closed = true
}
}

val res: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] {
var finished = false

override def hasNext: Boolean = {
if (finished) {
false
} else {
val itHasNext = it.hasNext
if (!itHasNext && !closed) {
jniWrapper.close(r2cHandle)
closed = true
}
itHasNext
it.hasNext
}
}

Expand Down Expand Up @@ -215,6 +202,12 @@ object RowToVeloxColumnarExec {
cb
}
}
new CloseableColumnBatchIterator(res)
Iterators
.wrap(res)
.recycleIterator {
jniWrapper.close(r2cHandle)
}
.recyclePayload(_.close())
.create()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.TaskResources
import org.apache.spark.util.Iterators

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -98,26 +98,13 @@ object VeloxColumnarToRowExec {

// TODO:: pass the jni jniWrapper and arrowSchema and serializeSchema method by broadcast
val jniWrapper = NativeColumnarToRowJniWrapper.create()
var closed = false
val c2rId = jniWrapper.nativeColumnarToRowInit(
NativeMemoryManagers.contextInstance("ColumnarToRow").getNativeInstanceHandle)

TaskResources.addRecycler(s"ColumnarToRow_$c2rId", 100) {
if (!closed) {
jniWrapper.nativeClose(c2rId)
closed = true
}
}

val res: Iterator[Iterator[InternalRow]] = new Iterator[Iterator[InternalRow]] {

override def hasNext: Boolean = {
val hasNext = batches.hasNext
if (!hasNext && !closed) {
jniWrapper.nativeClose(c2rId)
closed = true
}
hasNext
batches.hasNext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we didi not call recyclePayload for c2r after refactor since the iterator is row-based. Is there a leak ?

Copy link
Member Author

@zhztheplayer zhztheplayer Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't change any code about payload closing. The deleted lines were for iterator completion so moved torecycleIterator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it is closed by itself, then it seems c2r did not call recyclePayload ?

Copy link
Member Author

@zhztheplayer zhztheplayer Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it is closed by itself, then it seems c2r did not call recyclePayload ?

Yes as you see the batches are closed manually within some if-else conditions. I didn't how much effort needed to refactor the usages within #recyclePayload() so didn't change that part of code in this patch. Probably we can do that in a separate ticket, I am not sure.

}

override def next(): Iterator[InternalRow] = {
Expand Down Expand Up @@ -170,6 +157,11 @@ object VeloxColumnarToRowExec {
}
}
}
res.flatten
Iterators
.wrap(res.flatten)
.recycleIterator {
jniWrapper.nativeClose(c2rId)
}
.create()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import io.glutenproject.execution.{RowToVeloxColumnarExec, VeloxColumnarToRowExe
import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators
import io.glutenproject.memory.nmm.NativeMemoryManagers
import io.glutenproject.utils.ArrowAbiUtil
import io.glutenproject.vectorized.{CloseableColumnBatchIterator, ColumnarBatchSerializerJniWrapper}
import io.glutenproject.vectorized.ColumnarBatchSerializerJniWrapper

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
Expand All @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.utils.SparkArrowUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.TaskResources
import org.apache.spark.util.Iterators

import org.apache.arrow.c.ArrowSchema

Expand Down Expand Up @@ -244,33 +244,34 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with SQLConfHe
nmm.getNativeInstanceHandle
)
cSchema.close()
TaskResources.addRecycler(
s"ColumnarCachedBatchSerializer_convertCachedBatchToColumnarBatch_$deserializerHandle",
50) {
ColumnarBatchSerializerJniWrapper.create().close(deserializerHandle)
}

new CloseableColumnBatchIterator(new Iterator[ColumnarBatch] {
override def hasNext: Boolean = it.hasNext
Iterators
.wrap(new Iterator[ColumnarBatch] {
override def hasNext: Boolean = it.hasNext

override def next(): ColumnarBatch = {
val cachedBatch = it.next().asInstanceOf[CachedColumnarBatch]
val batchHandle =
ColumnarBatchSerializerJniWrapper
.create()
.deserialize(deserializerHandle, cachedBatch.bytes)
val batch = ColumnarBatches.create(Runtimes.contextInstance(), batchHandle)
if (shouldSelectAttributes) {
try {
ColumnarBatches.select(nmm, batch, requestedColumnIndices.toArray)
} finally {
batch.close()
override def next(): ColumnarBatch = {
val cachedBatch = it.next().asInstanceOf[CachedColumnarBatch]
val batchHandle =
ColumnarBatchSerializerJniWrapper
.create()
.deserialize(deserializerHandle, cachedBatch.bytes)
val batch = ColumnarBatches.create(Runtimes.contextInstance(), batchHandle)
if (shouldSelectAttributes) {
try {
ColumnarBatches.select(nmm, batch, requestedColumnIndices.toArray)
} finally {
batch.close()
}
} else {
batch
}
} else {
batch
}
})
.recycleIterator {
ColumnarBatchSerializerJniWrapper.create().close(deserializerHandle)
}
})
.recyclePayload(_.close())
.create()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
package org.apache.spark.sql.execution.datasources

import io.glutenproject.datasource.DatasourceJniWrapper
import io.glutenproject.vectorized.{CloseableColumnBatchIterator, ColumnarBatchInIterator}
import io.glutenproject.vectorized.ColumnarBatchInIterator

import org.apache.spark.TaskContext
import org.apache.spark.sql.execution.datasources.VeloxWriteQueue.EOS_BATCH
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Iterators

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.types.pojo.Schema
Expand Down Expand Up @@ -50,7 +51,8 @@ class VeloxWriteQueue(
try {
datasourceJniWrapper.write(
dsHandle,
new ColumnarBatchInIterator(new CloseableColumnBatchIterator(scanner).asJava))
new ColumnarBatchInIterator(
Iterators.wrap(scanner).recyclePayload(_.close()).create().asJava))
} catch {
case e: Exception =>
writeException.set(e)
Expand Down
10 changes: 10 additions & 0 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,16 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_shr
JNI_METHOD_END(kInvalidResourceHandle)
}

JNIEXPORT void JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_hold( // NOLINT
JNIEnv* env,
jclass,
jlong memoryManagerHandle) {
JNI_METHOD_START
auto memoryManager = jniCastOrThrow<MemoryManager>(memoryManagerHandle);
memoryManager->hold();
JNI_METHOD_END()
}

JNIEXPORT void JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_release( // NOLINT
JNIEnv* env,
jclass,
Expand Down
5 changes: 5 additions & 0 deletions cpp/core/memory/MemoryManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ class MemoryManager {
virtual const MemoryUsageStats collectMemoryUsageStats() const = 0;

virtual const int64_t shrink(int64_t size) = 0;

// Hold this memory manager. The underlying memory pools will be released as lately as this memory manager gets
// destroyed. Which means, a call to this function would make sure the memory blocks directly or indirectly managed
// by this manager, be guaranteed safe to access during the period that this manager is alive.
virtual void hold() = 0;
};

} // namespace gluten
17 changes: 17 additions & 0 deletions cpp/velox/memory/VeloxMemoryManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,23 @@ const int64_t VeloxMemoryManager::shrink(int64_t size) {
return shrinkVeloxMemoryPool(veloxAggregatePool_.get(), size);
}

namespace {
void holdInternal(
std::vector<std::shared_ptr<facebook::velox::memory::MemoryPool>>& heldVeloxPools,
const velox::memory::MemoryPool* pool) {
pool->visitChildren([&](velox::memory::MemoryPool* child) -> bool {
auto shared = child->shared_from_this();
heldVeloxPools.push_back(shared);
holdInternal(heldVeloxPools, child);
return true;
});
}
} // namespace

void VeloxMemoryManager::hold() {
holdInternal(heldVeloxPools_, veloxAggregatePool_.get());
}

velox::memory::MemoryManager* getDefaultVeloxMemoryManager() {
return &(facebook::velox::memory::defaultMemoryManager());
}
Expand Down
Loading
Loading