Skip to content

Commit

Permalink
Better Luau buffer type support.
Browse files Browse the repository at this point in the history
- Add `Lua::create_buffer()` function
- Support serializing buffer type as a byte slice
- Support accessing copy of underlying bytes using `BString`
  • Loading branch information
khvzak committed Mar 28, 2024
1 parent b62f2ee commit fa217d3
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 5 deletions.
40 changes: 35 additions & 5 deletions src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -679,19 +679,49 @@ impl<'lua> IntoLua<'lua> for BString {
}

impl<'lua> FromLua<'lua> for BString {
#[inline]
fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result<Self> {
let ty = value.type_name();
Ok(BString::from(
lua.coerce_string(value)?
match value {
Value::String(s) => Ok(s.as_bytes().into()),
#[cfg(feature = "luau")]
Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
let mut size = 0usize;
let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size);
mlua_assert!(!buf.is_null(), "invalid Luau buffer");
Ok(slice::from_raw_parts(buf as *const u8, size).into())
},
_ => Ok(lua
.coerce_string(value)?
.ok_or_else(|| Error::FromLuaConversionError {
from: ty,
to: "BString",
message: Some("expected string or number".to_string()),
})?
.as_bytes()
.to_vec(),
))
.into()),
}
}

unsafe fn from_stack(idx: c_int, lua: &'lua Lua) -> Result<Self> {
let state = lua.state();
match ffi::lua_type(state, idx) {
ffi::LUA_TSTRING => {
let mut size = 0;
let data = ffi::lua_tolstring(state, idx, &mut size);
Ok(slice::from_raw_parts(data as *const u8, size).into())
}
#[cfg(feature = "luau")]
ffi::LUA_TBUFFER => {
let mut size = 0;
let buf = ffi::lua_tobuffer(state, idx, &mut size);
mlua_assert!(!buf.is_null(), "invalid Luau buffer");
Ok(slice::from_raw_parts(buf as *const u8, size).into())
}
_ => {
// Fallback to default
Self::from_lua(lua.stack_value(idx), lua)
}
}
}
}

Expand Down
21 changes: 21 additions & 0 deletions src/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,27 @@ impl Lua {
}
}

/// Create and return a Luau [buffer] object from a byte slice of data.
///
/// Requires `feature = "luau"`
///
/// [buffer]: https://luau-lang.org/library#buffer-library
#[cfg(feature = "luau")]
pub fn create_buffer(&self, buf: impl AsRef<[u8]>) -> Result<AnyUserData> {
let state = self.state();
unsafe {
if self.unlikely_memory_error() {
crate::util::push_buffer(self.ref_thread(), buf.as_ref(), false)?;
return Ok(AnyUserData(self.pop_ref_thread(), SubtypeId::Buffer));
}

let _sg = StackGuard::new(state);
check_stack(state, 4)?;
crate::util::push_buffer(state, buf.as_ref(), true)?;
Ok(AnyUserData(self.pop_ref(), SubtypeId::Buffer))
}
}

