Skip to content

Commit

Permalink
Implement IntoLua for ref to String/Table/Function/AnyUserData
Browse files Browse the repository at this point in the history
This would prevent cloning plus has better performance when pushing values to Lua stack (`IntoLua::push_into_stack` method)
  • Loading branch information
khvzak committed Jan 23, 2024
1 parent fe6ab25 commit 8200bee
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 40 deletions.
102 changes: 102 additions & 0 deletions src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ impl<'lua> IntoLua<'lua> for String<'lua> {
}
}

impl<'lua> IntoLua<'lua> for &String<'lua> {
#[inline]
fn into_lua(self, _: &'lua Lua) -> Result<Value<'lua>> {
Ok(Value::String(self.clone()))
}

#[inline]
unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> {
Ok(lua.push_ref(&self.0))
}
}

impl<'lua> FromLua<'lua> for String<'lua> {
#[inline]
fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result<String<'lua>> {
Expand All @@ -64,6 +76,18 @@ impl<'lua> IntoLua<'lua> for Table<'lua> {
}
}

impl<'lua> IntoLua<'lua> for &Table<'lua> {
#[inline]
fn into_lua(self, _: &'lua Lua) -> Result<Value<'lua>> {
Ok(Value::Table(self.clone()))
}

#[inline]
unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> {
Ok(lua.push_ref(&self.0))
}
}

impl<'lua> FromLua<'lua> for Table<'lua> {
#[inline]
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result<Table<'lua>> {
Expand All @@ -87,6 +111,20 @@ impl<'lua> IntoLua<'lua> for OwnedTable {
}
}

#[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))]
impl<'lua> IntoLua<'lua> for &OwnedTable {
#[inline]
fn into_lua(self, lua: &'lua Lua) -> Result<Value<'lua>> {
OwnedTable::into_lua(self.clone(), lua)
}

#[inline]
unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> {
Ok(lua.push_owned_ref(&self.0))
}
}

#[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))]
impl<'lua> FromLua<'lua> for OwnedTable {
Expand All @@ -103,6 +141,18 @@ impl<'lua> IntoLua<'lua> for Function<'lua> {
}
}

impl<'lua> IntoLua<'lua> for &Function<'lua> {
#[inline]
fn into_lua(self, _: &'lua Lua) -> Result<Value<'lua>> {
Ok(Value::Function(self.clone()))
}

#[inline]
unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> {
Ok(lua.push_ref(&self.0))
}
}

impl<'lua> FromLua<'lua> for Function<'lua> {
#[inline]
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result<Function<'lua>> {
Expand All @@ -126,6 +176,20 @@ impl<'lua> IntoLua<'lua> for OwnedFunction {
}
}

#[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))]
impl<'lua> IntoLua<'lua> for &OwnedFunction {
#[inline]
fn into_lua(self, lua: &'lua Lua) -> Result<Value<'lua>> {
OwnedFunction::into_lua(self.clone(), lua)
}

#[inline]
unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> {
Ok(lua.push_owned_ref(&self.0))
}
}

#[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))]
impl<'lua> FromLua<'lua> for OwnedFunction {
Expand All @@ -142,6 +206,18 @@ impl<'lua> IntoLua<'lua> for Thread<'lua> {
}
}

impl<'lua> IntoLua<'lua> for &Thread<'lua> {
#[inline]
fn into_lua(self, _: &'lua Lua) -> Result<Value<'lua>> {
Ok(Value::Thread(self.clone()))
}

#[inline]
unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> {
Ok(lua.push_ref(&self.0))
}
}

impl<'lua> FromLua<'lua> for Thread<'lua> {
#[inline]
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result<Thread<'lua>> {
Expand All @@ -163,6 +239,18 @@ impl<'lua> IntoLua<'lua> for AnyUserData<'lua> {
}
}

impl<'lua> IntoLua<'lua> for &AnyUserData<'lua> {
#[inline]
fn into_lua(self, _: &'lua Lua) -> Result<Value<'lua>> {
Ok(Value::UserData(self.clone()))
}

#[inline]
unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> {
Ok(lua.push_ref(&self.0))
}
}

