diff --git a/src/serde/de.rs b/src/serde/de.rs index feb8f81f..b82205a0 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -42,6 +42,11 @@ pub struct Options { /// /// Default: **true** pub deny_recursive_tables: bool, + + /// If true, keys in tables will be iterated in sorted order. + /// + /// Default: **false** + pub sort_keys: bool, } impl Default for Options { @@ -56,6 +61,7 @@ impl Options { Options { deny_unsupported_types: true, deny_recursive_tables: true, + sort_keys: false, } } @@ -76,6 +82,15 @@ impl Options { self.deny_recursive_tables = enabled; self } + + /// Sets [`sort_keys`] option. + /// + /// [`sort_keys`]: #structfield.sort_keys + #[must_use] + pub const fn sort_keys(mut self, enabled: bool) -> Self { + self.sort_keys = enabled; + self + } } impl<'lua> Deserializer<'lua> { @@ -291,8 +306,16 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { Value::Table(t) => { let _guard = RecursionGuard::new(&t, &self.visited); + let pairs = if self.options.sort_keys { + let mut pairs = t.pairs::().collect::>>()?; + pairs.sort_by(|(a, _), (b, _)| b.cmp(a)); // reverse order as we pop values from the end + MapPairs::Vec(pairs) + } else { + MapPairs::Iter(t.pairs::()) + }; + let mut deserializer = MapDeserializer { - pairs: t.pairs(), + pairs, value: None, options: self.options, visited: self.visited, @@ -443,8 +466,29 @@ impl<'de> de::SeqAccess<'de> for VecDeserializer { } } +enum MapPairs<'lua> { + Iter(TablePairs<'lua, Value<'lua>, Value<'lua>>), + Vec(Vec<(Value<'lua>, Value<'lua>)>), +} + +impl<'lua> MapPairs<'lua> { + fn count(self) -> usize { + match self { + MapPairs::Iter(iter) => iter.count(), + MapPairs::Vec(vec) => vec.len(), + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + MapPairs::Iter(iter) => iter.size_hint(), + MapPairs::Vec(vec) => (vec.len(), Some(vec.len())), + } + } +} + struct MapDeserializer<'lua> { - pairs: TablePairs<'lua, Value<'lua>, Value<'lua>>, + pairs: MapPairs<'lua>, value: Option>, options: Options, visited: Rc>>, @@ -459,21 +503,38 @@ impl<'lua, 'de> de::MapAccess<'de> for MapDeserializer<'lua> { T: de::DeserializeSeed<'de>, { loop { - match self.pairs.next() { - Some(item) => { - let (key, value) = item?; - if check_value_if_skip(&key, self.options, &self.visited)? - || check_value_if_skip(&value, self.options, &self.visited)? - { - continue; + match self.pairs { + MapPairs::Iter(ref mut iter) => match iter.next() { + Some(item) => { + let (key, value) = item?; + if check_value_if_skip(&key, self.options, &self.visited)? + || check_value_if_skip(&value, self.options, &self.visited)? + { + continue; + } + self.processed += 1; + self.value = Some(value); + let visited = Rc::clone(&self.visited); + let key_de = Deserializer::from_parts(key, self.options, visited); + return seed.deserialize(key_de).map(Some); } - self.processed += 1; - self.value = Some(value); - let visited = Rc::clone(&self.visited); - let key_de = Deserializer::from_parts(key, self.options, visited); - return seed.deserialize(key_de).map(Some); - } - None => return Ok(None), + None => return Ok(None), + }, + MapPairs::Vec(ref mut pairs) => match pairs.pop() { + Some((key, value)) => { + if check_value_if_skip(&key, self.options, &self.visited)? + || check_value_if_skip(&value, self.options, &self.visited)? + { + continue; + } + self.processed += 1; + self.value = Some(value); + let visited = Rc::clone(&self.visited); + let key_de = Deserializer::from_parts(key, self.options, visited); + return seed.deserialize(key_de).map(Some); + } + None => return Ok(None), + }, } } } diff --git a/tests/serde.rs b/tests/serde.rs index 660c83ca..f4b8b1fd 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -4,8 +4,8 @@ use std::collections::HashMap; use std::error::Error as StdError; use mlua::{ - DeserializeOptions, Error, Lua, LuaSerdeExt, Result as LuaResult, SerializeOptions, UserData, - Value, + DeserializeOptions, Error, ExternalResult, Lua, LuaSerdeExt, Result as LuaResult, + SerializeOptions, UserData, Value, }; use serde::{Deserialize, Serialize}; @@ -611,3 +611,26 @@ fn test_from_value_userdata() -> Result<(), Box> { Ok(()) } + +#[test] +fn test_from_value_sorted() -> Result<(), Box> { + let lua = Lua::new(); + + let to_json = lua.create_function(|lua, value| { + let json_value: serde_json::Value = + lua.from_value_with(value, DeserializeOptions::new().sort_keys(true))?; + serde_json::to_string(&json_value).into_lua_err() + })?; + lua.globals().set("to_json", to_json)?; + + lua.load( + r#" + local json = to_json({c = 3, b = 2, hello = "world", x = {1}, ["0a"] = {z = "z", d = "d"}}) + assert(json == '{"0a":{"d":"d","z":"z"},"b":2,"c":3,"hello":"world","x":[1]}', "invalid json") + "#, + ) + .exec() + .unwrap(); + + Ok(()) +}