Skip to content

Commit

Permalink
Support options for Value::serialize() implementation
Browse files Browse the repository at this point in the history
To match `lua.from_value_with()` functionality.
  • Loading branch information
khvzak committed Aug 12, 2023
1 parent c137da7 commit 86a42a8
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 94 deletions.
105 changes: 53 additions & 52 deletions src/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::cell::RefCell;
use std::convert::TryInto;
use std::os::raw::c_void;
use std::rc::Rc;
use std::result::Result as StdResult;
use std::string::String as StdString;

use rustc_hash::FxHashSet;
Expand Down Expand Up @@ -156,10 +157,8 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
| Value::LightUserData(_)
| Value::Error(_) => {
if self.options.deny_unsupported_types {
Err(de::Error::custom(format!(
"unsupported value type `{}`",
self.value.type_name()
)))
let msg = format!("unsupported value type `{}`", self.value.type_name());
Err(de::Error::custom(msg))
} else {
visitor.visit_unit()
}
Expand Down Expand Up @@ -210,7 +209,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
&"map with a single key",
));
}
if check_value_if_skip(&value, self.options, &self.visited)? {
let skip = check_value_for_skip(&value, self.options, &self.visited)
.map_err(|err| Error::DeserializeError(err.to_string()))?;
if skip {
return Err(de::Error::custom("bad enum value"));
}

Expand Down Expand Up @@ -306,16 +307,8 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
Value::Table(t) => {
let _guard = RecursionGuard::new(&t, &self.visited);

let pairs = if self.options.sort_keys {
let mut pairs = t.pairs::<Value, Value>().collect::<Result<Vec<_>>>()?;
pairs.sort_by(|(a, _), (b, _)| b.cmp(a)); // reverse order as we pop values from the end
MapPairs::Vec(pairs)
} else {
MapPairs::Iter(t.pairs::<Value, Value>())
};

let mut deserializer = MapDeserializer {
pairs,
pairs: MapPairs::new(t, self.options.sort_keys)?,
value: None,
options: self.options,
visited: self.visited,
Expand Down Expand Up @@ -413,7 +406,9 @@ impl<'lua, 'de> de::SeqAccess<'de> for SeqDeserializer<'lua> {
match self.seq.next() {
Some(value) => {
let value = value?;
if check_value_if_skip(&value, self.options, &self.visited)? {
let skip = check_value_for_skip(&value, self.options, &self.visited)
.map_err(|err| Error::DeserializeError(err.to_string()))?;
if skip {
continue;
}
let visited = Rc::clone(&self.visited);
Expand Down Expand Up @@ -466,27 +461,48 @@ impl<'de> de::SeqAccess<'de> for VecDeserializer {
}
}

enum MapPairs<'lua> {
pub(crate) enum MapPairs<'lua> {
Iter(TablePairs<'lua, Value<'lua>, Value<'lua>>),
Vec(Vec<(Value<'lua>, Value<'lua>)>),
}

impl<'lua> MapPairs<'lua> {
fn count(self) -> usize {
pub(crate) fn new(t: Table<'lua>, sort_keys: bool) -> Result<Self> {
if sort_keys {
let mut pairs = t.pairs::<Value, Value>().collect::<Result<Vec<_>>>()?;
pairs.sort_by(|(a, _), (b, _)| b.cmp(a)); // reverse order as we pop values from the end
Ok(MapPairs::Vec(pairs))
} else {
Ok(MapPairs::Iter(t.pairs::<Value, Value>()))
}
}

pub(crate) fn count(self) -> usize {
match self {
MapPairs::Iter(iter) => iter.count(),
MapPairs::Vec(vec) => vec.len(),
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
pub(crate) fn size_hint(&self) -> (usize, Option<usize>) {
match self {
MapPairs::Iter(iter) => iter.size_hint(),
MapPairs::Vec(vec) => (vec.len(), Some(vec.len())),
}
}
}

impl<'lua> Iterator for MapPairs<'lua> {
type Item = Result<(Value<'lua>, Value<'lua>)>;

fn next(&mut self) -> Option<Self::Item> {
match self {
MapPairs::Iter(iter) => iter.next(),
MapPairs::Vec(vec) => vec.pop().map(Ok),
}
}
}

struct MapDeserializer<'lua> {
pairs: MapPairs<'lua>,
value: Option<Value<'lua>>,
Expand All @@ -503,38 +519,23 @@ impl<'lua, 'de> de::MapAccess<'de> for MapDeserializer<'lua> {
T: de::DeserializeSeed<'de>,
{
loop {
match self.pairs {
MapPairs::Iter(ref mut iter) => match iter.next() {
Some(item) => {
let (key, value) = item?;
if check_value_if_skip(&key, self.options, &self.visited)?
|| check_value_if_skip(&value, self.options, &self.visited)?
{
continue;
}
self.processed += 1;
self.value = Some(value);
let visited = Rc::clone(&self.visited);
let key_de = Deserializer::from_parts(key, self.options, visited);
return seed.deserialize(key_de).map(Some);
}
None => return Ok(None),
},
MapPairs::Vec(ref mut pairs) => match pairs.pop() {
Some((key, value)) => {
if check_value_if_skip(&key, self.options, &self.visited)?
|| check_value_if_skip(&value, self.options, &self.visited)?
{
continue;
}
self.processed += 1;
self.value = Some(value);
let visited = Rc::clone(&self.visited);
let key_de = Deserializer::from_parts(key, self.options, visited);
return seed.deserialize(key_de).map(Some);
match self.pairs.next() {
Some(item) => {
let (key, value) = item?;
let skip_key = check_value_for_skip(&key, self.options, &self.visited)
.map_err(|err| Error::DeserializeError(err.to_string()))?;
let skip_value = check_value_for_skip(&value, self.options, &self.visited)
.map_err(|err| Error::DeserializeError(err.to_string()))?;
if skip_key || skip_value {
continue;
}
None => return Ok(None),
},
self.processed += 1;
self.value = Some(value);
let visited = Rc::clone(&self.visited);
let key_de = Deserializer::from_parts(key, self.options, visited);
return seed.deserialize(key_de).map(Some);
}
None => return Ok(None),
}
}
}
Expand Down Expand Up @@ -676,17 +677,17 @@ impl Drop for RecursionGuard {
}

// Checks `options` and decides should we emit an error or skip next element
fn check_value_if_skip(
pub(crate) fn check_value_for_skip(
value: &Value,
options: Options,
visited: &RefCell<FxHashSet<*const c_void>>,
) -> Result<bool> {
) -> StdResult<bool, &'static str> {
match value {
Value::Table(table) => {
let ptr = table.to_pointer();
if visited.borrow().contains(&ptr) {
if options.deny_recursive_tables {
return Err(de::Error::custom("recursive table detected"));
return Err("recursive table detected");
}
return Ok(true); // skip
}
Expand Down
104 changes: 71 additions & 33 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use std::os::raw::c_void;
#[cfg(feature = "serialize")]
use {
rustc_hash::FxHashSet,
serde::ser::{self, Serialize, SerializeMap, SerializeSeq, Serializer},
std::{cell::RefCell, result::Result as StdResult},
serde::ser::{Serialize, SerializeMap, SerializeSeq, Serializer},
std::{cell::RefCell, rc::Rc, result::Result as StdResult},
};

use crate::error::{Error, Result};
Expand Down Expand Up @@ -1016,47 +1016,85 @@ impl<'lua> TableExt<'lua> for Table<'lua> {
}
}

/// A wrapped [`Table`] with customized serialization behavior.
#[cfg(feature = "serialize")]
pub(crate) struct SerializableTable<'a, 'lua> {
table: &'a Table<'lua>,
options: crate::serde::de::Options,
visited: Rc<RefCell<FxHashSet<*const c_void>>>,
}

#[cfg(feature = "serialize")]
impl<'lua> Serialize for Table<'lua> {
#[inline]
fn serialize<S: Serializer>(&self, serializer: S) -> StdResult<S::Ok, S::Error> {
SerializableTable::new(self, Default::default(), Default::default()).serialize(serializer)
}
}

impl<'a, 'lua> SerializableTable<'a, 'lua> {
#[inline]
pub(crate) fn new(
table: &'a Table<'lua>,
options: crate::serde::de::Options,
visited: Rc<RefCell<FxHashSet<*const c_void>>>,
) -> Self {
Self {
table,
options,
visited,
}
}
}

#[cfg(feature = "serialize")]
impl<'a, 'lua> Serialize for SerializableTable<'a, 'lua> {
fn serialize<S>(&self, serializer: S) -> StdResult<S::Ok, S::Error>
where
S: Serializer,
{
thread_local! {
static VISITED: RefCell<FxHashSet<*const c_void>> = RefCell::new(FxHashSet::default());
}

let ptr = self.to_pointer();
let res = VISITED.with(|visited| {
{
let mut visited = visited.borrow_mut();
if visited.contains(&ptr) {
return Err(ser::Error::custom("recursive table detected"));
use crate::serde::de::{check_value_for_skip, MapPairs};
use crate::value::SerializableValue;

let options = self.options;
let visited = &self.visited;
visited.borrow_mut().insert(self.table.to_pointer());

// Array
let len = self.table.raw_len();
if len > 0 || self.table.is_array() {
let mut seq = serializer.serialize_seq(Some(len))?;
for value in self.table.clone().sequence_values_by_len::<Value>(None) {
let value = &value.map_err(serde::ser::Error::custom)?;
let skip = check_value_for_skip(value, self.options, &self.visited)
.map_err(serde::ser::Error::custom)?;
if skip {
continue;
}
visited.insert(ptr);
}

let len = self.raw_len();
if len > 0 || self.is_array() {
let mut seq = serializer.serialize_seq(Some(len))?;
for v in self.clone().sequence_values_by_len::<Value>(None) {
let v = v.map_err(serde::ser::Error::custom)?;
seq.serialize_element(&v)?;
}
return seq.end();
seq.serialize_element(&SerializableValue::new(value, options, Some(visited)))?;
}
return seq.end();
}

let mut map = serializer.serialize_map(None)?;
for kv in self.clone().pairs::<Value, Value>() {
let (k, v) = kv.map_err(serde::ser::Error::custom)?;
map.serialize_entry(&k, &v)?;
// HashMap
let mut map = serializer.serialize_map(None)?;
let pairs = MapPairs::new(self.table.clone(), self.options.sort_keys)
.map_err(serde::ser::Error::custom)?;
for kv in pairs {
let (key, value) = kv.map_err(serde::ser::Error::custom)?;
let skip_key = check_value_for_skip(&key, self.options, &self.visited)
.map_err(serde::ser::Error::custom)?;
let skip_value = check_value_for_skip(&value, self.options, &self.visited)
.map_err(serde::ser::Error::custom)?;
if skip_key || skip_value {
continue;
}
map.end()
});
VISITED.with(|visited| {
visited.borrow_mut().remove(&ptr);
});
res
map.serialize_entry(
&SerializableValue::new(&key, options, Some(visited)),
&SerializableValue::new(&value, options, Some(visited)),
)?;
}
map.end()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/userdata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,7 @@ impl<'lua> AnyUserData<'lua> {
Ok(false)
}

/// Returns true if this `AnyUserData` is serializable (eg. was created using `create_ser_userdata`).
/// Returns `true` if this `AnyUserData` is serializable (eg. was created using `create_ser_userdata`).
#[cfg(feature = "serialize")]
pub(crate) fn is_serializable(&self) -> bool {
let lua = self.0.lua;
Expand Down
Loading

0 comments on commit 86a42a8

Please sign in to comment.