Skip to content

Commit

Permalink
Add MaybeSend requirement to Lua futures
Browse files Browse the repository at this point in the history
  • Loading branch information
khvzak committed Aug 10, 2024
1 parent 0c08cda commit c58f67b
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 46 deletions.
6 changes: 3 additions & 3 deletions examples/async_tcp_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};

use mlua::{chunk, Function, Lua, String as LuaString, UserData, UserDataMethods};
use mlua::{chunk, BString, Function, Lua, UserData, UserDataMethods};

struct LuaTcpStream(TcpStream);

Expand All @@ -19,8 +19,8 @@ impl UserData for LuaTcpStream {
lua.create_string(&buf)
});

methods.add_async_method_mut("write", |_, this, data: LuaString| async move {
let n = this.0.write(&data.as_bytes()).await?;
methods.add_async_method_mut("write", |_, this, data: BString| async move {
let n = this.0.write(&data).await?;
Ok(n)
});

Expand Down
8 changes: 3 additions & 5 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,17 +559,15 @@ impl Function {
A: FromLuaMulti,
R: IntoLuaMulti,
F: Fn(&Lua, A) -> FR + MaybeSend + 'static,
FR: Future<Output = Result<R>> + 'static,
FR: Future<Output = Result<R>> + MaybeSend + 'static,
{
WrappedAsyncFunction(Box::new(move |rawlua, args| unsafe {
let lua = rawlua.lua();
WrappedAsyncFunction(Box::new(move |lua, args| unsafe {
let args = match A::from_lua_args(args, 1, None, lua) {
Ok(args) => args,
Err(e) => return Box::pin(future::ready(Err(e))),
};
let fut = func(lua, args);
let weak = rawlua.weak().clone();
Box::pin(async move { fut.await?.push_into_stack_multi(&weak.lock()) })
Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) })
}))
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ mod value;

pub mod prelude;

pub use bstr::BString;
pub use ffi::{self, lua_CFunction, lua_State};

pub use crate::chunk::{AsChunk, Chunk, ChunkMode};
Expand All @@ -113,7 +114,7 @@ pub use crate::stdlib::StdLib;
pub use crate::string::{BorrowedBytes, BorrowedStr, String};
pub use crate::table::{Table, TableExt, TablePairs, TableSequence};
pub use crate::thread::{Thread, ThreadStatus};
pub use crate::types::{AppDataRef, AppDataRefMut, Integer, LightUserData, Number, RegistryKey};
pub use crate::types::{AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey};
pub use crate::userdata::{
AnyUserData, AnyUserDataExt, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods,
UserDataRef, UserDataRefMut, UserDataRegistry,
Expand Down
16 changes: 11 additions & 5 deletions src/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::any::TypeId;
use std::cell::RefCell;
// use std::collections::VecDeque;
use std::marker::PhantomData;
use std::ops::Deref;
use std::os::raw::{c_int, c_void};
Expand Down Expand Up @@ -1163,17 +1162,16 @@ impl Lua {
'lua: 'a,
F: Fn(&'a Lua, A) -> FR + MaybeSend + 'static,
A: FromLuaMulti,
FR: Future<Output = Result<R>> + 'a,
FR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
(self.lock()).create_async_callback(Box::new(move |rawlua, args| unsafe {
let lua = rawlua.lua();
(self.lock()).create_async_callback(Box::new(move |lua, args| unsafe {
let args = match A::from_lua_args(args, 1, None, lua) {
Ok(args) => args,
Err(e) => return Box::pin(future::ready(Err(e))),
};
let fut = func(lua, args);
Box::pin(async move { fut.await?.push_into_stack_multi(rawlua) })
Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) })
}))
}

Expand Down Expand Up @@ -1840,6 +1838,14 @@ impl Lua {
pub(crate) fn weak(&self) -> WeakLua {
WeakLua(XRc::downgrade(&self.0))
}

/// Returns a handle to the unprotected Lua state without any synchronization.
///
/// This is useful where we know that the lock is already held by the caller.
#[inline(always)]
pub(crate) unsafe fn raw_lua(&self) -> &RawLua {
&*self.0.data_ptr()
}
}

impl WeakLua {
Expand Down
2 changes: 1 addition & 1 deletion src/state/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ impl RawLua {

let args = MultiValue::from_stack_multi(nargs, rawlua)?;
let func = &*(*upvalue).data;
let fut = func(rawlua, args);
let fut = func(rawlua.lua(), args);
let extra = XRc::clone(&(*upvalue).extra);
let protect = !rawlua.unlikely_memory_error();
push_internal_userdata(state, AsyncPollUpvalue { data: fut, extra }, protect)?;
Expand Down
4 changes: 4 additions & 0 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,4 +528,8 @@ mod assertions {
static_assertions::assert_not_impl_any!(Thread: Send);
#[cfg(feature = "send")]
static_assertions::assert_impl_all!(Thread: Send, Sync);
#[cfg(all(feature = "async", not(feature = "send")))]
static_assertions::assert_not_impl_any!(AsyncThread<()>: Send);
#[cfg(all(feature = "async", feature = "send"))]
static_assertions::assert_impl_all!(AsyncThread<()>: Send, Sync);
}
13 changes: 10 additions & 3 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ use crate::hook::Debug;
use crate::state::{ExtraData, Lua, RawLua, WeakLua};

#[cfg(feature = "async")]
use {crate::value::MultiValue, futures_util::future::LocalBoxFuture};
use crate::value::MultiValue;

#[cfg(all(feature = "async", feature = "send"))]
pub(crate) type BoxFuture<'a, T> = futures_util::future::BoxFuture<'a, T>;

#[cfg(all(feature = "async", not(feature = "send")))]
pub(crate) type BoxFuture<'a, T> = futures_util::future::LocalBoxFuture<'a, T>;

#[cfg(all(feature = "luau", feature = "serialize"))]
use serde::ser::{Serialize, SerializeTupleStruct, Serializer};
Expand Down Expand Up @@ -57,13 +63,13 @@ pub(crate) type CallbackUpvalue = Upvalue<Callback<'static>>;

#[cfg(feature = "async")]
pub(crate) type AsyncCallback<'a> =
Box<dyn Fn(&'a RawLua, MultiValue) -> LocalBoxFuture<'a, Result<c_int>> + 'static>;
Box<dyn Fn(&'a Lua, MultiValue) -> BoxFuture<'a, Result<c_int>> + 'static>;

#[cfg(feature = "async")]
pub(crate) type AsyncCallbackUpvalue = Upvalue<AsyncCallback<'static>>;

#[cfg(feature = "async")]
pub(crate) type AsyncPollUpvalue = Upvalue<LocalBoxFuture<'static, Result<c_int>>>;
pub(crate) type AsyncPollUpvalue = Upvalue<BoxFuture<'static, Result<c_int>>>;

/// Type to set next Luau VM action after executing interrupt function.
#[cfg(any(feature = "luau", doc))]
Expand Down Expand Up @@ -91,6 +97,7 @@ pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &str, bool) -> Result<()> + Send
#[cfg(all(not(feature = "send"), feature = "lua54"))]
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &str, bool) -> Result<()>>;

/// A trait that adds `Send` requirement if `send` feature is enabled.
#[cfg(feature = "send")]
pub trait MaybeSend: Send {}
#[cfg(feature = "send")]
Expand Down
12 changes: 6 additions & 6 deletions src/userdata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ pub trait UserDataMethods<'a, T> {
T: 'static,
M: Fn(&'a Lua, &'a T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti;

/// Add an async method which accepts a `&mut T` as the first parameter and returns Future.
Expand All @@ -304,7 +304,7 @@ pub trait UserDataMethods<'a, T> {
T: 'static,
M: Fn(&'a Lua, &'a mut T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti;

/// Add a regular method as a function which accepts generic arguments, the first argument will
Expand Down Expand Up @@ -348,7 +348,7 @@ pub trait UserDataMethods<'a, T> {
where
F: Fn(&'a Lua, A) -> FR + MaybeSend + 'static,
A: FromLuaMulti,
FR: Future<Output = Result<R>> + 'a,
FR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti;

/// Add a metamethod which accepts a `&T` as the first parameter.
Expand Down Expand Up @@ -393,7 +393,7 @@ pub trait UserDataMethods<'a, T> {
T: 'static,
M: Fn(&'a Lua, &'a T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti;

/// Add an async metamethod which accepts a `&mut T` as the first parameter and returns Future.
Expand All @@ -410,7 +410,7 @@ pub trait UserDataMethods<'a, T> {
T: 'static,
M: Fn(&'a Lua, &'a mut T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti;

/// Add a metamethod which accepts generic arguments.
Expand Down Expand Up @@ -448,7 +448,7 @@ pub trait UserDataMethods<'a, T> {
where
F: Fn(&'a Lua, A) -> FR + MaybeSend + 'static,
A: FromLuaMulti,
FR: Future<Output = Result<R>> + 'a,
FR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti;
}

Expand Down
37 changes: 17 additions & 20 deletions src/userdata/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl<'a, T: 'static> UserDataRegistry<'a, T> {
where
M: Fn(&'a Lua, &'a T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
let name = get_function_name::<T>(name);
Expand All @@ -145,15 +145,14 @@ impl<'a, T: 'static> UserDataRegistry<'a, T> {
};
}

Box::new(move |rawlua, mut args| unsafe {
Box::new(move |lua, mut args| unsafe {
let this = args
.pop_front()
.ok_or_else(|| Error::from_lua_conversion("missing argument", "userdata", None));
let lua = rawlua.lua();
let this = try_self_arg!(AnyUserData::from_lua(try_self_arg!(this), lua));
let args = A::from_lua_args(args, 2, Some(&name), lua);

let (ref_thread, index) = (rawlua.ref_thread(), this.0.index);
let (ref_thread, index) = (lua.raw_lua().ref_thread(), this.0.index);
match try_self_arg!(this.type_id()) {
Some(id) if id == TypeId::of::<T>() => {
let ud = try_self_arg!(borrow_userdata_ref::<T>(ref_thread, index));
Expand All @@ -162,7 +161,7 @@ impl<'a, T: 'static> UserDataRegistry<'a, T> {
Err(e) => return Box::pin(future::ready(Err(e))),
};
let fut = method(lua, ud.get_ref(), args);
Box::pin(async move { fut.await?.push_into_stack_multi(rawlua) })
Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) })
}
_ => {
let err = Error::bad_self_argument(&name, Error::UserDataTypeMismatch);
Expand All @@ -177,7 +176,7 @@ impl<'a, T: 'static> UserDataRegistry<'a, T> {
where
M: Fn(&'a Lua, &'a mut T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
let name = get_function_name::<T>(name);
Expand All @@ -190,15 +189,14 @@ impl<'a, T: 'static> UserDataRegistry<'a, T> {
};
}

Box::new(move |rawlua, mut args| unsafe {
Box::new(move |lua, mut args| unsafe {
let this = args
.pop_front()
.ok_or_else(|| Error::from_lua_conversion("missing argument", "userdata", None));
let lua = rawlua.lua();
let this = try_self_arg!(AnyUserData::from_lua(try_self_arg!(this), lua));
let args = A::from_lua_args(args, 2, Some(&name), lua);

let (ref_thread, index) = (rawlua.ref_thread(), this.0.index);
let (ref_thread, index) = (lua.raw_lua().ref_thread(), this.0.index);
match try_self_arg!(this.type_id()) {
Some(id) if id == TypeId::of::<T>() => {
let mut ud = try_self_arg!(borrow_userdata_mut::<T>(ref_thread, index));
Expand All @@ -207,7 +205,7 @@ impl<'a, T: 'static> UserDataRegistry<'a, T> {
Err(e) => return Box::pin(future::ready(Err(e))),
};
let fut = method(lua, ud.get_mut(), args);
Box::pin(async move { fut.await?.push_into_stack_multi(rawlua) })
Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) })
}
_ => {
let err = Error::bad_self_argument(&name, Error::UserDataTypeMismatch);
Expand Down Expand Up @@ -252,18 +250,17 @@ impl<'a, T: 'static> UserDataRegistry<'a, T> {
where
F: Fn(&'a Lua, A) -> FR + MaybeSend + 'static,
A: FromLuaMulti,
FR: Future<Output = Result<R>> + 'a,
FR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
let name = get_function_name::<T>(name);
Box::new(move |rawlua, args| unsafe {
let lua = rawlua.lua();
Box::new(move |lua, args| unsafe {
let args = match A::from_lua_args(args, 1, Some(&name), lua) {
Ok(args) => args,
Err(e) => return Box::pin(future::ready(Err(e))),
};
let fut = function(lua, args);
Box::pin(async move { fut.await?.push_into_stack_multi(rawlua) })
Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) })
})
}

Expand Down Expand Up @@ -397,7 +394,7 @@ impl<'a, T: 'static> UserDataMethods<'a, T> for UserDataRegistry<'a, T> {
where
M: Fn(&'a Lua, &'a T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
let name = name.to_string();
Expand All @@ -410,7 +407,7 @@ impl<'a, T: 'static> UserDataMethods<'a, T> for UserDataRegistry<'a, T> {
where
M: Fn(&'a Lua, &'a mut T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
let name = name.to_string();
Expand Down Expand Up @@ -445,7 +442,7 @@ impl<'a, T: 'static> UserDataMethods<'a, T> for UserDataRegistry<'a, T> {
where
F: Fn(&'a Lua, A) -> FR + MaybeSend + 'static,
A: FromLuaMulti,
FR: Future<Output = Result<R>> + 'a,
FR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
let name = name.to_string();
Expand Down Expand Up @@ -480,7 +477,7 @@ impl<'a, T: 'static> UserDataMethods<'a, T> for UserDataRegistry<'a, T> {
where
M: Fn(&'a Lua, &'a T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
let name = name.to_string();
Expand All @@ -493,7 +490,7 @@ impl<'a, T: 'static> UserDataMethods<'a, T> for UserDataRegistry<'a, T> {
where
M: Fn(&'a Lua, &'a mut T, A) -> MR + MaybeSend + 'static,
A: FromLuaMulti,
MR: Future<Output = Result<R>> + 'a,
MR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
let name = name.to_string();
Expand Down Expand Up @@ -528,7 +525,7 @@ impl<'a, T: 'static> UserDataMethods<'a, T> for UserDataRegistry<'a, T> {
where
F: Fn(&'a Lua, A) -> FR + MaybeSend + 'static,
A: FromLuaMulti,
FR: Future<Output = Result<R>> + 'a,
FR: Future<Output = Result<R>> + MaybeSend + 'a,
R: IntoLuaMulti,
{
let name = name.to_string();
Expand Down
5 changes: 3 additions & 2 deletions tests/async.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#![cfg(feature = "async")]

use std::sync::{Arc, Mutex};
use std::sync::Arc;
use std::time::Duration;

use futures_util::stream::TryStreamExt;
use tokio::sync::Mutex;

use mlua::{
AnyUserDataExt, Error, Function, Lua, LuaOptions, MultiValue, Result, StdLib, Table, TableExt, UserData,
Expand Down Expand Up @@ -504,7 +505,7 @@ async fn test_async_terminate() -> Result<()> {
let func = lua.create_async_function(move |_, ()| {
let mutex = mutex2.clone();
async move {
let _guard = mutex.lock();
let _guard = mutex.lock().await;
sleep_ms(100).await;
Ok(())
}
Expand Down

0 comments on commit c58f67b

Please sign in to comment.