diff --git a/mobile/library/jni/jni_helper.cc b/mobile/library/jni/jni_helper.cc index 7b5e515dd518..1dd285e1116b 100644 --- a/mobile/library/jni/jni_helper.cc +++ b/mobile/library/jni/jni_helper.cc @@ -16,10 +16,19 @@ constexpr const char* THREAD_NAME = "EnvoyMain"; // Non-const variables. std::atomic java_vm_cache_; thread_local JNIEnv* jni_env_cache_ = nullptr; +// `jclass_cache_map` contains `jclass` references that are statically populated. This field is +// used by `FindClass` to find the `jclass` reference from a given class name. absl::flat_hash_map jclass_cache_map; // The `jclass_cache_set` is a superset of `jclass_cache_map`. It contains `jclass` objects that are // retrieve dynamically via `GetObjectClass`. -thread_local absl::flat_hash_set jclass_cache_set; +// +// The `jclass_cache_set` owns the `jclass` global refs, wrapped in `GlobalRefUniquePtr` to allow +// automatic `GlobalRef` destruction. The other fields, such as `jmethod_id_cache_map`, +// `static_jmethod_id_cache_map`, `jfield_id_cache_map`, and `static_jfield_id_cache_map` only +// borrow the `jclass` references. +// +// Note: all these fields are `thread_local` to avoid locking. +thread_local absl::flat_hash_set> jclass_cache_set; thread_local absl::flat_hash_map< std::tuple, jmethodID> @@ -43,13 +52,37 @@ thread_local absl::flat_hash_map< jclass addClassToCacheIfNotExist(JNIEnv* env, jclass clazz) { jclass java_class_global_ref = clazz; if (auto it = jclass_cache_set.find(clazz); it == jclass_cache_set.end()) { - jclass global_ref = reinterpret_cast(env->NewGlobalRef(clazz)); - jclass_cache_set.emplace(global_ref); + java_class_global_ref = reinterpret_cast(env->NewGlobalRef(clazz)); + jclass_cache_set.emplace(java_class_global_ref, GlobalRefDeleter()); } return java_class_global_ref; } } // namespace +void GlobalRefDeleter::operator()(jobject object) const { + if (object != nullptr) { + JniHelper::getThreadLocalEnv()->DeleteGlobalRef(object); + } +} + +void LocalRefDeleter::operator()(jobject object) const { + if (object != nullptr) { + JniHelper::getThreadLocalEnv()->DeleteLocalRef(object); + } +} + +void StringUtfDeleter::operator()(const char* c_str) const { + if (c_str != nullptr) { + JniHelper::getThreadLocalEnv()->ReleaseStringUTFChars(j_str_, c_str); + } +} + +void PrimitiveArrayCriticalDeleter::operator()(void* c_array) const { + if (c_array != nullptr) { + JniHelper::getThreadLocalEnv()->ReleasePrimitiveArrayCritical(array_, c_array, 0); + } +} + jint JniHelper::getVersion() { return JNI_VERSION; } void JniHelper::initialize(JavaVM* java_vm) { @@ -67,7 +100,7 @@ void JniHelper::finalize() { static_jfield_id_cache_map.clear(); jclass_cache_map.clear(); for (const auto& clazz : jclass_cache_set) { - env->DeleteGlobalRef(clazz); + env->DeleteGlobalRef(clazz.get()); } jclass_cache_set.clear(); } @@ -80,7 +113,7 @@ void JniHelper::addClassToCache(const char* class_name) { ASSERT(java_class != nullptr, absl::StrFormat("Unable to find class '%s'.", class_name)); jclass java_class_global_ref = reinterpret_cast(env->NewGlobalRef(java_class)); jclass_cache_map.emplace(class_name, java_class_global_ref); - jclass_cache_set.emplace(java_class_global_ref); + jclass_cache_set.emplace(java_class_global_ref, GlobalRefDeleter()); } JavaVM* JniHelper::getJavaVm() { return java_vm_cache_.load(std::memory_order_acquire); } @@ -90,18 +123,15 @@ void JniHelper::detachCurrentThread() { } JNIEnv* JniHelper::getThreadLocalEnv() { - if (jni_env_cache_ != nullptr) { - return jni_env_cache_; - } JavaVM* java_vm = getJavaVm(); ASSERT(java_vm != nullptr, "Unable to get JavaVM."); jint result = java_vm->GetEnv(reinterpret_cast(&jni_env_cache_), getVersion()); if (result == JNI_EDETACHED) { JavaVMAttachArgs args = {getVersion(), const_cast(THREAD_NAME), nullptr}; #if defined(__ANDROID__) - result = java_vm->AttachCurrentThread(&jni_env_cache_, &args); + result = java_vm->AttachCurrentThreadAsDaemon(&jni_env_cache_, &args); #else - result = java_vm->AttachCurrentThread(reinterpret_cast(&jni_env_cache_), &args); + result = java_vm->AttachCurrentThreadAsDaemon(reinterpret_cast(&jni_env_cache_), &args); #endif } ASSERT(result == JNI_OK, "Unable to get JNIEnv."); @@ -193,7 +223,7 @@ jclass JniHelper::findClass(const char* class_name) { } LocalRefUniquePtr JniHelper::getObjectClass(jobject object) { - return {env_->GetObjectClass(object), LocalRefDeleter(env_)}; + return {env_->GetObjectClass(object), LocalRefDeleter()}; } void JniHelper::throwNew(const char* java_class_name, const char* message) { @@ -207,34 +237,33 @@ void JniHelper::throwNew(const char* java_class_name, const char* message) { jboolean JniHelper::exceptionCheck() { return env_->ExceptionCheck(); } LocalRefUniquePtr JniHelper::exceptionOccurred() { - return {env_->ExceptionOccurred(), LocalRefDeleter(env_)}; + return {env_->ExceptionOccurred(), LocalRefDeleter()}; } void JniHelper::exceptionCleared() { env_->ExceptionClear(); } GlobalRefUniquePtr JniHelper::newGlobalRef(jobject object) { - GlobalRefUniquePtr result(env_->NewGlobalRef(object), GlobalRefDeleter(env_)); + GlobalRefUniquePtr result(env_->NewGlobalRef(object), GlobalRefDeleter()); return result; } LocalRefUniquePtr JniHelper::newObject(jclass clazz, jmethodID method_id, ...) { va_list args; va_start(args, method_id); - LocalRefUniquePtr result(env_->NewObjectV(clazz, method_id, args), - LocalRefDeleter(env_)); + LocalRefUniquePtr result(env_->NewObjectV(clazz, method_id, args), LocalRefDeleter()); rethrowException(); va_end(args); return result; } LocalRefUniquePtr JniHelper::newStringUtf(const char* str) { - LocalRefUniquePtr result(env_->NewStringUTF(str), LocalRefDeleter(env_)); + LocalRefUniquePtr result(env_->NewStringUTF(str), LocalRefDeleter()); rethrowException(); return result; } StringUtfUniquePtr JniHelper::getStringUtfChars(jstring str, jboolean* is_copy) { - StringUtfUniquePtr result(env_->GetStringUTFChars(str, is_copy), StringUtfDeleter(env_, str)); + StringUtfUniquePtr result(env_->GetStringUTFChars(str, is_copy), StringUtfDeleter(str)); rethrowException(); return result; } @@ -243,8 +272,7 @@ jsize JniHelper::getArrayLength(jarray array) { return env_->GetArrayLength(arra #define DEFINE_NEW_ARRAY(JAVA_TYPE, JNI_TYPE) \ LocalRefUniquePtr JniHelper::new##JAVA_TYPE##Array(jsize length) { \ - LocalRefUniquePtr result(env_->New##JAVA_TYPE##Array(length), \ - LocalRefDeleter(env_)); \ + LocalRefUniquePtr result(env_->New##JAVA_TYPE##Array(length), LocalRefDeleter()); \ rethrowException(); \ return result; \ } @@ -261,7 +289,7 @@ DEFINE_NEW_ARRAY(Boolean, jbooleanArray) LocalRefUniquePtr JniHelper::newObjectArray(jsize length, jclass element_class, jobject initial_element) { LocalRefUniquePtr result( - env_->NewObjectArray(length, element_class, initial_element), LocalRefDeleter(env_)); + env_->NewObjectArray(length, element_class, initial_element), LocalRefDeleter()); return result; } @@ -362,7 +390,7 @@ void JniHelper::callStaticVoidMethod(jclass clazz, jmethodID method_id, ...) { LocalRefUniquePtr JniHelper::newDirectByteBuffer(void* address, jlong capacity) { LocalRefUniquePtr result(env_->NewDirectByteBuffer(address, capacity), - LocalRefDeleter(env_)); + LocalRefDeleter()); rethrowException(); return result; } diff --git a/mobile/library/jni/jni_helper.h b/mobile/library/jni/jni_helper.h index d21b33804a6b..cb3b07f20b84 100644 --- a/mobile/library/jni/jni_helper.h +++ b/mobile/library/jni/jni_helper.h @@ -10,16 +10,14 @@ namespace JNI { /** A custom deleter to delete JNI global ref. */ class GlobalRefDeleter { public: - explicit GlobalRefDeleter(JNIEnv* env) : env_(env) {} + explicit GlobalRefDeleter() = default; - void operator()(jobject object) const { - if (object != nullptr) { - env_->DeleteGlobalRef(object); - } - } + GlobalRefDeleter(const GlobalRefDeleter&) = default; -private: - JNIEnv* const env_; + // This is to allow move semantics in `GlobalRefUniquePtr`. + GlobalRefDeleter& operator=(const GlobalRefDeleter&) = default; + + void operator()(jobject object) const; }; /** A unique pointer for JNI global ref. */ @@ -29,21 +27,14 @@ using GlobalRefUniquePtr = std::unique_ptr::type /** A custom deleter to delete JNI local ref. */ class LocalRefDeleter { public: - explicit LocalRefDeleter(JNIEnv* env) : env_(env) {} + explicit LocalRefDeleter() = default; LocalRefDeleter(const LocalRefDeleter&) = default; // This is to allow move semantics in `LocalRefUniquePtr`. - LocalRefDeleter& operator=(const LocalRefDeleter&) { return *this; } - - void operator()(jobject object) const { - if (object != nullptr) { - env_->DeleteLocalRef(object); - } - } + LocalRefDeleter& operator=(const LocalRefDeleter&) = default; -private: - JNIEnv* const env_; + void operator()(jobject object) const; }; /** A unique pointer for JNI local ref. */ @@ -53,16 +44,11 @@ using LocalRefUniquePtr = std::unique_ptr::type, /** A custom deleter for UTF strings. */ class StringUtfDeleter { public: - StringUtfDeleter(JNIEnv* env, jstring j_str) : env_(env), j_str_(j_str) {} + explicit StringUtfDeleter(jstring j_str) : j_str_(j_str) {} - void operator()(const char* c_str) const { - if (c_str != nullptr) { - env_->ReleaseStringUTFChars(j_str_, c_str); - } - } + void operator()(const char* c_str) const; private: - JNIEnv* const env_; jstring j_str_; }; @@ -111,16 +97,11 @@ using ArrayElementsUniquePtr = std::unique_ptr< /** A custom deleter for JNI primitive array critical. */ class PrimitiveArrayCriticalDeleter { public: - PrimitiveArrayCriticalDeleter(JNIEnv* env, jarray array) : env_(env), array_(array) {} + explicit PrimitiveArrayCriticalDeleter(jarray array) : array_(array) {} - void operator()(void* c_array) const { - if (c_array != nullptr) { - env_->ReleasePrimitiveArrayCritical(array_, c_array, 0); - } - } + void operator()(void* c_array) const; private: - JNIEnv* const env_; jarray array_; }; @@ -221,7 +202,7 @@ class JniHelper { template [[nodiscard]] LocalRefUniquePtr getObjectField(jobject object, jfieldID field_id) { LocalRefUniquePtr result(static_cast(env_->GetObjectField(object, field_id)), - LocalRefDeleter(env_)); + LocalRefDeleter()); return result; } @@ -360,7 +341,7 @@ class JniHelper { template [[nodiscard]] LocalRefUniquePtr getObjectArrayElement(jobjectArray array, jsize index) { LocalRefUniquePtr result(static_cast(env_->GetObjectArrayElement(array, index)), - LocalRefDeleter(env_)); + LocalRefDeleter()); rethrowException(); return result; } @@ -382,7 +363,7 @@ class JniHelper { jboolean* is_copy) { PrimitiveArrayCriticalUniquePtr result( static_cast(env_->GetPrimitiveArrayCritical(array, is_copy)), - PrimitiveArrayCriticalDeleter(env_, array)); + PrimitiveArrayCriticalDeleter(array)); return result; } @@ -429,7 +410,7 @@ class JniHelper { va_list args; va_start(args, method_id); LocalRefUniquePtr result(static_cast(env_->CallObjectMethodV(object, method_id, args)), - LocalRefDeleter(env_)); + LocalRefDeleter()); va_end(args); rethrowException(); return result; @@ -461,8 +442,7 @@ class JniHelper { va_list args; va_start(args, method_id); LocalRefUniquePtr result( - static_cast(env_->CallStaticObjectMethodV(clazz, method_id, args)), - LocalRefDeleter(env_)); + static_cast(env_->CallStaticObjectMethodV(clazz, method_id, args)), LocalRefDeleter()); va_end(args); rethrowException(); return result; diff --git a/mobile/library/jni/jni_impl.cc b/mobile/library/jni/jni_impl.cc index ef7d6573886d..847182a8c241 100644 --- a/mobile/library/jni/jni_impl.cc +++ b/mobile/library/jni/jni_impl.cc @@ -565,8 +565,8 @@ jvm_http_filter_on_resume(const char* method, envoy_headers* headers, envoy_data headers_length = static_cast(headers->length); passHeaders("passHeader", *headers, j_context); } - Envoy::JNI::LocalRefUniquePtr j_in_data = Envoy::JNI::LocalRefUniquePtr( - nullptr, Envoy::JNI::LocalRefDeleter(jni_helper.getEnv())); + Envoy::JNI::LocalRefUniquePtr j_in_data = + Envoy::JNI::LocalRefUniquePtr(nullptr, Envoy::JNI::LocalRefDeleter()); if (data) { j_in_data = Envoy::JNI::envoyDataToJavaByteBuffer(jni_helper, *data); }