Skip to content

Commit

Permalink
Add SerializeOptions::detect_serde_json_arbitrary_precision to dete…
Browse files Browse the repository at this point in the history
…ct `serde_json::Number` with `arbitrary_precision` and convert it to Lua number.

By default the option is disabled and such numbers represented as Lua objects with `$serde_json::private::Number` key.
Fixes #385
  • Loading branch information
khvzak committed Mar 23, 2024
1 parent 038cc5f commit 39afe4c
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ http-body-util = "0.1.1"
reqwest = { version = "0.12", features = ["json"] }
tokio = { version = "1.0", features = ["macros", "rt", "time"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_json = { version = "1.0", features = ["arbitrary_precision"] }
maplit = "1.0"
tempfile = "3"
static_assertions = "1.0"
Expand Down
85 changes: 77 additions & 8 deletions src/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ pub struct Options {
/// [`null`]: crate::LuaSerdeExt::null
/// [`Nil`]: crate::Value::Nil
pub serialize_unit_to_null: bool,

/// If true, serialize `serde_json::Number` with arbitrary_precision to a Lua number.
/// Otherwise it will be serialized as an object (what serde does).
///
/// Default: **false**
pub detect_serde_json_arbitrary_precision: bool,
}

impl Default for Options {
Expand All @@ -58,6 +64,7 @@ impl Options {
set_array_metatable: true,
serialize_none_to_null: true,
serialize_unit_to_null: true,
detect_serde_json_arbitrary_precision: false,
}
}

Expand Down Expand Up @@ -87,6 +94,20 @@ impl Options {
self.serialize_unit_to_null = enabled;
self
}

/// Sets [`detect_serde_json_arbitrary_precision`] option.
///
/// This option is used to serialize `serde_json::Number` with arbitrary precision to a Lua number.
/// Otherwise it will be serialized as an object (what serde does).
///
/// This option is disabled by default.
///
/// [`detect_serde_json_arbitrary_precision`]: #structfield.detect_serde_json_arbitrary_precision
#[must_use]
pub const fn detect_serde_json_arbitrary_precision(mut self, enabled: bool) -> Self {
self.detect_serde_json_arbitrary_precision = enabled;
self
}
}

impl<'lua> Serializer<'lua> {
Expand Down Expand Up @@ -121,7 +142,7 @@ impl<'lua> ser::Serializer for Serializer<'lua> {
type SerializeTupleStruct = SerializeSeq<'lua>;
type SerializeTupleVariant = SerializeTupleVariant<'lua>;
type SerializeMap = SerializeMap<'lua>;
type SerializeStruct = SerializeMap<'lua>;
type SerializeStruct = SerializeStruct<'lua>;
type SerializeStructVariant = SerializeStructVariant<'lua>;

#[inline]
Expand Down Expand Up @@ -282,8 +303,23 @@ impl<'lua> ser::Serializer for Serializer<'lua> {
}

#[inline]
fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
self.serialize_map(Some(len))
fn serialize_struct(self, name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
if self.options.detect_serde_json_arbitrary_precision
&& name == "$serde_json::private::Number"
&& len == 1
{
return Ok(SerializeStruct {
lua: self.lua,
inner: None,
options: self.options,
});
}

Ok(SerializeStruct {
lua: self.lua,
inner: Some(Value::Table(self.lua.create_table_with_capacity(0, len)?)),
options: self.options,
})
}

#[inline]
Expand Down Expand Up @@ -465,20 +501,53 @@ impl<'lua> ser::SerializeMap for SerializeMap<'lua> {
}
}

impl<'lua> ser::SerializeStruct for SerializeMap<'lua> {
#[doc(hidden)]
pub struct SerializeStruct<'lua> {
lua: &'lua Lua,
inner: Option<Value<'lua>>,
options: Options,
}

impl<'lua> ser::SerializeStruct for SerializeStruct<'lua> {
type Ok = Value<'lua>;
type Error = Error;

fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
where
T: Serialize + ?Sized,
{
ser::SerializeMap::serialize_key(self, key)?;
ser::SerializeMap::serialize_value(self, value)
match self.inner {
Some(Value::Table(ref table)) => {
table.raw_set(key, self.lua.to_value_with(value, self.options)?)?;
}
None if self.options.detect_serde_json_arbitrary_precision => {
// A special case for `serde_json::Number` with arbitrary precision.
assert_eq!(key, "$serde_json::private::Number");
self.inner = Some(self.lua.to_value_with(value, self.options)?);
}
_ => unreachable!(),
}
Ok(())
}

fn end(self) -> Result<Value<'lua>> {
ser::SerializeMap::end(self)
match self.inner {
Some(table @ Value::Table(_)) => Ok(table),
Some(value) if self.options.detect_serde_json_arbitrary_precision => {
let number_s = value.as_str().expect("not an arbitrary precision number");
if number_s.contains(&['.', 'e', 'E']) {
if let Ok(number) = number_s.parse().map(Value::Number) {
return Ok(number);
}
}
Ok(number_s
.parse()
.map(Value::Integer)
.or_else(|_| number_s.parse().map(Value::Number))
.unwrap_or_else(|_| value))
}
_ => unreachable!(),
}
}
}

Expand All @@ -505,7 +574,7 @@ impl<'lua> ser::SerializeStructVariant for SerializeStructVariant<'lua> {

fn end(self) -> Result<Value<'lua>> {
let lua = self.table.0.lua;
let table = lua.create_table()?;
let table = lua.create_table_with_capacity(0, 1)?;
table.raw_set(self.name, self.table)?;
Ok(Value::Table(table))
}
Expand Down
31 changes: 31 additions & 0 deletions tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,34 @@ fn test_from_value_sorted() -> Result<(), Box<dyn StdError>> {

Ok(())
}

#[test]
fn test_arbitrary_precision() {
let lua = Lua::new();

let opts = SerializeOptions::new().detect_serde_json_arbitrary_precision(true);

// Number
let num = serde_json::Value::Number(serde_json::Number::from_f64(1.244e2).unwrap());
let num = lua.to_value_with(&num, opts).unwrap();
assert_eq!(num, Value::Number(1.244e2));

// Integer
let num = serde_json::Value::Number(serde_json::Number::from_f64(123.0).unwrap());
let num = lua.to_value_with(&num, opts).unwrap();
assert_eq!(num, Value::Integer(123));

// Max u64
let num = serde_json::Value::Number(serde_json::Number::from(i64::MAX));
let num = lua.to_value_with(&num, opts).unwrap();
assert_eq!(num, Value::Number(i64::MAX as f64));

// Check that the option is disabled by default
let num = serde_json::Value::Number(serde_json::Number::from_f64(1.244e2).unwrap());
let num = lua.to_value(&num).unwrap();
assert_eq!(num.type_name(), "table");
assert_eq!(
format!("{:#?}", num),
"{\n [\"$serde_json::private::Number\"] = \"124.4\",\n}"
);
}

0 comments on commit 39afe4c

Please sign in to comment.