impl<'lua> FromLua<'lua> for AnyUserData<'lua> {
#[inline]
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result<AnyUserData<'lua>> {
Expand All @@ -189,6 +277,20 @@ impl<'lua> IntoLua<'lua> for OwnedAnyUserData {
}
}

#[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))]
impl<'lua> IntoLua<'lua> for &OwnedAnyUserData {
#[inline]
fn into_lua(self, lua: &'lua Lua) -> Result<Value<'lua>> {
OwnedAnyUserData::into_lua(self.clone(), lua)
}

#[inline]
unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> {
Ok(lua.push_owned_ref(&self.0))
}
}

#[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))]
impl<'lua> FromLua<'lua> for OwnedAnyUserData {
Expand Down
9 changes: 9 additions & 0 deletions src/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2558,6 +2558,15 @@ impl Lua {
ffi::lua_xpush(self.ref_thread(), self.state(), lref.index);
}

#[cfg(all(feature = "unstable", not(feature = "send")))]
pub(crate) unsafe fn push_owned_ref(&self, loref: &crate::types::LuaOwnedRef) {
assert!(
Arc::ptr_eq(&loref.inner, &self.0),
"Lua instance passed Value created from a different main Lua state"
);
ffi::lua_xpush(self.ref_thread(), self.state(), loref.index);
}

// Pops the topmost element of the stack and stores a reference to it. This pins the object,
// preventing garbage collection until the returned `LuaRef` is dropped.
//
Expand Down
2 changes: 1 addition & 1 deletion tests/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ async fn test_async_userdata() -> Result<()> {
let globals = lua.globals();

let userdata = lua.create_userdata(MyUserData(11))?;
globals.set("userdata", userdata.clone())?;
globals.set("userdata", &userdata)?;

lua.load(
r#"
Expand Down
175 changes: 174 additions & 1 deletion tests/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,180 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::ffi::{CStr, CString};

use maplit::{btreemap, btreeset, hashmap, hashset};
use mlua::{Error, Lua, Result};
use mlua::{AnyUserData, Error, Function, IntoLua, Lua, Result, Table, Thread, UserDataRef, Value};

#[test]
fn test_string_into_lua() -> Result<()> {
let lua = Lua::new();

// Direct conversion
let s = lua.create_string("hello, world!")?;
let s2 = (&s).into_lua(&lua)?;
assert_eq!(s, s2.as_string().unwrap());

// Push into stack
let table = lua.create_table()?;
table.set("s", &s)?;
assert_eq!(s, table.get::<_, String>("s")?);

Ok(())
}

#[test]
fn test_table_into_lua() -> Result<()> {
let lua = Lua::new();

// Direct conversion
let t = lua.create_table()?;
let t2 = (&t).into_lua(&lua)?;
assert_eq!(&t, t2.as_table().unwrap());

// Push into stack
let f = lua.create_function(|_, (t, s): (Table, String)| t.set("s", s))?;
f.call((&t, "hello"))?;
assert_eq!("hello", t.get::<_, String>("s")?);

Ok(())
}

#[cfg(all(feature = "unstable", not(feature = "send")))]
#[test]
fn test_owned_table_into_lua() -> Result<()> {
let lua = Lua::new();

// Direct conversion
let t = lua.create_table()?.into_owned();
let t2 = (&t).into_lua(&lua)?;
assert_eq!(t.to_ref(), *t2.as_table().unwrap());

// Push into stack
let f = lua.create_function(|_, (t, s): (Table, String)| t.set("s", s))?;
f.call((&t, "hello"))?;
assert_eq!("hello", t.to_ref().get::<_, String>("s")?);

Ok(())
}

#[test]
fn test_function_into_lua() -> Result<()> {
let lua = Lua::new();

// Direct conversion
let f = lua.create_function(|_, ()| Ok::<_, Error>(()))?;
let f2 = (&f).into_lua(&lua)?;
assert_eq!(&f, f2.as_function().unwrap());

// Push into stack
let table = lua.create_table()?;
table.set("f", &f)?;
assert_eq!(f, table.get::<_, Function>("f")?);

Ok(())
}

#[cfg(all(feature = "unstable", not(feature = "send")))]
#[test]
fn test_owned_function_into_lua() -> Result<()> {
let lua = Lua::new();

// Direct conversion
let f = lua
.create_function(|_, ()| Ok::<_, Error>(()))?
.into_owned();
let f2 = (&f).into_lua(&lua)?;
assert_eq!(f.to_ref(), *f2.as_function().unwrap());

// Push into stack
let table = lua.create_table()?;
table.set("f", &f)?;
assert_eq!(f.to_ref(), table.get::<_, Function>("f")?);

Ok(())
}

#[test]
fn test_thread_into_lua() -> Result<()> {
let lua = Lua::new();

// Direct conversion
let f = lua.create_function(|_, ()| Ok::<_, Error>(()))?;
let th = lua.create_thread(f)?;
let th2 = (&th).into_lua(&lua)?;
assert_eq!(&th, th2.as_thread().unwrap());

// Push into stack
let table = lua.create_table()?;
table.set("th", &th)?;
assert_eq!(th, table.get::<_, Thread>("th")?);

Ok(())
}

#[test]
fn test_anyuserdata_into_lua() -> Result<()> {
let lua = Lua::new();

// Direct conversion
let ud = lua.create_any_userdata(String::from("hello"))?;
let ud2 = (&ud).into_lua(&lua)?;
assert_eq!(&ud, ud2.as_userdata().unwrap());

// Push into stack
let table = lua.create_table()?;
table.set("ud", &ud)?;
assert_eq!(ud, table.get::<_, AnyUserData>("ud")?);
assert_eq!("hello", *table.get::<_, UserDataRef<String>>("ud")?);

Ok(())
}

#[cfg(all(feature = "unstable", not(feature = "send")))]
#[test]
fn test_owned_anyuserdata_into_lua() -> Result<()> {
let lua = Lua::new();

// Direct conversion
let ud = lua.create_any_userdata(String::from("hello"))?.into_owned();
let ud2 = (&ud).into_lua(&lua)?;
assert_eq!(ud.to_ref(), *ud2.as_userdata().unwrap());

// Push into stack
let table = lua.create_table()?;
table.set("ud", &ud)?;
assert_eq!(ud.to_ref(), table.get::<_, AnyUserData>("ud")?);
assert_eq!("hello", *table.get::<_, UserDataRef<String>>("ud")?);

Ok(())
}

#[test]
fn test_registry_value_into_lua() -> Result<()> {
let lua = Lua::new();

let t = lua.create_table()?;
let r = lua.create_registry_value(t)?;
let f = lua.create_function(|_, t: Table| t.raw_set("hello", "world"))?;

f.call(&r)?;
let v = r.into_lua(&lua)?;
let t = v.as_table().unwrap();
assert_eq!(t.get::<_, String>("hello")?, "world");

// Try to set nil registry key
let r_nil = lua.create_registry_value(Value::Nil)?;
t.set("hello", &r_nil)?;
assert_eq!(t.get::<_, Value>("hello")?, Value::Nil);

// Check non-owned registry key
let lua2 = Lua::new();
let r2 = lua2.create_registry_value("abc")?;
assert!(matches!(
f.call::<_, ()>(&r2),
Err(Error::MismatchedRegistryKey)
));

Ok(())
}

#[test]
fn test_conv_vec() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ fn test_from_value_with_options() -> Result<(), Box<dyn StdError>> {

// Check recursion when using `Serialize` impl
let t = lua.create_table()?;
t.set("t", t.clone())?;
t.set("t", &t)?;
assert!(serde_json::to_string(&t).is_err());

// Serialize Lua globals table
Expand Down
Loading

0 comments on commit 8200bee

Please sign in to comment.