Skip to content

Commit

Permalink
mobile: Fix JNI GlobalRef leak (envoyproxy#35504)
Browse files Browse the repository at this point in the history
When the thread terminates, the `thread_local` storage will be
destroyed. However, the `GlobalRef` `jclass` references inside the
`thread_local` don't get destroyed automatically causing a leak. This PR
fixes it by wrapping the `jclass` with an RAII-style wrapper,
`GlobalRefUniquePtr` that has a custom deleter to delete the `GlobalRef`
upon `threal_local` destruction.

Risk Level: low
Testing: CI
Docs Changes: n/a
Release Notes: n/a
Platform Specific Features: mobile

Signed-off-by: Fredy Wijaya <[email protected]>
  • Loading branch information
fredyw authored Jul 30, 2024
1 parent 3f6a8c0 commit ca51f36
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 61 deletions.
70 changes: 49 additions & 21 deletions mobile/library/jni/jni_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ constexpr const char* THREAD_NAME = "EnvoyMain";
// Non-const variables.
std::atomic<JavaVM*> java_vm_cache_;
thread_local JNIEnv* jni_env_cache_ = nullptr;
// `jclass_cache_map` contains `jclass` references that are statically populated. This field is
// used by `FindClass` to find the `jclass` reference from a given class name.
absl::flat_hash_map<absl::string_view, jclass> jclass_cache_map;
// The `jclass_cache_set` is a superset of `jclass_cache_map`. It contains `jclass` objects that are
// retrieve dynamically via `GetObjectClass`.
thread_local absl::flat_hash_set<jclass> jclass_cache_set;
//
// The `jclass_cache_set` owns the `jclass` global refs, wrapped in `GlobalRefUniquePtr` to allow
// automatic `GlobalRef` destruction. The other fields, such as `jmethod_id_cache_map`,
// `static_jmethod_id_cache_map`, `jfield_id_cache_map`, and `static_jfield_id_cache_map` only
// borrow the `jclass` references.
//
// Note: all these fields are `thread_local` to avoid locking.
thread_local absl::flat_hash_set<GlobalRefUniquePtr<jclass>> jclass_cache_set;
thread_local absl::flat_hash_map<
std::tuple<jclass, absl::string_view /* method */, absl::string_view /* signature */>,
jmethodID>
Expand All @@ -43,13 +52,37 @@ thread_local absl::flat_hash_map<
jclass addClassToCacheIfNotExist(JNIEnv* env, jclass clazz) {
jclass java_class_global_ref = clazz;
if (auto it = jclass_cache_set.find(clazz); it == jclass_cache_set.end()) {
jclass global_ref = reinterpret_cast<jclass>(env->NewGlobalRef(clazz));
jclass_cache_set.emplace(global_ref);
java_class_global_ref = reinterpret_cast<jclass>(env->NewGlobalRef(clazz));
jclass_cache_set.emplace(java_class_global_ref, GlobalRefDeleter());
}
return java_class_global_ref;
}
} // namespace

void GlobalRefDeleter::operator()(jobject object) const {
if (object != nullptr) {
JniHelper::getThreadLocalEnv()->DeleteGlobalRef(object);
}
}

void LocalRefDeleter::operator()(jobject object) const {
if (object != nullptr) {
JniHelper::getThreadLocalEnv()->DeleteLocalRef(object);
}
}

void StringUtfDeleter::operator()(const char* c_str) const {
if (c_str != nullptr) {
JniHelper::getThreadLocalEnv()->ReleaseStringUTFChars(j_str_, c_str);
}
}

void PrimitiveArrayCriticalDeleter::operator()(void* c_array) const {
if (c_array != nullptr) {
JniHelper::getThreadLocalEnv()->ReleasePrimitiveArrayCritical(array_, c_array, 0);
}
}

jint JniHelper::getVersion() { return JNI_VERSION; }

