Skip to content

Commit

Permalink
Support multi threads under send feature flag
Browse files Browse the repository at this point in the history
  • Loading branch information
khvzak committed Jul 8, 2024
1 parent c4e0c4c commit 00f646a
Show file tree
Hide file tree
Showing 13 changed files with 222 additions and 103 deletions.
15 changes: 9 additions & 6 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,12 @@ impl IntoLua for WrappedAsyncFunction {
}
}

// #[cfg(test)]
// mod assertions {
// use super::*;

// static_assertions::assert_not_impl_any!(Function: Send);
// }
#[cfg(test)]
mod assertions {
use super::*;

#[cfg(not(feature = "send"))]
static_assertions::assert_not_impl_any!(Function: Send);
#[cfg(feature = "send")]
static_assertions::assert_impl_all!(Function: Send, Sync);
}
2 changes: 1 addition & 1 deletion src/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use std::ops::{BitOr, BitOrAssign};
use std::os::raw::c_int;

use ffi::lua_Debug;
use parking_lot::ReentrantMutexGuard;

use crate::state::RawLua;
use crate::types::ReentrantMutexGuard;
use crate::util::{linenumber_to_usize, ptr_to_lossy_str, ptr_to_str};

/// Contains information about currently executing Lua code.
Expand Down
58 changes: 29 additions & 29 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ use std::marker::PhantomData;
use std::ops::Deref;
use std::os::raw::{c_int, c_void};
use std::panic::Location;
use std::rc::Rc;
use std::result::Result as StdResult;
use std::sync::{Arc, Weak};
use std::{mem, ptr};

use parking_lot::{ReentrantMutex, ReentrantMutexGuard};

use crate::chunk::{AsChunk, Chunk};
use crate::error::{Error, Result};
use crate::function::Function;
Expand All @@ -24,7 +22,7 @@ use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LightUserData, MaybeSend, Number,
RegistryKey,
ReentrantMutex, ReentrantMutexGuard, RegistryKey, XRc, XWeak,
};
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataVariant};
use crate::util::{assert_stack, check_stack, push_string, push_table, rawset_field, StackGuard};
Expand All @@ -49,11 +47,11 @@ use util::{callback_error_ext, StateGuard};
/// Top level Lua struct which represents an instance of Lua VM.
#[derive(Clone)]
#[repr(transparent)]
pub struct Lua(Arc<ReentrantMutex<RawLua>>);
pub struct Lua(XRc<ReentrantMutex<RawLua>>);

#[derive(Clone)]
#[repr(transparent)]
pub(crate) struct WeakLua(Weak<ReentrantMutex<RawLua>>);
pub(crate) struct WeakLua(XWeak<ReentrantMutex<RawLua>>);

pub(crate) struct LuaGuard(ArcReentrantMutexGuard<RawLua>);

Expand Down Expand Up @@ -142,11 +140,6 @@ impl LuaOptions {
}
}

/// Requires `feature = "send"`
#[cfg(feature = "send")]
#[cfg_attr(docsrs, doc(cfg(feature = "send")))]
unsafe impl Send for Lua {}