/// Creates and returns a new empty table.
pub fn create_table(&self) -> Result<Table> {
self.create_table_with_capacity(0, 0)
Expand Down
8 changes: 8 additions & 0 deletions src/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
Value::UserData(ud) if ud.is_serializable() => {
serde_userdata(ud, |value| value.deserialize_any(visitor))
}
#[cfg(feature = "luau")]
Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
let mut size = 0usize;
let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size);
mlua_assert!(!buf.is_null(), "invalid Luau buffer");
let buf = std::slice::from_raw_parts(buf as *const u8, size);
visitor.visit_bytes(buf)
},
Value::Function(_)
| Value::Thread(_)
| Value::UserData(_)
Expand Down
13 changes: 13 additions & 0 deletions src/userdata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,19 @@ impl<'lua> Serialize for AnyUserData<'lua> {
S: Serializer,
{
let lua = self.0.lua;

// Special case for Luau buffer type
#[cfg(feature = "luau")]
if self.1 == SubtypeId::Buffer {
let buf = unsafe {
let mut size = 0usize;
let buf = ffi::lua_tobuffer(lua.ref_thread(), self.0.index, &mut size);
mlua_assert!(!buf.is_null(), "invalid Luau buffer");
std::slice::from_raw_parts(buf as *const u8, size)
};
return serializer.serialize_bytes(buf);
}

let data = unsafe {
let _ = lua
.get_userdata_ref_type_id(&self.0)
Expand Down
14 changes: 14 additions & 0 deletions src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,20 @@ pub unsafe fn push_string(state: *mut ffi::lua_State, s: &[u8], protect: bool) -
}
}

// Uses 3 stack spaces (when protect), does not call checkstack.
#[cfg(feature = "luau")]
#[inline(always)]
pub unsafe fn push_buffer(state: *mut ffi::lua_State, b: &[u8], protect: bool) -> Result<()> {
let data = if protect {
protect_lua!(state, 0, 1, |state| ffi::lua_newbuffer(state, b.len()))?
} else {
ffi::lua_newbuffer(state, b.len())
};
let buf = slice::from_raw_parts_mut(data as *mut u8, b.len());
buf.copy_from_slice(b);
Ok(())
}

// Uses 3 stack spaces, does not call checkstack.
#[inline]
pub unsafe fn push_table(
Expand Down
44 changes: 44 additions & 0 deletions tests/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::ffi::{CStr, CString};

use bstr::BString;
use maplit::{btreemap, btreeset, hashmap, hashset};
use mlua::{
AnyUserData, Error, Function, IntoLua, Lua, RegistryKey, Result, Table, Thread, UserDataRef,
Expand Down Expand Up @@ -409,3 +410,46 @@ fn test_conv_array() -> Result<()> {

Ok(())
}

#[test]
fn test_bstring_from_lua() -> Result<()> {
let lua = Lua::new();

let s = lua.create_string("hello, world")?;
let bstr = lua.unpack::<BString>(Value::String(s))?;
assert_eq!(bstr, "hello, world");

let bstr = lua.unpack::<BString>(Value::Integer(123))?;
assert_eq!(bstr, "123");

let bstr = lua.unpack::<BString>(Value::Number(-123.55))?;
assert_eq!(bstr, "-123.55");

// Test from stack
let f = lua.create_function(|_, bstr: BString| Ok(bstr))?;
let bstr = f.call::<_, BString>("hello, world")?;
assert_eq!(bstr, "hello, world");

let bstr = f.call::<_, BString>(-43.22)?;
assert_eq!(bstr, "-43.22");

Ok(())
}

#[cfg(feature = "luau")]
#[test]
fn test_bstring_from_lua_buffer() -> Result<()> {
let lua = Lua::new();

let b = lua.create_buffer("hello, world")?;
let bstr = lua.unpack::<BString>(Value::UserData(b))?;
assert_eq!(bstr, "hello, world");

// Test from stack
let f = lua.create_function(|_, bstr: BString| Ok(bstr))?;
let buf = lua.create_buffer("hello, world")?;
let bstr = f.call::<_, BString>(buf)?;
assert_eq!(bstr, "hello, world");

Ok(())
}
27 changes: 27 additions & 0 deletions tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,30 @@ fn test_arbitrary_precision() {
"{\n [\"$serde_json::private::Number\"] = \"124.4\",\n}"
);
}

#[cfg(feature = "luau")]
#[test]
fn test_buffer_serialize() {
let lua = Lua::new();

let buf = lua.create_buffer(&[1, 2, 3, 4]).unwrap();
let val = serde_value::to_value(&buf).unwrap();
assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4]));

// Try empty buffer
let buf = lua.create_buffer(&[]).unwrap();
let val = serde_value::to_value(&buf).unwrap();
assert_eq!(val, serde_value::Value::Bytes(vec![]));
}

#[cfg(feature = "luau")]
#[test]
fn test_buffer_from_value() {
let lua = Lua::new();

let buf = lua.create_buffer(&[1, 2, 3, 4]).unwrap();
let val = lua
.from_value::<serde_value::Value>(Value::UserData(buf))
.unwrap();
assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4]));
}

0 comments on commit fa217d3

Please sign in to comment.