void JniHelper::initialize(JavaVM* java_vm) {
Expand All @@ -67,7 +100,7 @@ void JniHelper::finalize() {
static_jfield_id_cache_map.clear();
jclass_cache_map.clear();
for (const auto& clazz : jclass_cache_set) {
env->DeleteGlobalRef(clazz);
env->DeleteGlobalRef(clazz.get());
}
jclass_cache_set.clear();
}
Expand All @@ -80,7 +113,7 @@ void JniHelper::addClassToCache(const char* class_name) {
ASSERT(java_class != nullptr, absl::StrFormat("Unable to find class '%s'.", class_name));
jclass java_class_global_ref = reinterpret_cast<jclass>(env->NewGlobalRef(java_class));
jclass_cache_map.emplace(class_name, java_class_global_ref);
jclass_cache_set.emplace(java_class_global_ref);
jclass_cache_set.emplace(java_class_global_ref, GlobalRefDeleter());
}

JavaVM* JniHelper::getJavaVm() { return java_vm_cache_.load(std::memory_order_acquire); }
Expand All @@ -90,18 +123,15 @@ void JniHelper::detachCurrentThread() {
}

JNIEnv* JniHelper::getThreadLocalEnv() {
if (jni_env_cache_ != nullptr) {
return jni_env_cache_;
}
JavaVM* java_vm = getJavaVm();
ASSERT(java_vm != nullptr, "Unable to get JavaVM.");
jint result = java_vm->GetEnv(reinterpret_cast<void**>(&jni_env_cache_), getVersion());
if (result == JNI_EDETACHED) {
JavaVMAttachArgs args = {getVersion(), const_cast<char*>(THREAD_NAME), nullptr};
#if defined(__ANDROID__)
result = java_vm->AttachCurrentThread(&jni_env_cache_, &args);
result = java_vm->AttachCurrentThreadAsDaemon(&jni_env_cache_, &args);
#else
result = java_vm->AttachCurrentThread(reinterpret_cast<void**>(&jni_env_cache_), &args);
result = java_vm->AttachCurrentThreadAsDaemon(reinterpret_cast<void**>(&jni_env_cache_), &args);
#endif
}
ASSERT(result == JNI_OK, "Unable to get JNIEnv.");
Expand Down Expand Up @@ -193,7 +223,7 @@ jclass JniHelper::findClass(const char* class_name) {
}

LocalRefUniquePtr<jclass> JniHelper::getObjectClass(jobject object) {
return {env_->GetObjectClass(object), LocalRefDeleter(env_)};
return {env_->GetObjectClass(object), LocalRefDeleter()};
}

void JniHelper::throwNew(const char* java_class_name, const char* message) {
Expand All @@ -207,34 +237,33 @@ void JniHelper::throwNew(const char* java_class_name, const char* message) {
jboolean JniHelper::exceptionCheck() { return env_->ExceptionCheck(); }

LocalRefUniquePtr<jthrowable> JniHelper::exceptionOccurred() {
return {env_->ExceptionOccurred(), LocalRefDeleter(env_)};
return {env_->ExceptionOccurred(), LocalRefDeleter()};
}

void JniHelper::exceptionCleared() { env_->ExceptionClear(); }

GlobalRefUniquePtr<jobject> JniHelper::newGlobalRef(jobject object) {
GlobalRefUniquePtr<jobject> result(env_->NewGlobalRef(object), GlobalRefDeleter(env_));
GlobalRefUniquePtr<jobject> result(env_->NewGlobalRef(object), GlobalRefDeleter());
return result;
}

LocalRefUniquePtr<jobject> JniHelper::newObject(jclass clazz, jmethodID method_id, ...) {
va_list args;
va_start(args, method_id);
LocalRefUniquePtr<jobject> result(env_->NewObjectV(clazz, method_id, args),
LocalRefDeleter(env_));
LocalRefUniquePtr<jobject> result(env_->NewObjectV(clazz, method_id, args), LocalRefDeleter());
rethrowException();
va_end(args);
return result;
}

LocalRefUniquePtr<jstring> JniHelper::newStringUtf(const char* str) {
LocalRefUniquePtr<jstring> result(env_->NewStringUTF(str), LocalRefDeleter(env_));
LocalRefUniquePtr<jstring> result(env_->NewStringUTF(str), LocalRefDeleter());
rethrowException();
return result;
}

StringUtfUniquePtr JniHelper::getStringUtfChars(jstring str, jboolean* is_copy) {
StringUtfUniquePtr result(env_->GetStringUTFChars(str, is_copy), StringUtfDeleter(env_, str));
StringUtfUniquePtr result(env_->GetStringUTFChars(str, is_copy), StringUtfDeleter(str));
rethrowException();
return result;
}
Expand All @@ -243,8 +272,7 @@ jsize JniHelper::getArrayLength(jarray array) { return env_->GetArrayLength(arra

#define DEFINE_NEW_ARRAY(JAVA_TYPE, JNI_TYPE) \
LocalRefUniquePtr<JNI_TYPE> JniHelper::new##JAVA_TYPE##Array(jsize length) { \
LocalRefUniquePtr<JNI_TYPE> result(env_->New##JAVA_TYPE##Array(length), \
LocalRefDeleter(env_)); \
LocalRefUniquePtr<JNI_TYPE> result(env_->New##JAVA_TYPE##Array(length), LocalRefDeleter()); \
rethrowException(); \
return result; \
}
Expand All @@ -261,7 +289,7 @@ DEFINE_NEW_ARRAY(Boolean, jbooleanArray)
LocalRefUniquePtr<jobjectArray> JniHelper::newObjectArray(jsize length, jclass element_class,
jobject initial_element) {
LocalRefUniquePtr<jobjectArray> result(
env_->NewObjectArray(length, element_class, initial_element), LocalRefDeleter(env_));
env_->NewObjectArray(length, element_class, initial_element), LocalRefDeleter());

return result;
}
Expand Down Expand Up @@ -362,7 +390,7 @@ void JniHelper::callStaticVoidMethod(jclass clazz, jmethodID method_id, ...) {

LocalRefUniquePtr<jobject> JniHelper::newDirectByteBuffer(void* address, jlong capacity) {
LocalRefUniquePtr<jobject> result(env_->NewDirectByteBuffer(address, capacity),
LocalRefDeleter(env_));
LocalRefDeleter());
rethrowException();
return result;
}
Expand Down
56 changes: 18 additions & 38 deletions mobile/library/jni/jni_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@ namespace JNI {
/** A custom deleter to delete JNI global ref. */
class GlobalRefDeleter {
public:
explicit GlobalRefDeleter(JNIEnv* env) : env_(env) {}
explicit GlobalRefDeleter() = default;

void operator()(jobject object) const {
if (object != nullptr) {
env_->DeleteGlobalRef(object);
}
}
GlobalRefDeleter(const GlobalRefDeleter&) = default;

private:
JNIEnv* const env_;
// This is to allow move semantics in `GlobalRefUniquePtr`.
GlobalRefDeleter& operator=(const GlobalRefDeleter&) = default;

void operator()(jobject object) const;
};

/** A unique pointer for JNI global ref. */
Expand All @@ -29,21 +27,14 @@ using GlobalRefUniquePtr = std::unique_ptr<typename std::remove_pointer<T>::type
/** A custom deleter to delete JNI local ref. */
class LocalRefDeleter {
public:
explicit LocalRefDeleter(JNIEnv* env) : env_(env) {}
explicit LocalRefDeleter() = default;

LocalRefDeleter(const LocalRefDeleter&) = default;

// This is to allow move semantics in `LocalRefUniquePtr`.
LocalRefDeleter& operator=(const LocalRefDeleter&) { return *this; }

void operator()(jobject object) const {
if (object != nullptr) {
env_->DeleteLocalRef(object);
}
}
LocalRefDeleter& operator=(const LocalRefDeleter&) = default;

private:
JNIEnv* const env_;
void operator()(jobject object) const;
};

/** A unique pointer for JNI local ref. */
Expand All @@ -53,16 +44,11 @@ using LocalRefUniquePtr = std::unique_ptr<typename std::remove_pointer<T>::type,
/** A custom deleter for UTF strings. */
class StringUtfDeleter {
public:
StringUtfDeleter(JNIEnv* env, jstring j_str) : env_(env), j_str_(j_str) {}
explicit StringUtfDeleter(jstring j_str) : j_str_(j_str) {}

void operator()(const char* c_str) const {
if (c_str != nullptr) {
env_->ReleaseStringUTFChars(j_str_, c_str);
}
}
void operator()(const char* c_str) const;

private:
JNIEnv* const env_;
jstring j_str_;
};

Expand Down Expand Up @@ -111,16 +97,11 @@ using ArrayElementsUniquePtr = std::unique_ptr<
/** A custom deleter for JNI primitive array critical. */
class PrimitiveArrayCriticalDeleter {
public:
PrimitiveArrayCriticalDeleter(JNIEnv* env, jarray array) : env_(env), array_(array) {}
explicit PrimitiveArrayCriticalDeleter(jarray array) : array_(array) {}

void operator()(void* c_array) const {
if (c_array != nullptr) {
env_->ReleasePrimitiveArrayCritical(array_, c_array, 0);
}
}
void operator()(void* c_array) const;

private:
JNIEnv* const env_;
jarray array_;
};

Expand Down Expand Up @@ -221,7 +202,7 @@ class JniHelper {
template <typename T = jobject>
[[nodiscard]] LocalRefUniquePtr<T> getObjectField(jobject object, jfieldID field_id) {
LocalRefUniquePtr<T> result(static_cast<T>(env_->GetObjectField(object, field_id)),
LocalRefDeleter(env_));
LocalRefDeleter());
return result;
}

Expand Down Expand Up @@ -360,7 +341,7 @@ class JniHelper {
template <typename T = jobject>
[[nodiscard]] LocalRefUniquePtr<T> getObjectArrayElement(jobjectArray array, jsize index) {
LocalRefUniquePtr<T> result(static_cast<T>(env_->GetObjectArrayElement(array, index)),
LocalRefDeleter(env_));
LocalRefDeleter());
rethrowException();
return result;
}
Expand All @@ -382,7 +363,7 @@ class JniHelper {
jboolean* is_copy) {
PrimitiveArrayCriticalUniquePtr<T> result(
static_cast<T>(env_->GetPrimitiveArrayCritical(array, is_copy)),
PrimitiveArrayCriticalDeleter(env_, array));
PrimitiveArrayCriticalDeleter(array));
return result;
}

Expand Down Expand Up @@ -429,7 +410,7 @@ class JniHelper {
va_list args;
va_start(args, method_id);
LocalRefUniquePtr<T> result(static_cast<T>(env_->CallObjectMethodV(object, method_id, args)),
LocalRefDeleter(env_));
LocalRefDeleter());
va_end(args);
rethrowException();
return result;
Expand Down Expand Up @@ -461,8 +442,7 @@ class JniHelper {
va_list args;
va_start(args, method_id);
LocalRefUniquePtr<T> result(
static_cast<T>(env_->CallStaticObjectMethodV(clazz, method_id, args)),
LocalRefDeleter(env_));
static_cast<T>(env_->CallStaticObjectMethodV(clazz, method_id, args)), LocalRefDeleter());
va_end(args);
rethrowException();
return result;
Expand Down
4 changes: 2 additions & 2 deletions mobile/library/jni/jni_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,8 @@ jvm_http_filter_on_resume(const char* method, envoy_headers* headers, envoy_data
headers_length = static_cast<jlong>(headers->length);
passHeaders("passHeader", *headers, j_context);
}
Envoy::JNI::LocalRefUniquePtr<jobject> j_in_data = Envoy::JNI::LocalRefUniquePtr<jobject>(
nullptr, Envoy::JNI::LocalRefDeleter(jni_helper.getEnv()));
Envoy::JNI::LocalRefUniquePtr<jobject> j_in_data =
Envoy::JNI::LocalRefUniquePtr<jobject>(nullptr, Envoy::JNI::LocalRefDeleter());
if (data) {
j_in_data = Envoy::JNI::envoyDataToJavaByteBuffer(jni_helper, *data);
}
Expand Down

0 comments on commit ca51f36

Please sign in to comment.