From fa217d3706ebebfd1519743c3e620a940c218967 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Thu, 28 Mar 2024 13:05:01 +0000 Subject: [PATCH] Better Luau buffer type support. - Add `Lua::create_buffer()` function - Support serializing buffer type as a byte slice - Support accessing copy of underlying bytes using `BString` --- src/conversion.rs | 40 +++++++++++++++++++++++++++++++++++----- src/lua.rs | 21 +++++++++++++++++++++ src/serde/de.rs | 8 ++++++++ src/userdata.rs | 13 +++++++++++++ src/util/mod.rs | 14 ++++++++++++++ tests/conversion.rs | 44 ++++++++++++++++++++++++++++++++++++++++++++ tests/serde.rs | 27 +++++++++++++++++++++++++++ 7 files changed, 162 insertions(+), 5 deletions(-) diff --git a/src/conversion.rs b/src/conversion.rs index 2c57269e..8ca8fe86 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -679,19 +679,49 @@ impl<'lua> IntoLua<'lua> for BString { } impl<'lua> FromLua<'lua> for BString { - #[inline] fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { let ty = value.type_name(); - Ok(BString::from( - lua.coerce_string(value)? + match value { + Value::String(s) => Ok(s.as_bytes().into()), + #[cfg(feature = "luau")] + Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe { + let mut size = 0usize; + let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size); + mlua_assert!(!buf.is_null(), "invalid Luau buffer"); + Ok(slice::from_raw_parts(buf as *const u8, size).into()) + }, + _ => Ok(lua + .coerce_string(value)? .ok_or_else(|| Error::FromLuaConversionError { from: ty, to: "BString", message: Some("expected string or number".to_string()), })? .as_bytes() - .to_vec(), - )) + .into()), + } + } + + unsafe fn from_stack(idx: c_int, lua: &'lua Lua) -> Result { + let state = lua.state(); + match ffi::lua_type(state, idx) { + ffi::LUA_TSTRING => { + let mut size = 0; + let data = ffi::lua_tolstring(state, idx, &mut size); + Ok(slice::from_raw_parts(data as *const u8, size).into()) + } + #[cfg(feature = "luau")] + ffi::LUA_TBUFFER => { + let mut size = 0; + let buf = ffi::lua_tobuffer(state, idx, &mut size); + mlua_assert!(!buf.is_null(), "invalid Luau buffer"); + Ok(slice::from_raw_parts(buf as *const u8, size).into()) + } + _ => { + // Fallback to default + Self::from_lua(lua.stack_value(idx), lua) + } + } } } diff --git a/src/lua.rs b/src/lua.rs index 9938a74c..6691bf0e 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -1373,6 +1373,27 @@ impl Lua { } } + /// Create and return a Luau [buffer] object from a byte slice of data. + /// + /// Requires `feature = "luau"` + /// + /// [buffer]: https://luau-lang.org/library#buffer-library + #[cfg(feature = "luau")] + pub fn create_buffer(&self, buf: impl AsRef<[u8]>) -> Result { + let state = self.state(); + unsafe { + if self.unlikely_memory_error() { + crate::util::push_buffer(self.ref_thread(), buf.as_ref(), false)?; + return Ok(AnyUserData(self.pop_ref_thread(), SubtypeId::Buffer)); + } + + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + crate::util::push_buffer(state, buf.as_ref(), true)?; + Ok(AnyUserData(self.pop_ref(), SubtypeId::Buffer)) + } + } + /// Creates and returns a new empty table. pub fn create_table(&self) -> Result { self.create_table_with_capacity(0, 0) diff --git a/src/serde/de.rs b/src/serde/de.rs index 5d3512c2..6933e4e2 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -148,6 +148,14 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { Value::UserData(ud) if ud.is_serializable() => { serde_userdata(ud, |value| value.deserialize_any(visitor)) } + #[cfg(feature = "luau")] + Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe { + let mut size = 0usize; + let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size); + mlua_assert!(!buf.is_null(), "invalid Luau buffer"); + let buf = std::slice::from_raw_parts(buf as *const u8, size); + visitor.visit_bytes(buf) + }, Value::Function(_) | Value::Thread(_) | Value::UserData(_) diff --git a/src/userdata.rs b/src/userdata.rs index e2c4a748..dbbbffa3 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -1340,6 +1340,19 @@ impl<'lua> Serialize for AnyUserData<'lua> { S: Serializer, { let lua = self.0.lua; + + // Special case for Luau buffer type + #[cfg(feature = "luau")] + if self.1 == SubtypeId::Buffer { + let buf = unsafe { + let mut size = 0usize; + let buf = ffi::lua_tobuffer(lua.ref_thread(), self.0.index, &mut size); + mlua_assert!(!buf.is_null(), "invalid Luau buffer"); + std::slice::from_raw_parts(buf as *const u8, size) + }; + return serializer.serialize_bytes(buf); + } + let data = unsafe { let _ = lua .get_userdata_ref_type_id(&self.0) diff --git a/src/util/mod.rs b/src/util/mod.rs index 595407e1..ff8b28e9 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -253,6 +253,20 @@ pub unsafe fn push_string(state: *mut ffi::lua_State, s: &[u8], protect: bool) - } } +// Uses 3 stack spaces (when protect), does not call checkstack. +#[cfg(feature = "luau")] +#[inline(always)] +pub unsafe fn push_buffer(state: *mut ffi::lua_State, b: &[u8], protect: bool) -> Result<()> { + let data = if protect { + protect_lua!(state, 0, 1, |state| ffi::lua_newbuffer(state, b.len()))? + } else { + ffi::lua_newbuffer(state, b.len()) + }; + let buf = slice::from_raw_parts_mut(data as *mut u8, b.len()); + buf.copy_from_slice(b); + Ok(()) +} + // Uses 3 stack spaces, does not call checkstack. #[inline] pub unsafe fn push_table( diff --git a/tests/conversion.rs b/tests/conversion.rs index a312f114..ad2c09ad 100644 --- a/tests/conversion.rs +++ b/tests/conversion.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::ffi::{CStr, CString}; +use bstr::BString; use maplit::{btreemap, btreeset, hashmap, hashset}; use mlua::{ AnyUserData, Error, Function, IntoLua, Lua, RegistryKey, Result, Table, Thread, UserDataRef, @@ -409,3 +410,46 @@ fn test_conv_array() -> Result<()> { Ok(()) } + +#[test] +fn test_bstring_from_lua() -> Result<()> { + let lua = Lua::new(); + + let s = lua.create_string("hello, world")?; + let bstr = lua.unpack::(Value::String(s))?; + assert_eq!(bstr, "hello, world"); + + let bstr = lua.unpack::(Value::Integer(123))?; + assert_eq!(bstr, "123"); + + let bstr = lua.unpack::(Value::Number(-123.55))?; + assert_eq!(bstr, "-123.55"); + + // Test from stack + let f = lua.create_function(|_, bstr: BString| Ok(bstr))?; + let bstr = f.call::<_, BString>("hello, world")?; + assert_eq!(bstr, "hello, world"); + + let bstr = f.call::<_, BString>(-43.22)?; + assert_eq!(bstr, "-43.22"); + + Ok(()) +} + +#[cfg(feature = "luau")] +#[test] +fn test_bstring_from_lua_buffer() -> Result<()> { + let lua = Lua::new(); + + let b = lua.create_buffer("hello, world")?; + let bstr = lua.unpack::(Value::UserData(b))?; + assert_eq!(bstr, "hello, world"); + + // Test from stack + let f = lua.create_function(|_, bstr: BString| Ok(bstr))?; + let buf = lua.create_buffer("hello, world")?; + let bstr = f.call::<_, BString>(buf)?; + assert_eq!(bstr, "hello, world"); + + Ok(()) +} diff --git a/tests/serde.rs b/tests/serde.rs index 7e2e251a..9c287a47 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -728,3 +728,30 @@ fn test_arbitrary_precision() { "{\n [\"$serde_json::private::Number\"] = \"124.4\",\n}" ); } + +#[cfg(feature = "luau")] +#[test] +fn test_buffer_serialize() { + let lua = Lua::new(); + + let buf = lua.create_buffer(&[1, 2, 3, 4]).unwrap(); + let val = serde_value::to_value(&buf).unwrap(); + assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4])); + + // Try empty buffer + let buf = lua.create_buffer(&[]).unwrap(); + let val = serde_value::to_value(&buf).unwrap(); + assert_eq!(val, serde_value::Value::Bytes(vec![])); +} + +#[cfg(feature = "luau")] +#[test] +fn test_buffer_from_value() { + let lua = Lua::new(); + + let buf = lua.create_buffer(&[1, 2, 3, 4]).unwrap(); + let val = lua + .from_value::(Value::UserData(buf)) + .unwrap(); + assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4])); +}