From 5042fc4f98343ae000eeeb303991b5f142c91038 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 20 Aug 2024 11:22:09 -0700 Subject: [PATCH] [rust] Fixes memory leak (#3433) Fixes: #3413 --- extensions/tokenizers/rust/src/ndarray/mod.rs | 25 ++++++++++--------- .../java/ai/djl/engine/rust/RsNDArray.java | 3 ++- .../java/ai/djl/engine/rust/RustLibrary.java | 2 +- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/extensions/tokenizers/rust/src/ndarray/mod.rs b/extensions/tokenizers/rust/src/ndarray/mod.rs index 059b69cd4b1..478cde826c7 100644 --- a/extensions/tokenizers/rust/src/ndarray/mod.rs +++ b/extensions/tokenizers/rust/src/ndarray/mod.rs @@ -1,7 +1,7 @@ use candle::{DType, Device, Error, Result, Shape, Tensor, WithDType}; use half::{bf16, f16}; -use jni::objects::{JByteBuffer, JIntArray, JLongArray, JObject, JString, ReleaseMode}; -use jni::sys::{jint, jlong}; +use jni::objects::{JByteArray, JIntArray, JLongArray, JObject, JString, ReleaseMode}; +use jni::sys::{jbyte, jint, jlong, jsize}; use jni::JNIEnv; use crate::{cast_handle, drop_handle, to_handle}; @@ -69,13 +69,13 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_getShape<'local>( } #[no_mangle] -pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_getByteBuffer<'local>( - mut env: JNIEnv<'local>, +pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_toByteArray<'local>( + env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> JByteBuffer<'local> { +) -> JByteArray<'local> { let tensor = cast_handle::(handle).flatten_all().unwrap(); - let (ptr, len) = match tensor.dtype() { + let vs = match tensor.dtype() { DType::U8 => convert_back_::(tensor.to_vec1().unwrap()), DType::U32 => convert_back_::(tensor.to_vec1().unwrap()), DType::I64 => convert_back_::(tensor.to_vec1().unwrap()), @@ -84,9 +84,9 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_getByteBuffer<'local> DType::F32 => convert_back_::(tensor.to_vec1().unwrap()), DType::F64 => convert_back_::(tensor.to_vec1().unwrap()), }; - - let buf = unsafe { env.new_direct_byte_buffer(ptr, len) }.unwrap(); - buf + let array = env.new_byte_array(vs.len() as jsize).unwrap(); + env.set_byte_array_region(&array, 0, &vs).unwrap(); + array } #[no_mangle] @@ -237,17 +237,18 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_deleteTensor( drop_handle::(handle); } -fn convert_back_(mut vs: Vec) -> (*mut u8, usize) { +fn convert_back_(mut vs: Vec) -> Vec { let size_in_bytes = T::DTYPE.size_in_bytes(); let length = vs.len() * size_in_bytes; - let ptr = vs.as_mut_ptr() as *mut u8; + let capacity = vs.capacity() * size_in_bytes; + let ptr = vs.as_mut_ptr() as *mut jbyte; // Don't run the destructor for Vec std::mem::forget(vs); // SAFETY: // // Every T is larger than u8, so there is no issue regarding alignment. // This re-interpret the Vec as a Vec. - (ptr, length) + unsafe { Vec::from_raw_parts(ptr, length, capacity) } } fn as_shape<'local>(env: &mut JNIEnv, shape: &JLongArray<'local>) -> Shape { diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArray.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArray.java index 8e6e4f3c7dc..cca87a082e1 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArray.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArray.java @@ -209,7 +209,8 @@ public NDArray stopGradient() { /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer(boolean tryDirect) { - ByteBuffer bb = RustLibrary.getByteBuffer(getHandle()); + byte[] buf = RustLibrary.toByteArray(getHandle()); + ByteBuffer bb = ByteBuffer.wrap(buf); bb.order(ByteOrder.nativeOrder()); return bb; } diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java index ef848e0291a..bdcdabe8df6 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java @@ -87,7 +87,7 @@ public static long hannWindow(long numPoints, String deviceType, int deviceId) { public static native long toDataType(long handle, int dataType); - public static native ByteBuffer getByteBuffer(long handle); + public static native byte[] toByteArray(long handle); public static native long fullSlice(long handle, long[] min, long[] max, long[] step);