From 39a5162770c26ac6ed7c7c38cc078b28b98478cc Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 30 Oct 2023 19:21:14 +0800 Subject: [PATCH] [VL] Iterator's and its payloads' lifecycle improvements (#3526) --- .../clickhouse/CHIteratorApi.scala | 5 +- .../backendsapi/velox/IteratorApiImpl.scala | 79 +++---- .../execution/RowToVeloxColumnarExec.scala | 25 +-- .../execution/VeloxColumnarToRowExec.scala | 26 +-- .../ColumnarCachedBatchSerializer.scala | 50 ++--- .../datasources/VeloxWriteQueue.scala | 6 +- cpp/core/jni/JniWrapper.cc | 12 +- cpp/core/memory/MemoryManager.h | 5 + cpp/velox/memory/VeloxMemoryManager.cc | 17 ++ cpp/velox/memory/VeloxMemoryManager.h | 3 + ...VeloxCelebornColumnarBatchSerializer.scala | 7 +- .../vectorized/GeneralOutIterator.java | 23 +- .../backendsapi/IteratorApi.scala | 7 - .../io/glutenproject/utils/Iterators.scala | 201 ++++++++++++++++++ .../org/apache/spark/util/TaskResources.scala | 20 +- .../glutenproject/utils/IteratorSuite.scala | 144 +++++++++++++ .../memory/nmm/NativeMemoryManager.java | 10 + .../vectorized/ColumnarBatchOutIterator.java | 8 +- .../vectorized/NativePlanEvaluator.java | 39 ++-- .../vectorized/PlanEvaluatorJniWrapper.java | 10 - .../CloseableColumnBatchIterator.scala | 67 ------ .../vectorized/ColumnarBatchSerializer.scala | 6 +- .../execution/ColumnarBuildSideRelation.scala | 79 +++---- .../spark/sql/execution/utils/ExecUtil.scala | 128 ++++------- 24 files changed, 618 insertions(+), 359 deletions(-) create mode 100644 gluten-core/src/main/scala/io/glutenproject/utils/Iterators.scala create mode 100644 gluten-core/src/test/scala/io/glutenproject/utils/IteratorSuite.scala delete mode 100644 gluten-data/src/main/scala/io/glutenproject/vectorized/CloseableColumnBatchIterator.scala diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala index 458edae515c8..d19a1f506c3c 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala @@ -230,11 +230,12 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { /** * Generate closeable ColumnBatch iterator. * + * FIXME: This no longer overrides parent's method + * * @param iter * @return */ - override def genCloseableColumnBatchIterator( - iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { + def genCloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { if (iter.isInstanceOf[CloseableCHColumnBatchIterator]) iter else new CloseableCHColumnBatchIterator(iter) } diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala index fb5ff60bd005..bab188277abb 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala @@ -23,9 +23,10 @@ import io.glutenproject.metrics.IMetrics import io.glutenproject.substrait.plan.PlanNode import io.glutenproject.substrait.rel.LocalFilesBuilder import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat +import io.glutenproject.utils.Iterators import io.glutenproject.vectorized._ -import org.apache.spark.{InterruptibleIterator, Partition, SparkConf, SparkContext, TaskContext} +import org.apache.spark.{Partition, SparkConf, SparkContext, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -39,7 +40,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.{BinaryType, DateType, StructType, TimestampType} import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.{ExecutorManager, TaskResources} +import org.apache.spark.util.ExecutorManager import java.net.URLDecoder import java.nio.charset.StandardCharsets @@ -124,17 +125,6 @@ class IteratorApiImpl extends IteratorApi with Logging { GlutenPartition(index, substraitPlan, localFilesNodesWithLocations.head._2) } - /** - * Generate closeable ColumnBatch iterator. - * - * @param iter - * @return - */ - override def genCloseableColumnBatchIterator( - iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { - new CloseableColumnBatchIterator(iter) - } - /** * Generate Iterator[ColumnarBatch] for first stage. * @@ -156,30 +146,18 @@ class IteratorApiImpl extends IteratorApi with Logging { val resIter: GeneralOutIterator = transKernel.createKernelWithBatchIterator(inputPartition.plan, columnarNativeIterators) pipelineTime += TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - beforeBuild) - TaskResources.addRecycler(s"FirstStageIterator_${resIter.getId}", 100)(resIter.close()) - val iter = new Iterator[ColumnarBatch] { - private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics - var finished = false - - override def hasNext: Boolean = { - val res = resIter.hasNext - if (!res) { - updateNativeMetrics(resIter.getMetrics) - updateInputMetrics(inputMetrics) - finished = true - } - res - } - override def next(): ColumnarBatch = { - if (finished) { - throw new java.util.NoSuchElementException("End of stream.") - } - resIter.next() + Iterators + .wrap(resIter.asScala) + .recycleIterator { + updateNativeMetrics(resIter.getMetrics) + updateInputMetrics(TaskContext.get().taskMetrics().inputMetrics) + resIter.close() } - } - - new InterruptibleIterator(context, new CloseableColumnBatchIterator(iter, Some(pipelineTime))) + .recyclePayload(batch => batch.close()) + .addToPipelineTime(pipelineTime) + .asInterruptible(context) + .create() } // scalastyle:off argcount @@ -213,25 +191,15 @@ class IteratorApiImpl extends IteratorApi with Logging { pipelineTime += TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - beforeBuild) - val resIter = new Iterator[ColumnarBatch] { - override def hasNext: Boolean = { - val res = nativeResultIterator.hasNext - if (!res) { - updateNativeMetrics(nativeResultIterator.getMetrics) - } - res - } - - override def next(): ColumnarBatch = { - nativeResultIterator.next + Iterators + .wrap(nativeResultIterator.asScala) + .recycleIterator { + updateNativeMetrics(nativeResultIterator.getMetrics) + nativeResultIterator.close() } - } - - TaskResources.addRecycler(s"FinalStageIterator_${nativeResultIterator.getId}", 100) { - nativeResultIterator.close() - } - - new CloseableColumnBatchIterator(resIter, Some(pipelineTime)) + .recyclePayload(batch => batch.close()) + .addToPipelineTime(pipelineTime) + .create() } // scalastyle:on argcount @@ -254,6 +222,9 @@ class IteratorApiImpl extends IteratorApi with Logging { broadcasted: Broadcast[BuildSideRelation], broadCastContext: BroadCastHashJoinContext): Iterator[ColumnarBatch] = { val relation = broadcasted.value.asReadOnlyCopy(broadCastContext) - new CloseableColumnBatchIterator(relation.deserialized) + Iterators + .wrap(relation.deserialized) + .recyclePayload(batch => batch.close()) + .create() } } diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/RowToVeloxColumnarExec.scala b/backends-velox/src/main/scala/io/glutenproject/execution/RowToVeloxColumnarExec.scala index 373f9c473f96..d898d9f18b53 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/RowToVeloxColumnarExec.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/RowToVeloxColumnarExec.scala @@ -21,7 +21,7 @@ import io.glutenproject.columnarbatch.ColumnarBatches import io.glutenproject.exec.Runtimes import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators import io.glutenproject.memory.nmm.NativeMemoryManagers -import io.glutenproject.utils.ArrowAbiUtil +import io.glutenproject.utils.{ArrowAbiUtil, Iterators} import io.glutenproject.vectorized._ import org.apache.spark.rdd.RDD @@ -93,7 +93,6 @@ object RowToVeloxColumnarExec { val jniWrapper = NativeRowToColumnarJniWrapper.create() val allocator = ArrowBufferAllocators.contextInstance() val cSchema = ArrowSchema.allocateNew(allocator) - var closed = false val r2cHandle = try { ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema) @@ -106,13 +105,6 @@ object RowToVeloxColumnarExec { cSchema.close() } - TaskResources.addRecycler(s"RowToColumnar_$r2cHandle", 100) { - if (!closed) { - jniWrapper.close(r2cHandle) - closed = true - } - } - val res: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { var finished = false @@ -120,12 +112,7 @@ object RowToVeloxColumnarExec { if (finished) { false } else { - val itHasNext = it.hasNext - if (!itHasNext && !closed) { - jniWrapper.close(r2cHandle) - closed = true - } - itHasNext + it.hasNext } } @@ -215,6 +202,12 @@ object RowToVeloxColumnarExec { cb } } - new CloseableColumnBatchIterator(res) + Iterators + .wrap(res) + .recycleIterator { + jniWrapper.close(r2cHandle) + } + .recyclePayload(_.close()) + .create() } } diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/VeloxColumnarToRowExec.scala b/backends-velox/src/main/scala/io/glutenproject/execution/VeloxColumnarToRowExec.scala index e0659a9c67a5..3428e8b7e24a 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/VeloxColumnarToRowExec.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/VeloxColumnarToRowExec.scala @@ -19,6 +19,7 @@ package io.glutenproject.execution import io.glutenproject.columnarbatch.ColumnarBatches import io.glutenproject.extension.ValidationResult import io.glutenproject.memory.nmm.NativeMemoryManagers +import io.glutenproject.utils.Iterators import io.glutenproject.vectorized.NativeColumnarToRowJniWrapper import org.apache.spark.rdd.RDD @@ -28,7 +29,6 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.TaskResources import scala.collection.JavaConverters._ @@ -98,26 +98,13 @@ object VeloxColumnarToRowExec { // TODO:: pass the jni jniWrapper and arrowSchema and serializeSchema method by broadcast val jniWrapper = NativeColumnarToRowJniWrapper.create() - var closed = false val c2rId = jniWrapper.nativeColumnarToRowInit( NativeMemoryManagers.contextInstance("ColumnarToRow").getNativeInstanceHandle) - TaskResources.addRecycler(s"ColumnarToRow_$c2rId", 100) { - if (!closed) { - jniWrapper.nativeClose(c2rId) - closed = true - } - } - val res: Iterator[Iterator[InternalRow]] = new Iterator[Iterator[InternalRow]] { override def hasNext: Boolean = { - val hasNext = batches.hasNext - if (!hasNext && !closed) { - jniWrapper.nativeClose(c2rId) - closed = true - } - hasNext + batches.hasNext } override def next(): Iterator[InternalRow] = { @@ -170,6 +157,13 @@ object VeloxColumnarToRowExec { } } } - res.flatten + Iterators + .wrap(res.flatten) + .protectInvocationFlow() // Spark may call `hasNext()` again after a false output which + // is not allowed by Gluten iterators. E.g. GroupedIterator#fetchNextGroupIterator + .recycleIterator { + jniWrapper.nativeClose(c2rId) + } + .create() } } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala index 93aafcced9b0..1012d8027957 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala @@ -24,8 +24,8 @@ import io.glutenproject.exec.Runtimes import io.glutenproject.execution.{RowToVeloxColumnarExec, VeloxColumnarToRowExec} import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators import io.glutenproject.memory.nmm.NativeMemoryManagers -import io.glutenproject.utils.ArrowAbiUtil -import io.glutenproject.vectorized.{CloseableColumnBatchIterator, ColumnarBatchSerializerJniWrapper} +import io.glutenproject.utils.{ArrowAbiUtil, Iterators} +import io.glutenproject.vectorized.ColumnarBatchSerializerJniWrapper import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession @@ -38,7 +38,6 @@ import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.utils.SparkArrowUtil import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.TaskResources import org.apache.arrow.c.ArrowSchema @@ -244,33 +243,34 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with SQLConfHe nmm.getNativeInstanceHandle ) cSchema.close() - TaskResources.addRecycler( - s"ColumnarCachedBatchSerializer_convertCachedBatchToColumnarBatch_$deserializerHandle", - 50) { - ColumnarBatchSerializerJniWrapper.create().close(deserializerHandle) - } - new CloseableColumnBatchIterator(new Iterator[ColumnarBatch] { - override def hasNext: Boolean = it.hasNext + Iterators + .wrap(new Iterator[ColumnarBatch] { + override def hasNext: Boolean = it.hasNext - override def next(): ColumnarBatch = { - val cachedBatch = it.next().asInstanceOf[CachedColumnarBatch] - val batchHandle = - ColumnarBatchSerializerJniWrapper - .create() - .deserialize(deserializerHandle, cachedBatch.bytes) - val batch = ColumnarBatches.create(Runtimes.contextInstance(), batchHandle) - if (shouldSelectAttributes) { - try { - ColumnarBatches.select(nmm, batch, requestedColumnIndices.toArray) - } finally { - batch.close() + override def next(): ColumnarBatch = { + val cachedBatch = it.next().asInstanceOf[CachedColumnarBatch] + val batchHandle = + ColumnarBatchSerializerJniWrapper + .create() + .deserialize(deserializerHandle, cachedBatch.bytes) + val batch = ColumnarBatches.create(Runtimes.contextInstance(), batchHandle) + if (shouldSelectAttributes) { + try { + ColumnarBatches.select(nmm, batch, requestedColumnIndices.toArray) + } finally { + batch.close() + } + } else { + batch } - } else { - batch } + }) + .recycleIterator { + ColumnarBatchSerializerJniWrapper.create().close(deserializerHandle) } - }) + .recyclePayload(_.close()) + .create() } } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala index 3ff69a24737c..1aee3ed3394c 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.datasources import io.glutenproject.datasource.DatasourceJniWrapper -import io.glutenproject.vectorized.{CloseableColumnBatchIterator, ColumnarBatchInIterator} +import io.glutenproject.utils.Iterators +import io.glutenproject.vectorized.ColumnarBatchInIterator import org.apache.spark.TaskContext import org.apache.spark.sql.execution.datasources.VeloxWriteQueue.EOS_BATCH @@ -50,7 +51,8 @@ class VeloxWriteQueue( try { datasourceJniWrapper.write( dsHandle, - new ColumnarBatchInIterator(new CloseableColumnBatchIterator(scanner).asJava)) + new ColumnarBatchInIterator( + Iterators.wrap(scanner).recyclePayload(_.close()).create().asJava)) } catch { case e: Exception => writeException.set(e) diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index e333a49c85c0..543a8f21fb82 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -389,7 +389,7 @@ JNIEXPORT jboolean JNICALL Java_io_glutenproject_vectorized_ColumnarBatchOutIter auto iter = ctx->objectStore()->retrieve(iterHandle); if (iter == nullptr) { - std::string errorMessage = "faked to get batch iterator"; + std::string errorMessage = "failed to get batch iterator"; throw gluten::GlutenException(errorMessage); } return iter->hasNext(); @@ -1272,6 +1272,16 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_shr JNI_METHOD_END(kInvalidResourceHandle) } +JNIEXPORT void JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_hold( // NOLINT + JNIEnv* env, + jclass, + jlong memoryManagerHandle) { + JNI_METHOD_START + auto memoryManager = jniCastOrThrow(memoryManagerHandle); + memoryManager->hold(); + JNI_METHOD_END() +} + JNIEXPORT void JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_release( // NOLINT JNIEnv* env, jclass, diff --git a/cpp/core/memory/MemoryManager.h b/cpp/core/memory/MemoryManager.h index 1a256596cc7c..5ec5213051ae 100644 --- a/cpp/core/memory/MemoryManager.h +++ b/cpp/core/memory/MemoryManager.h @@ -33,6 +33,11 @@ class MemoryManager { virtual const MemoryUsageStats collectMemoryUsageStats() const = 0; virtual const int64_t shrink(int64_t size) = 0; + + // Hold this memory manager. The underlying memory pools will be released as lately as this memory manager gets + // destroyed. Which means, a call to this function would make sure the memory blocks directly or indirectly managed + // by this manager, be guaranteed safe to access during the period that this manager is alive. + virtual void hold() = 0; }; } // namespace gluten diff --git a/cpp/velox/memory/VeloxMemoryManager.cc b/cpp/velox/memory/VeloxMemoryManager.cc index 843e856db113..a8313310ac8d 100644 --- a/cpp/velox/memory/VeloxMemoryManager.cc +++ b/cpp/velox/memory/VeloxMemoryManager.cc @@ -215,6 +215,23 @@ const int64_t VeloxMemoryManager::shrink(int64_t size) { return shrinkVeloxMemoryPool(veloxAggregatePool_.get(), size); } +namespace { +void holdInternal( + std::vector>& heldVeloxPools, + const velox::memory::MemoryPool* pool) { + pool->visitChildren([&](velox::memory::MemoryPool* child) -> bool { + auto shared = child->shared_from_this(); + heldVeloxPools.push_back(shared); + holdInternal(heldVeloxPools, child); + return true; + }); +} +} // namespace + +void VeloxMemoryManager::hold() { + holdInternal(heldVeloxPools_, veloxAggregatePool_.get()); +} + velox::memory::MemoryManager* getDefaultVeloxMemoryManager() { return &(facebook::velox::memory::defaultMemoryManager()); } diff --git a/cpp/velox/memory/VeloxMemoryManager.h b/cpp/velox/memory/VeloxMemoryManager.h index f9614043a09b..faf48c344fca 100644 --- a/cpp/velox/memory/VeloxMemoryManager.h +++ b/cpp/velox/memory/VeloxMemoryManager.h @@ -53,6 +53,8 @@ class VeloxMemoryManager final : public MemoryManager { const int64_t shrink(int64_t size) override; + void hold() override; + private: std::string name_; @@ -68,6 +70,7 @@ class VeloxMemoryManager final : public MemoryManager { std::unique_ptr veloxMemoryManager_; std::shared_ptr veloxAggregatePool_; std::shared_ptr veloxLeafPool_; + std::vector> heldVeloxPools_; }; /// Not tracked by Spark and should only used in test or validation. diff --git a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala index 8ac261479876..21be1642ff72 100644 --- a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala +++ b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala @@ -62,6 +62,8 @@ private class CelebornColumnarBatchSerializerInstance( extends SerializerInstance with Logging { + private lazy val nmm = NativeMemoryManagers.contextInstance("ShuffleReader") + private lazy val shuffleReaderHandle = { val allocator: BufferAllocator = ArrowBufferAllocators .contextInstance() @@ -83,7 +85,7 @@ private class CelebornColumnarBatchSerializerInstance( .create() .make( cSchema.memoryAddress(), - NativeMemoryManagers.contextInstance("ShuffleReader").getNativeInstanceHandle, + nmm.getNativeInstanceHandle, compressionCodec, compressionCodecBackend ) @@ -106,7 +108,8 @@ private class CelebornColumnarBatchSerializerInstance( Runtimes.contextInstance(), ShuffleReaderJniWrapper .create() - .readStream(shuffleReaderHandle, byteIn)) + .readStream(shuffleReaderHandle, byteIn), + nmm) private var cb: ColumnarBatch = _ diff --git a/gluten-core/src/main/java/io/glutenproject/vectorized/GeneralOutIterator.java b/gluten-core/src/main/java/io/glutenproject/vectorized/GeneralOutIterator.java index 216f00c21d32..28b857ef2a42 100644 --- a/gluten-core/src/main/java/io/glutenproject/vectorized/GeneralOutIterator.java +++ b/gluten-core/src/main/java/io/glutenproject/vectorized/GeneralOutIterator.java @@ -16,24 +16,37 @@ */ package io.glutenproject.vectorized; +import io.glutenproject.exception.GlutenException; import io.glutenproject.metrics.IMetrics; import org.apache.spark.sql.vectorized.ColumnarBatch; import java.io.Serializable; +import java.util.Iterator; import java.util.concurrent.atomic.AtomicBoolean; -public abstract class GeneralOutIterator implements AutoCloseable, Serializable { +public abstract class GeneralOutIterator + implements AutoCloseable, Serializable, Iterator { protected final AtomicBoolean closed = new AtomicBoolean(false); public GeneralOutIterator() {} - public final boolean hasNext() throws Exception { - return hasNextInternal(); + @Override + public final boolean hasNext() { + try { + return hasNextInternal(); + } catch (Exception e) { + throw new GlutenException(e); + } } - public final ColumnarBatch next() throws Exception { - return nextInternal(); + @Override + public final ColumnarBatch next() { + try { + return nextInternal(); + } catch (Exception e) { + throw new GlutenException(e); + } } public final IMetrics getMetrics() throws Exception { diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala index 13ea170c6378..e6e956a6791f 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala @@ -45,13 +45,6 @@ trait IteratorApi { fileFormats: Seq[ReadFileFormat], wsCxt: WholeStageTransformContext): BaseGlutenPartition - /** - * Generate closeable ColumnBatch iterator. - * - * @return - */ - def genCloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] - /** * Generate Iterator[ColumnarBatch] for first stage. ("first" means it does not depend on other * SCAN inputs) diff --git a/gluten-core/src/main/scala/io/glutenproject/utils/Iterators.scala b/gluten-core/src/main/scala/io/glutenproject/utils/Iterators.scala new file mode 100644 index 000000000000..5e370a42c616 --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/utils/Iterators.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.utils + +import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.util.TaskResources + +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +private class PayloadCloser[A](in: Iterator[A])(closeCallback: A => Unit) extends Iterator[A] { + private var closer: Option[() => Unit] = None + + TaskResources.addRecycler("Iterators#PayloadCloser", 100) { + tryClose() + } + + override def hasNext: Boolean = { + tryClose() + in.hasNext + } + + override def next(): A = { + val a: A = in.next() + closer.synchronized { + closer = Some( + () => { + closeCallback.apply(a) + }) + } + a + } + + private def tryClose(): Unit = { + closer.synchronized { + closer match { + case Some(c) => c.apply() + case None => + } + closer = None // make sure the payload is closed once + } + } +} + +private class IteratorCompleter[A](in: Iterator[A])(completionCallback: => Unit) + extends Iterator[A] { + private val completed = new AtomicBoolean(false) + + TaskResources.addRecycler("Iterators#IteratorRecycler", 100) { + tryComplete() + } + + override def hasNext: Boolean = { + val out = in.hasNext + if (!out) { + tryComplete() + } + out + } + + override def next(): A = { + in.next() + } + + private def tryComplete(): Unit = { + if (!completed.compareAndSet(false, true)) { + return // make sure the iterator is completed once + } + completionCallback + } +} + +private class PipelineTimeAccumulator[A](in: Iterator[A], pipelineTime: SQLMetric) + extends Iterator[A] { + private val accumulatedTime: AtomicLong = new AtomicLong(0L) + + TaskResources.addRecycler("Iterators#PipelineTimeAccumulator", 100) { + tryFinish() + } + + override def hasNext: Boolean = { + val prev = System.nanoTime() + val out = in.hasNext + accumulatedTime.addAndGet(System.nanoTime() - prev) + if (!out) { + tryFinish() + } + out + } + + override def next(): A = { + val prev = System.nanoTime() + val out = in.next() + accumulatedTime.addAndGet(System.nanoTime() - prev) + out + } + + private def tryFinish(): Unit = { + pipelineTime += accumulatedTime.getAndSet( + 0L + ) // make sure the accumulated time is submitted once + } +} + +/** + * To protect the wrapped iterator to avoid undesired order of calls to its `hasNext` and `next` + * methods. + */ +private class InvocationFlowProtection[A](in: Iterator[A]) extends Iterator[A] { + sealed private trait State; + private case object Init extends State; + private case class HasNextCalled(hasNext: Boolean) extends State; + private case object NextCalled extends State; + + private var state: State = Init + + override def hasNext: Boolean = { + val out = state match { + case Init | NextCalled => + in.hasNext + case HasNextCalled(lastHasNext) => + lastHasNext + } + state = HasNextCalled(out) + out + } + + override def next(): A = { + val out = state match { + case Init | NextCalled => + if (!in.hasNext) { + throw new IllegalStateException("End of stream") + } + in.next() + case HasNextCalled(lastHasNext) => + if (!lastHasNext) { + throw new IllegalStateException("End of stream") + } + in.next() + } + state = NextCalled + out + } +} + +class WrapperBuilder[A](in: Iterator[A]) { // FIXME how to make the ctor companion-private? + private var wrapped: Iterator[A] = in + + def recyclePayload(closeCallback: (A) => Unit): WrapperBuilder[A] = { + wrapped = new PayloadCloser(wrapped)(closeCallback) + this + } + + def recycleIterator(completionCallback: => Unit): WrapperBuilder[A] = { + wrapped = new IteratorCompleter(wrapped)(completionCallback) + this + } + + def addToPipelineTime(pipelineTime: SQLMetric): WrapperBuilder[A] = { + wrapped = new PipelineTimeAccumulator[A](wrapped, pipelineTime) + this + } + + def asInterruptible(context: TaskContext): WrapperBuilder[A] = { + wrapped = new InterruptibleIterator[A](context, wrapped) + this + } + + def protectInvocationFlow(): WrapperBuilder[A] = { + wrapped = new InvocationFlowProtection[A](wrapped) + this + } + + def create(): Iterator[A] = { + wrapped + } +} + +/** + * Utility class to provide iterator wrappers for non-trivial use cases. E.g. iterators that manage + * payload's lifecycle. + */ +object Iterators { + def wrap[A](in: Iterator[A]): WrapperBuilder[A] = { + new WrapperBuilder[A](in) + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/util/TaskResources.scala b/gluten-core/src/main/scala/org/apache/spark/util/TaskResources.scala index e57badad5faa..4cb33186ad1e 100644 --- a/gluten-core/src/main/scala/org/apache/spark/util/TaskResources.scala +++ b/gluten-core/src/main/scala/org/apache/spark/util/TaskResources.scala @@ -36,6 +36,20 @@ object TaskResources extends TaskListener with Logging { } val ACCUMULATED_LEAK_BYTES = new AtomicLong(0L) + // For testing purpose only + private var fallbackRegistry: Option[TaskResourceRegistry] = None + + // For testing purpose only + def setFallbackRegistry(r: TaskResourceRegistry): Unit = { + fallbackRegistry = Some(r) + } + + // For testing purpose only + def unsetFallbackRegistry(): Unit = { + fallbackRegistry.foreach(r => r.releaseAll()) + fallbackRegistry = None + } + private val RESOURCE_REGISTRIES = new java.util.IdentityHashMap[TaskContext, TaskResourceRegistry]() @@ -52,7 +66,11 @@ object TaskResources extends TaskListener with Logging { logWarning( "Using the fallback instance of TaskResourceRegistry. " + "This should only happen when call is not from Spark task.") - throw new IllegalStateException("Found a caller not in Spark task scope.") + return fallbackRegistry match { + case Some(r) => r + case _ => + throw new IllegalStateException("No fallback instance of TaskResourceRegistry found.") + } } val tc = getLocalTaskContext() RESOURCE_REGISTRIES.synchronized { diff --git a/gluten-core/src/test/scala/io/glutenproject/utils/IteratorSuite.scala b/gluten-core/src/test/scala/io/glutenproject/utils/IteratorSuite.scala new file mode 100644 index 000000000000..88b2eb63126b --- /dev/null +++ b/gluten-core/src/test/scala/io/glutenproject/utils/IteratorSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.utils + +import org.apache.spark.util.{TaskResourceRegistry, TaskResources} + +import org.scalatest.funsuite.AnyFunSuite + +class IteratorSuite extends AnyFunSuite { + test("Trivial wrapping") { + val strings = Array[String]("one", "two", "three") + val itr = strings.toIterator + val wrapped = Iterators + .wrap(itr) + .create() + assertResult(strings) { + wrapped.toArray + } + } + + test("Complete iterator") { + var completeCount = 0 + withFakeTaskContext { + val strings = Array[String]("one", "two", "three") + val itr = strings.toIterator + val wrapped = Iterators + .wrap(itr) + .recycleIterator { + completeCount += 1 + } + .create() + assertResult(strings) { + wrapped.toArray + } + assert(completeCount == 1) + } + assert(completeCount == 1) + } + + test("Complete intermediate iterator") { + var completeCount = 0 + withFakeTaskContext { + val strings = Array[String]("one", "two", "three") + val itr = strings.toIterator + val _ = Iterators + .wrap(itr) + .recycleIterator { + completeCount += 1 + } + .create() + assert(completeCount == 0) + } + assert(completeCount == 1) + } + + test("Close payload") { + var closeCount = 0 + withFakeTaskContext { + val strings = Array[String]("one", "two", "three") + val itr = strings.toIterator + val wrapped = Iterators + .wrap(itr) + .recyclePayload { _: String => closeCount += 1 } + .create() + assertResult(strings) { + wrapped.toArray + } + assert(closeCount == 3) + } + assert(closeCount == 3) + } + + test("Close intermediate payload") { + var closeCount = 0 + withFakeTaskContext { + val strings = Array[String]("one", "two", "three") + val itr = strings.toIterator + val wrapped = Iterators + .wrap(itr) + .recyclePayload { _: String => closeCount += 1 } + .create() + assertResult(strings.take(2)) { + wrapped.take(2).toArray + } + assert(closeCount == 1) // the first one is closed after consumed + } + assert(closeCount == 2) // the second one is closed on task exit + } + + test("Protect invocation flow") { + var hasNextCallCount = 0 + var nextCallCount = 0 + val itr = new Iterator[Any] { + override def hasNext: Boolean = { + hasNextCallCount += 1 + true + } + + override def next(): Any = { + nextCallCount += 1 + new Object + } + } + val wrapped = Iterators + .wrap(itr) + .protectInvocationFlow() + .create() + wrapped.hasNext + assert(hasNextCallCount == 1) + assert(nextCallCount == 0) + wrapped.hasNext + assert(hasNextCallCount == 1) + assert(nextCallCount == 0) + wrapped.next + assert(hasNextCallCount == 1) + assert(nextCallCount == 1) + wrapped.next + assert(hasNextCallCount == 2) + assert(nextCallCount == 2) + } + + private def withFakeTaskContext(body: => Unit): Unit = { + TaskResources.setFallbackRegistry(new TaskResourceRegistry) + try { + body + } finally { + TaskResources.unsetFallbackRegistry() + } + } +} diff --git a/gluten-data/src/main/java/io/glutenproject/memory/nmm/NativeMemoryManager.java b/gluten-data/src/main/java/io/glutenproject/memory/nmm/NativeMemoryManager.java index 33a7871d5d7b..797f75a7e9dd 100644 --- a/gluten-data/src/main/java/io/glutenproject/memory/nmm/NativeMemoryManager.java +++ b/gluten-data/src/main/java/io/glutenproject/memory/nmm/NativeMemoryManager.java @@ -62,6 +62,14 @@ public long shrink(long size) { return shrink(nativeInstanceHandle, size); } + // Hold this memory manager. The underlying memory pools will be released as lately as this + // memory manager gets destroyed. Which means, a call to this function would make sure the + // memory blocks directly or indirectly managed by this manager, be guaranteed safe to + // access during the period that this manager is alive. + public void hold() { + hold(nativeInstanceHandle); + } + private static native long shrink(long memoryManagerId, long size); private static native long create( @@ -75,6 +83,8 @@ private static native long create( private static native byte[] collectMemoryUsage(long memoryManagerId); + private static native void hold(long memoryManagerId); + @Override public void release() throws Exception { release(nativeInstanceHandle); diff --git a/gluten-data/src/main/java/io/glutenproject/vectorized/ColumnarBatchOutIterator.java b/gluten-data/src/main/java/io/glutenproject/vectorized/ColumnarBatchOutIterator.java index 1b3f3fe697ba..a6428d1fd871 100644 --- a/gluten-data/src/main/java/io/glutenproject/vectorized/ColumnarBatchOutIterator.java +++ b/gluten-data/src/main/java/io/glutenproject/vectorized/ColumnarBatchOutIterator.java @@ -19,6 +19,7 @@ import io.glutenproject.columnarbatch.ColumnarBatches; import io.glutenproject.exec.Runtime; import io.glutenproject.exec.RuntimeAware; +import io.glutenproject.memory.nmm.NativeMemoryManager; import io.glutenproject.metrics.IMetrics; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -28,11 +29,14 @@ public class ColumnarBatchOutIterator extends GeneralOutIterator implements RuntimeAware { private final Runtime runtime; private final long iterHandle; + private final NativeMemoryManager nmm; - public ColumnarBatchOutIterator(Runtime runtime, long iterHandle) throws IOException { + public ColumnarBatchOutIterator(Runtime runtime, long iterHandle, NativeMemoryManager nmm) + throws IOException { super(); this.runtime = runtime; this.iterHandle = iterHandle; + this.nmm = nmm; } @Override @@ -81,6 +85,8 @@ public long spill(long size) { @Override public void closeInternal() { + nmm.hold(); // to make sure the outputted batches are still accessible after the iterator is + // closed nativeClose(iterHandle); } } diff --git a/gluten-data/src/main/java/io/glutenproject/vectorized/NativePlanEvaluator.java b/gluten-data/src/main/java/io/glutenproject/vectorized/NativePlanEvaluator.java index 84c3ca06e623..44a43f016f4b 100644 --- a/gluten-data/src/main/java/io/glutenproject/vectorized/NativePlanEvaluator.java +++ b/gluten-data/src/main/java/io/glutenproject/vectorized/NativePlanEvaluator.java @@ -19,6 +19,7 @@ import io.glutenproject.backendsapi.BackendsApiManager; import io.glutenproject.exec.Runtime; import io.glutenproject.exec.Runtimes; +import io.glutenproject.memory.nmm.NativeMemoryManager; import io.glutenproject.memory.nmm.NativeMemoryManagers; import io.glutenproject.utils.DebugUtil; import io.glutenproject.validate.NativePlanValidationInfo; @@ -59,22 +60,22 @@ public NativePlanValidationInfo doNativeValidateWithFailureReason(byte[] subPlan public GeneralOutIterator createKernelWithBatchIterator( Plan wsPlan, List iterList) throws RuntimeException, IOException { final AtomicReference outIterator = new AtomicReference<>(); - final long memoryManagerHandle = + final NativeMemoryManager nmm = NativeMemoryManagers.create( - "WholeStageIterator", - (self, size) -> { - ColumnarBatchOutIterator instance = - Optional.of(outIterator.get()) - .orElseThrow( - () -> - new IllegalStateException( - "Fatal: spill() called before a output iterator " - + "is created. This behavior should be optimized " - + "by moving memory allocations from create() to " - + "hasNext()/next()")); - return instance.spill(size); - }) - .getNativeInstanceHandle(); + "WholeStageIterator", + (self, size) -> { + ColumnarBatchOutIterator instance = + Optional.of(outIterator.get()) + .orElseThrow( + () -> + new IllegalStateException( + "Fatal: spill() called before a output iterator " + + "is created. This behavior should be optimized " + + "by moving memory allocations from create() to " + + "hasNext()/next()")); + return instance.spill(size); + }); + final long memoryManagerHandle = nmm.getNativeInstanceHandle(); final String spillDirPath = SparkDirectoryUtil.namespace("gluten-spill") @@ -91,13 +92,13 @@ public GeneralOutIterator createKernelWithBatchIterator( TaskContext.get().taskAttemptId(), DebugUtil.saveInputToFile(), BackendsApiManager.getSparkPlanExecApiInstance().rewriteSpillPath(spillDirPath)); - outIterator.set(createOutIterator(Runtimes.contextInstance(), iterHandle)); + outIterator.set(createOutIterator(Runtimes.contextInstance(), iterHandle, nmm)); return outIterator.get(); } - private ColumnarBatchOutIterator createOutIterator(Runtime runtime, long iterHandle) - throws IOException { - return new ColumnarBatchOutIterator(runtime, iterHandle); + private ColumnarBatchOutIterator createOutIterator( + Runtime runtime, long iterHandle, NativeMemoryManager nmm) throws IOException { + return new ColumnarBatchOutIterator(runtime, iterHandle, nmm); } private byte[] getPlanBytesBuf(Plan planNode) { diff --git a/gluten-data/src/main/java/io/glutenproject/vectorized/PlanEvaluatorJniWrapper.java b/gluten-data/src/main/java/io/glutenproject/vectorized/PlanEvaluatorJniWrapper.java index c3e74d245621..a6ad93ec456e 100644 --- a/gluten-data/src/main/java/io/glutenproject/vectorized/PlanEvaluatorJniWrapper.java +++ b/gluten-data/src/main/java/io/glutenproject/vectorized/PlanEvaluatorJniWrapper.java @@ -70,14 +70,4 @@ public native long nativeCreateKernelWithIterator( boolean saveInputToFile, String spillDir) throws RuntimeException; - - /** Create a native compute kernel and return a row iterator. */ - native long nativeCreateKernelWithRowIterator(byte[] wsPlan) throws RuntimeException; - - /** - * Closes the projector referenced by nativeHandler. - * - * @param nativeHandler nativeHandler that needs to be closed - */ - native void nativeClose(long nativeHandler); } diff --git a/gluten-data/src/main/scala/io/glutenproject/vectorized/CloseableColumnBatchIterator.scala b/gluten-data/src/main/scala/io/glutenproject/vectorized/CloseableColumnBatchIterator.scala deleted file mode 100644 index 021413ddaed9..000000000000 --- a/gluten-data/src/main/scala/io/glutenproject/vectorized/CloseableColumnBatchIterator.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.glutenproject.vectorized - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.TaskResources - -import java.util.concurrent.TimeUnit - -/** - * An Iterator that insures that the batches [[ColumnarBatch]]s it iterates over are all closed - * properly. - */ -class CloseableColumnBatchIterator( - itr: Iterator[ColumnarBatch], - pipelineTime: Option[SQLMetric] = None) - extends Iterator[ColumnarBatch] - with Logging { - var cb: ColumnarBatch = _ - var scanTime = 0L - - override def hasNext: Boolean = { - val beforeTime = System.nanoTime() - val res = itr.hasNext - scanTime += System.nanoTime() - beforeTime - if (!res) { - pipelineTime.foreach(t => t += TimeUnit.NANOSECONDS.toMillis(scanTime)) - closeCurrentBatch() - } - res - } - - TaskResources.addRecycler("CloseableColumnBatchIterator", 100) { - closeCurrentBatch() - } - - override def next(): ColumnarBatch = { - val beforeTime = System.nanoTime() - closeCurrentBatch() - cb = itr.next() - scanTime += System.nanoTime() - beforeTime - cb - } - - private def closeCurrentBatch(): Unit = { - if (cb != null) { - cb.close() - cb = null - } - } -} diff --git a/gluten-data/src/main/scala/io/glutenproject/vectorized/ColumnarBatchSerializer.scala b/gluten-data/src/main/scala/io/glutenproject/vectorized/ColumnarBatchSerializer.scala index 77c28bc23b2b..563e143d9c24 100644 --- a/gluten-data/src/main/scala/io/glutenproject/vectorized/ColumnarBatchSerializer.scala +++ b/gluten-data/src/main/scala/io/glutenproject/vectorized/ColumnarBatchSerializer.scala @@ -78,6 +78,7 @@ private class ColumnarBatchSerializerInstance( extends SerializerInstance with Logging { + private lazy val nmm = NativeMemoryManagers.contextInstance("ShuffleReader") private lazy val shuffleReaderHandle = { val allocator: BufferAllocator = ArrowBufferAllocators .contextInstance() @@ -98,7 +99,7 @@ private class ColumnarBatchSerializerInstance( val jniWrapper = ShuffleReaderJniWrapper.create() val shuffleReaderHandle = jniWrapper.make( cSchema.memoryAddress(), - NativeMemoryManagers.contextInstance("ShuffleReader").getNativeInstanceHandle, + nmm.getNativeInstanceHandle, compressionCodec, compressionCodecBackend ) @@ -128,7 +129,8 @@ private class ColumnarBatchSerializerInstance( Runtimes.contextInstance(), ShuffleReaderJniWrapper .create() - .readStream(shuffleReaderHandle, byteIn)) + .readStream(shuffleReaderHandle, byteIn), + nmm) private var cb: ColumnarBatch = _ diff --git a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index a134742f2f9f..3c90ab5ea953 100644 --- a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -21,7 +21,7 @@ import io.glutenproject.exec.Runtimes import io.glutenproject.execution.BroadCastHashJoinContext import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators import io.glutenproject.memory.nmm.NativeMemoryManagers -import io.glutenproject.utils.ArrowAbiUtil +import io.glutenproject.utils.{ArrowAbiUtil, Iterators} import io.glutenproject.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper} import org.apache.spark.sql.catalyst.InternalRow @@ -32,7 +32,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.utils.SparkArrowUtil import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.TaskResources import org.apache.arrow.c.ArrowSchema @@ -45,54 +44,46 @@ case class ColumnarBuildSideRelation( extends BuildSideRelation { override def deserialized: Iterator[ColumnarBatch] = { - new Iterator[ColumnarBatch] { - var batchId = 0 - var closed = false - private var finalBatch: ColumnarBatch = null - val serializeHandle: Long = { - val allocator = ArrowBufferAllocators.contextInstance() - val cSchema = ArrowSchema.allocateNew(allocator) - val arrowSchema = SparkArrowUtil.toArrowSchema( - StructType.fromAttributes(output), - SQLConf.get.sessionLocalTimeZone) - ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema) - val handle = ColumnarBatchSerializerJniWrapper - .create() - .init( - cSchema.memoryAddress(), - NativeMemoryManagers - .contextInstance("BuildSideRelation#BatchSerializer") - .getNativeInstanceHandle) - cSchema.close() - handle - } + val serializeHandle: Long = { + val allocator = ArrowBufferAllocators.contextInstance() + val cSchema = ArrowSchema.allocateNew(allocator) + val arrowSchema = SparkArrowUtil.toArrowSchema( + StructType.fromAttributes(output), + SQLConf.get.sessionLocalTimeZone) + ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema) + val handle = ColumnarBatchSerializerJniWrapper + .create() + .init( + cSchema.memoryAddress(), + NativeMemoryManagers + .contextInstance("BuildSideRelation#BatchSerializer") + .getNativeInstanceHandle) + cSchema.close() + handle + } - TaskResources.addRecycler(s"BuildSideRelation_deserialized_$serializeHandle", 50) { - ColumnarBatchSerializerJniWrapper.create().close(serializeHandle) - } + Iterators + .wrap(new Iterator[ColumnarBatch] { + var batchId = 0 - override def hasNext: Boolean = { - val has = batchId < batches.length - if (!has && !closed) { - if (finalBatch != null) { - ColumnarBatches.forceClose(finalBatch) - } - closed = true + override def hasNext: Boolean = { + batchId < batches.length } - has - } - override def next: ColumnarBatch = { - val handle = - ColumnarBatchSerializerJniWrapper.create().deserialize(serializeHandle, batches(batchId)) - batchId += 1 - val batch = ColumnarBatches.create(Runtimes.contextInstance(), handle) - if (batchId == batches.length) { - finalBatch = batch + override def next: ColumnarBatch = { + val handle = + ColumnarBatchSerializerJniWrapper + .create() + .deserialize(serializeHandle, batches(batchId)) + batchId += 1 + ColumnarBatches.create(Runtimes.contextInstance(), handle) } - batch + }) + .recycleIterator { + ColumnarBatchSerializerJniWrapper.create().close(serializeHandle) } - } + .recyclePayload(ColumnarBatches.forceClose) // FIXME why force close? + .create() } override def asReadOnlyCopy( diff --git a/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala b/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala index b817b3ae3b5c..eb29ba2709a0 100644 --- a/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala +++ b/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.utils import io.glutenproject.columnarbatch.ColumnarBatches import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators import io.glutenproject.memory.nmm.NativeMemoryManagers +import io.glutenproject.utils.Iterators import io.glutenproject.vectorized.{ArrowWritableColumnVector, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper, NativePartitioning} import org.apache.spark.{Partitioner, RangePartitioner, ShuffleDependency} -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ColumnarShuffleDependency @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -import org.apache.spark.util.{MutablePair, TaskResources} +import org.apache.spark.util.MutablePair object ExecUtil { @@ -49,35 +49,27 @@ object ExecUtil { .getNativeInstanceHandle) info = jniWrapper.nativeColumnarToRowConvert(batchHandle, c2rHandle) - new Iterator[InternalRow] { - var rowId = 0 - val row = new UnsafeRow(batch.numCols()) - var closed = false + Iterators + .wrap(new Iterator[InternalRow] { + var rowId = 0 + val row = new UnsafeRow(batch.numCols()) - TaskResources.addRecycler(s"ColumnarToRow_$c2rHandle", 100) { - if (!closed) { - jniWrapper.nativeClose(c2rHandle) - closed = true + override def hasNext: Boolean = { + rowId < batch.numRows() } - } - override def hasNext: Boolean = { - val result = rowId < batch.numRows() - if (!result && !closed) { - jniWrapper.nativeClose(c2rHandle) - closed = true + override def next: UnsafeRow = { + if (rowId >= batch.numRows()) throw new NoSuchElementException + val (offset, length) = (info.offsets(rowId), info.lengths(rowId)) + row.pointTo(null, info.memoryAddress + offset, length.toInt) + rowId += 1 + row } - result - } - - override def next: UnsafeRow = { - if (rowId >= batch.numRows()) throw new NoSuchElementException - val (offset, length) = (info.offsets(rowId), info.lengths(rowId)) - row.pointTo(null, info.memoryAddress + offset, length.toInt) - rowId += 1 - row + }) + .recycleIterator { + jniWrapper.nativeClose(c2rHandle) } - } + .create() } // scalastyle:off argcount @@ -125,29 +117,31 @@ object ExecUtil { // only used for fallback range partitioning def computeAndAddPartitionId( cbIter: Iterator[ColumnarBatch], - partitionKeyExtractor: InternalRow => Any): CloseablePairedColumnarBatchIterator = { - CloseablePairedColumnarBatchIterator { - cbIter - .filter(cb => cb.numRows != 0 && cb.numCols != 0) - .map { - cb => - val pidVec = ArrowWritableColumnVector - .allocateColumns(cb.numRows, new StructType().add("pid", IntegerType)) - .head - convertColumnarToRow(cb).zipWithIndex.foreach { - case (row, i) => - val pid = rangePartitioner.get.getPartition(partitionKeyExtractor(row)) - pidVec.putInt(i, pid) - } - val pidBatch = ColumnarBatches.ensureOffloaded( - ArrowBufferAllocators.contextInstance(), - new ColumnarBatch(Array[ColumnVector](pidVec), cb.numRows)) - val newHandle = ColumnarBatches.compose(pidBatch, cb) - // Composed batch already hold pidBatch's shared ref, so close is safe. - ColumnarBatches.forceClose(pidBatch) - (0, ColumnarBatches.create(ColumnarBatches.getRuntime(cb), newHandle)) - } - } + partitionKeyExtractor: InternalRow => Any): Iterator[(Int, ColumnarBatch)] = { + Iterators + .wrap( + cbIter + .filter(cb => cb.numRows != 0 && cb.numCols != 0) + .map { + cb => + val pidVec = ArrowWritableColumnVector + .allocateColumns(cb.numRows, new StructType().add("pid", IntegerType)) + .head + convertColumnarToRow(cb).zipWithIndex.foreach { + case (row, i) => + val pid = rangePartitioner.get.getPartition(partitionKeyExtractor(row)) + pidVec.putInt(i, pid) + } + val pidBatch = ColumnarBatches.ensureOffloaded( + ArrowBufferAllocators.contextInstance(), + new ColumnarBatch(Array[ColumnVector](pidVec), cb.numRows)) + val newHandle = ColumnarBatches.compose(pidBatch, cb) + // Composed batch already hold pidBatch's shared ref, so close is safe. + ColumnarBatches.forceClose(pidBatch) + (0, ColumnarBatches.create(ColumnarBatches.getRuntime(cb), newHandle)) + }) + .recyclePayload(p => ColumnarBatches.forceClose(p._2)) // FIXME why force close? + .create() } val nativePartitioning: NativePartitioning = newPartitioning match { @@ -181,11 +175,6 @@ object ExecUtil { row => projection(row) } val newIter = computeAndAddPartitionId(cbIter, partitionKeyExtractor) - - TaskResources.addRecycler("RangePartitioningIter", 100) { - newIter.closeColumnBatch() - } - newIter }, isOrderSensitive = isOrderSensitive @@ -212,34 +201,3 @@ object ExecUtil { private[spark] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { override def getPartition(key: Any): Int = key.asInstanceOf[Int] } -case class CloseablePairedColumnarBatchIterator(iter: Iterator[(Int, ColumnarBatch)]) - extends Iterator[(Int, ColumnarBatch)] - with Logging { - - private var cur: (Int, ColumnarBatch) = _ - - override def hasNext: Boolean = { - iter.hasNext - } - - override def next(): (Int, ColumnarBatch) = { - closeColumnBatch() - if (iter.hasNext) { - cur = iter.next() - cur - } else { - closeColumnBatch() - Iterator.empty.next() - } - } - - def closeColumnBatch(): Unit = { - if (cur != null) { - logDebug("Close appended partition id vector") - cur match { - case (_, cb: ColumnarBatch) => ColumnarBatches.forceClose(cb) - } - cur = null - } - } -}