#[cfg(not(feature = "module"))]
impl Drop for Lua {
fn drop(&mut self) {
Expand Down Expand Up @@ -605,7 +598,7 @@ impl Lua {
let interrupt_cb = (*extra).interrupt_callback.clone();
let interrupt_cb =
mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc");
if Arc::strong_count(&interrupt_cb) > 2 {
if Rc::strong_count(&interrupt_cb) > 2 {
return Ok(VmState::Continue); // Don't allow recursion
}
let _guard = StateGuard::new((*extra).raw_lua(), state);
Expand All @@ -622,7 +615,7 @@ impl Lua {
// Set interrupt callback
let lua = self.lock();
unsafe {
(*lua.extra.get()).interrupt_callback = Some(Arc::new(callback));
(*lua.extra.get()).interrupt_callback = Some(Rc::new(callback));
(*ffi::lua_callbacks(lua.main_state)).interrupt = Some(interrupt_proc);
}
}
Expand Down Expand Up @@ -947,7 +940,8 @@ impl Lua {
#[cfg(any(feature = "luau-jit", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau-jit")))]
pub fn enable_jit(&self, enable: bool) {
unsafe { (*self.extra.get()).enable_jit = enable };
let lua = self.lock();
unsafe { (*lua.extra.get()).enable_jit = enable };
}

/// Sets Luau feature flag (global setting).
Expand Down Expand Up @@ -1879,15 +1873,15 @@ impl Lua {

#[inline(always)]
pub(crate) fn weak(&self) -> WeakLua {
WeakLua(Arc::downgrade(&self.0))
WeakLua(XRc::downgrade(&self.0))
}
}

impl WeakLua {
#[track_caller]
#[inline(always)]
pub(crate) fn lock(&self) -> LuaGuard {
LuaGuard::new(self.0.upgrade().unwrap())
LuaGuard::new(self.0.upgrade().expect("Lua instance is destroyed"))
}

#[inline(always)]
Expand All @@ -1898,15 +1892,21 @@ impl WeakLua {

impl PartialEq for WeakLua {
fn eq(&self, other: &Self) -> bool {
Weak::ptr_eq(&self.0, &other.0)
XWeak::ptr_eq(&self.0, &other.0)
}
}

impl Eq for WeakLua {}

impl LuaGuard {
pub(crate) fn new(handle: Arc<ReentrantMutex<RawLua>>) -> Self {
Self(handle.lock_arc())
#[cfg(feature = "send")]
pub(crate) fn new(handle: XRc<ReentrantMutex<RawLua>>) -> Self {
LuaGuard(handle.lock_arc())
}

#[cfg(not(feature = "send"))]
pub(crate) fn new(handle: XRc<ReentrantMutex<RawLua>>) -> Self {
LuaGuard(handle.into_lock_arc())
}
}

Expand All @@ -1922,15 +1922,15 @@ pub(crate) mod extra;
mod raw;
pub(crate) mod util;

// #[cfg(test)]
// mod assertions {
// use super::*;
#[cfg(test)]
mod assertions {
use super::*;

// // Lua has lots of interior mutability, should not be RefUnwindSafe
// static_assertions::assert_not_impl_any!(Lua: std::panic::RefUnwindSafe);
// Lua has lots of interior mutability, should not be RefUnwindSafe
static_assertions::assert_not_impl_any!(Lua: std::panic::RefUnwindSafe);

// #[cfg(not(feature = "send"))]
// static_assertions::assert_not_impl_any!(Lua: Send);
// #[cfg(feature = "send")]
// static_assertions::assert_impl_all!(Lua: Send);
// }
#[cfg(not(feature = "send"))]
static_assertions::assert_not_impl_any!(Lua: Send);
#[cfg(feature = "send")]
static_assertions::assert_impl_all!(Lua: Send, Sync);
}
18 changes: 9 additions & 9 deletions src/state/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use std::rc::Rc;
use std::mem::{self, MaybeUninit};
use std::os::raw::{c_int, c_void};
use std::ptr;
use std::sync::{Arc, Weak};
use std::sync::Arc;

use parking_lot::{Mutex, ReentrantMutex};
use parking_lot::Mutex;
use rustc_hash::FxHashMap;

use crate::error::Result;
use crate::state::RawLua;
use crate::stdlib::StdLib;
use crate::types::AppData;
use crate::types::{AppData, ReentrantMutex, XRc, XWeak};
use crate::util::{get_gc_metatable, push_gc_userdata, WrappedFailure};

#[cfg(any(feature = "luau", doc))]
Expand All @@ -34,9 +34,9 @@ const REF_STACK_RESERVE: c_int = 1;
/// Data associated with the Lua state.
pub(crate) struct ExtraData {
// Same layout as `Lua`
pub(super) lua: MaybeUninit<Arc<ReentrantMutex<RawLua>>>,
pub(super) lua: MaybeUninit<XRc<ReentrantMutex<RawLua>>>,
// Same layout as `WeakLua`
pub(super) weak: MaybeUninit<Weak<ReentrantMutex<RawLua>>>,
pub(super) weak: MaybeUninit<XWeak<ReentrantMutex<RawLua>>>,

pub(super) registered_userdata: FxHashMap<TypeId, c_int>,
pub(super) registered_userdata_mt: FxHashMap<*const c_void, Option<TypeId>>,
Expand Down Expand Up @@ -179,12 +179,12 @@ impl ExtraData {
extra
}

pub(super) unsafe fn set_lua(&mut self, lua: &Arc<ReentrantMutex<RawLua>>) {
self.lua.write(Arc::clone(lua));
pub(super) unsafe fn set_lua(&mut self, lua: &XRc<ReentrantMutex<RawLua>>) {
self.lua.write(XRc::clone(lua));
if cfg!(not(feature = "module")) {
Arc::decrement_strong_count(Arc::as_ptr(lua));
XRc::decrement_strong_count(XRc::as_ptr(lua));
}
self.weak.write(Arc::downgrade(lua));
self.weak.write(XRc::downgrade(lua));
}

pub(super) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self {
Expand Down
21 changes: 11 additions & 10 deletions src/state/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use std::{mem, ptr};

use parking_lot::ReentrantMutex;

use crate::chunk::ChunkMode;
use crate::error::{Error, Result};
use crate::function::Function;
Expand All @@ -21,7 +19,7 @@ use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer,
LightUserData, MaybeSend, RegistryKey, SubtypeId, ValueRef,
LightUserData, MaybeSend, ReentrantMutex, RegistryKey, SubtypeId, ValueRef, XRc,
};
use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataRegistry, UserDataVariant};
use crate::util::{
Expand Down Expand Up @@ -69,6 +67,9 @@ impl Drop for RawLua {
}
}

#[cfg(feature = "send")]
unsafe impl Send for RawLua {}

impl RawLua {
#[inline(always)]
pub(crate) fn lua(&self) -> &Lua {
Expand Down Expand Up @@ -96,7 +97,7 @@ impl RawLua {
unsafe { (*self.extra.get()).ref_thread }
}

pub(super) unsafe fn new(libs: StdLib, options: LuaOptions) -> Arc<ReentrantMutex<Self>> {
pub(super) unsafe fn new(libs: StdLib, options: LuaOptions) -> XRc<ReentrantMutex<Self>> {
let mem_state: *mut MemoryState = Box::into_raw(Box::default());
let mut state = ffi::lua_newstate(ALLOCATOR, mem_state as *mut c_void);
// If state is null then switch to Lua internal allocator
Expand Down Expand Up @@ -154,7 +155,7 @@ impl RawLua {
rawlua
}

pub(super) unsafe fn init_from_ptr(state: *mut ffi::lua_State) -> Arc<ReentrantMutex<Self>> {
pub(super) unsafe fn init_from_ptr(state: *mut ffi::lua_State) -> XRc<ReentrantMutex<Self>> {
assert!(!state.is_null(), "Lua state is NULL");
if let Some(lua) = Self::try_from_ptr(state) {
return lua;
Expand Down Expand Up @@ -209,7 +210,7 @@ impl RawLua {
assert_stack(main_state, ffi::LUA_MINSTACK);

#[allow(clippy::arc_with_non_send_sync)]
let rawlua = Arc::new(ReentrantMutex::new(RawLua {
let rawlua = XRc::new(ReentrantMutex::new(RawLua {
state: Cell::new(state),
main_state,
extra: Rc::clone(&extra),
Expand All @@ -221,10 +222,10 @@ impl RawLua {

pub(super) unsafe fn try_from_ptr(
state: *mut ffi::lua_State,
) -> Option<Arc<ReentrantMutex<Self>>> {
) -> Option<XRc<ReentrantMutex<Self>>> {
match ExtraData::get(state) {
extra if extra.is_null() => None,
extra => Some(Arc::clone(&(*extra).lua().0)),
extra => Some(XRc::clone(&(*extra).lua().0)),
}
}

Expand Down Expand Up @@ -369,7 +370,7 @@ impl RawLua {
callback_error_ext(state, extra, move |_| {
let hook_cb = (*extra).hook_callback.clone();
let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc");
if Arc::strong_count(&hook_cb) > 2 {
if Rc::strong_count(&hook_cb) > 2 {
return Ok(()); // Don't allow recursion
}
let rawlua = (*extra).raw_lua();
Expand All @@ -379,7 +380,7 @@ impl RawLua {
})
}

(*self.extra.get()).hook_callback = Some(Arc::new(callback));
(*self.extra.get()).hook_callback = Some(Rc::new(callback));
(*self.extra.get()).hook_thread = state; // Mark for what thread the hook is set
ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count());
}
Expand Down
15 changes: 9 additions & 6 deletions src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,12 @@ impl Serialize for String {
}
}

// #[cfg(test)]
// mod assertions {
// use super::*;

// static_assertions::assert_not_impl_any!(String: Send);
// }
#[cfg(test)]
mod assertions {
use super::*;

#[cfg(not(feature = "send"))]
static_assertions::assert_not_impl_any!(String: Send);
#[cfg(feature = "send")]
static_assertions::assert_impl_all!(String: Send, Sync);
}
15 changes: 9 additions & 6 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1219,9 +1219,12 @@ where
}
}

// #[cfg(test)]
// mod assertions {
// use super::*;

// static_assertions::assert_not_impl_any!(Table: Send);
// }
#[cfg(test)]
mod assertions {
use super::*;

#[cfg(not(feature = "send"))]
static_assertions::assert_not_impl_any!(Table: Send);
#[cfg(feature = "send")]
static_assertions::assert_impl_all!(Table: Send, Sync);
}
8 changes: 8 additions & 0 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ pub enum ThreadStatus {
#[derive(Clone, Debug)]
pub struct Thread(pub(crate) ValueRef, pub(crate) *mut ffi::lua_State);

#[cfg(feature = "send")]
unsafe impl Send for Thread {}
#[cfg(feature = "send")]
unsafe impl Sync for Thread {}

/// Thread (coroutine) representation as an async [`Future`] or [`Stream`].
///
/// Requires `feature = "async"`
Expand Down Expand Up @@ -526,5 +531,8 @@ impl<'lua, 'a> Drop for WakerGuard<'lua, 'a> {
mod assertions {
use super::*;

#[cfg(not(feature = "send"))]
static_assertions::assert_not_impl_any!(Thread: Send);
#[cfg(feature = "send")]
static_assertions::assert_impl_all!(Thread: Send, Sync);
}
Loading

0 comments on commit 00f646a

Please sign in to comment.