Skip to content

Commit

Permalink
[rust] Fixes memory leak (#3433)
Browse files Browse the repository at this point in the history
Fixes: #3413
  • Loading branch information
frankfliu authored Aug 20, 2024
1 parent 3344809 commit 5042fc4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
25 changes: 13 additions & 12 deletions extensions/tokenizers/rust/src/ndarray/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use candle::{DType, Device, Error, Result, Shape, Tensor, WithDType};
use half::{bf16, f16};
use jni::objects::{JByteBuffer, JIntArray, JLongArray, JObject, JString, ReleaseMode};
use jni::sys::{jint, jlong};
use jni::objects::{JByteArray, JIntArray, JLongArray, JObject, JString, ReleaseMode};
use jni::sys::{jbyte, jint, jlong, jsize};
use jni::JNIEnv;

use crate::{cast_handle, drop_handle, to_handle};
Expand Down Expand Up @@ -69,13 +69,13 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_getShape<'local>(
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_getByteBuffer<'local>(
mut env: JNIEnv<'local>,
pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_toByteArray<'local>(
env: JNIEnv<'local>,
_: JObject,
handle: jlong,
) -> JByteBuffer<'local> {
) -> JByteArray<'local> {
let tensor = cast_handle::<Tensor>(handle).flatten_all().unwrap();
let (ptr, len) = match tensor.dtype() {
let vs = match tensor.dtype() {
DType::U8 => convert_back_::<u8>(tensor.to_vec1().unwrap()),
DType::U32 => convert_back_::<u32>(tensor.to_vec1().unwrap()),
DType::I64 => convert_back_::<i64>(tensor.to_vec1().unwrap()),
Expand All @@ -84,9 +84,9 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_getByteBuffer<'local>
DType::F32 => convert_back_::<f32>(tensor.to_vec1().unwrap()),
DType::F64 => convert_back_::<f64>(tensor.to_vec1().unwrap()),
};

let buf = unsafe { env.new_direct_byte_buffer(ptr, len) }.unwrap();
buf
let array = env.new_byte_array(vs.len() as jsize).unwrap();
env.set_byte_array_region(&array, 0, &vs).unwrap();
array
}

#[no_mangle]
Expand Down Expand Up @@ -237,17 +237,18 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_deleteTensor(
drop_handle::<Tensor>(handle);
}

fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> (*mut u8, usize) {
fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<jbyte> {
let size_in_bytes = T::DTYPE.size_in_bytes();
let length = vs.len() * size_in_bytes;
let ptr = vs.as_mut_ptr() as *mut u8;
let capacity = vs.capacity() * size_in_bytes;
let ptr = vs.as_mut_ptr() as *mut jbyte;
// Don't run the destructor for Vec<T>
std::mem::forget(vs);
// SAFETY:
//
// Every T is larger than u8, so there is no issue regarding alignment.
// This re-interpret the Vec<T> as a Vec<u8>.
(ptr, length)
unsafe { Vec::from_raw_parts(ptr, length, capacity) }
}

fn as_shape<'local>(env: &mut JNIEnv, shape: &JLongArray<'local>) -> Shape {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ public NDArray stopGradient() {
/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer(boolean tryDirect) {
ByteBuffer bb = RustLibrary.getByteBuffer(getHandle());
byte[] buf = RustLibrary.toByteArray(getHandle());
ByteBuffer bb = ByteBuffer.wrap(buf);
bb.order(ByteOrder.nativeOrder());
return bb;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public static long hannWindow(long numPoints, String deviceType, int deviceId) {

public static native long toDataType(long handle, int dataType);

public static native ByteBuffer getByteBuffer(long handle);
public static native byte[] toByteArray(long handle);

public static native long fullSlice(long handle, long[] min, long[] max, long[] step);

Expand Down

0 comments on commit 5042fc4

Please sign in to comment.