Skip to content

Commit

Permalink
Add new option sort_keys to DeserializeOptions (Lua::from_value
Browse files Browse the repository at this point in the history
… method)

Closes #303
  • Loading branch information
khvzak committed Aug 7, 2023
1 parent 0cb0a34 commit c0c6a33
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 18 deletions.
93 changes: 77 additions & 16 deletions src/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ pub struct Options {
///
/// Default: **true**
pub deny_recursive_tables: bool,

/// If true, keys in tables will be iterated in sorted order.
///
/// Default: **false**
pub sort_keys: bool,
}

impl Default for Options {
Expand All @@ -56,6 +61,7 @@ impl Options {
Options {
deny_unsupported_types: true,
deny_recursive_tables: true,
sort_keys: false,
}
}

Expand All @@ -76,6 +82,15 @@ impl Options {
self.deny_recursive_tables = enabled;
self
}

/// Sets [`sort_keys`] option.
///
/// [`sort_keys`]: #structfield.sort_keys
#[must_use]
pub const fn sort_keys(mut self, enabled: bool) -> Self {
self.sort_keys = enabled;
self
}
}

impl<'lua> Deserializer<'lua> {
Expand Down Expand Up @@ -291,8 +306,16 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
Value::Table(t) => {
let _guard = RecursionGuard::new(&t, &self.visited);

let pairs = if self.options.sort_keys {
let mut pairs = t.pairs::<Value, Value>().collect::<Result<Vec<_>>>()?;
pairs.sort_by(|(a, _), (b, _)| b.cmp(a)); // reverse order as we pop values from the end
MapPairs::Vec(pairs)
} else {
MapPairs::Iter(t.pairs::<Value, Value>())
};

let mut deserializer = MapDeserializer {
pairs: t.pairs(),
pairs,
value: None,
options: self.options,
visited: self.visited,
Expand Down Expand Up @@ -443,8 +466,29 @@ impl<'de> de::SeqAccess<'de> for VecDeserializer {
}
}

enum MapPairs<'lua> {
Iter(TablePairs<'lua, Value<'lua>, Value<'lua>>),
Vec(Vec<(Value<'lua>, Value<'lua>)>),
}

impl<'lua> MapPairs<'lua> {
fn count(self) -> usize {
match self {
MapPairs::Iter(iter) => iter.count(),
MapPairs::Vec(vec) => vec.len(),
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
match self {
MapPairs::Iter(iter) => iter.size_hint(),
MapPairs::Vec(vec) => (vec.len(), Some(vec.len())),
}
}
}

struct MapDeserializer<'lua> {
pairs: TablePairs<'lua, Value<'lua>, Value<'lua>>,
pairs: MapPairs<'lua>,
value: Option<Value<'lua>>,
options: Options,
visited: Rc<RefCell<FxHashSet<*const c_void>>>,
Expand All @@ -459,21 +503,38 @@ impl<'lua, 'de> de::MapAccess<'de> for MapDeserializer<'lua> {
T: de::DeserializeSeed<'de>,
{
loop {
match self.pairs.next() {
Some(item) => {
let (key, value) = item?;
if check_value_if_skip(&key, self.options, &self.visited)?
|| check_value_if_skip(&value, self.options, &self.visited)?
{
continue;
match self.pairs {
MapPairs::Iter(ref mut iter) => match iter.next() {
Some(item) => {
let (key, value) = item?;
if check_value_if_skip(&key, self.options, &self.visited)?
|| check_value_if_skip(&value, self.options, &self.visited)?
{
continue;
}
self.processed += 1;
self.value = Some(value);
let visited = Rc::clone(&self.visited);
let key_de = Deserializer::from_parts(key, self.options, visited);
return seed.deserialize(key_de).map(Some);
}
self.processed += 1;
self.value = Some(value);
let visited = Rc::clone(&self.visited);
let key_de = Deserializer::from_parts(key, self.options, visited);
return seed.deserialize(key_de).map(Some);
}
None => return Ok(None),
None => return Ok(None),
},
MapPairs::Vec(ref mut pairs) => match pairs.pop() {
Some((key, value)) => {
if check_value_if_skip(&key, self.options, &self.visited)?
|| check_value_if_skip(&value, self.options, &self.visited)?
{
continue;
}
self.processed += 1;
self.value = Some(value);
let visited = Rc::clone(&self.visited);
let key_de = Deserializer::from_parts(key, self.options, visited);
return seed.deserialize(key_de).map(Some);
}
None => return Ok(None),
},
}
}
}
Expand Down
27 changes: 25 additions & 2 deletions tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::collections::HashMap;
use std::error::Error as StdError;

use mlua::{
DeserializeOptions, Error, Lua, LuaSerdeExt, Result as LuaResult, SerializeOptions, UserData,
Value,
DeserializeOptions, Error, ExternalResult, Lua, LuaSerdeExt, Result as LuaResult,
SerializeOptions, UserData, Value,
};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -611,3 +611,26 @@ fn test_from_value_userdata() -> Result<(), Box<dyn StdError>> {

Ok(())
}

#[test]
fn test_from_value_sorted() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();

let to_json = lua.create_function(|lua, value| {
let json_value: serde_json::Value =
lua.from_value_with(value, DeserializeOptions::new().sort_keys(true))?;
serde_json::to_string(&json_value).into_lua_err()
})?;
lua.globals().set("to_json", to_json)?;

lua.load(
r#"
local json = to_json({c = 3, b = 2, hello = "world", x = {1}, ["0a"] = {z = "z", d = "d"}})
assert(json == '{"0a":{"d":"d","z":"z"},"b":2,"c":3,"hello":"world","x":[1]}', "invalid json")
"#,
)
.exec()
.unwrap();

Ok(())
}

0 comments on commit c0c6a33

Please sign in to comment.