Skip to content

Commit

Permalink
Implement optimized deserializing of internally tagged enums
Browse files Browse the repository at this point in the history
When tag is the first field of the map, do not use intermediate buffering
to collect all fields and instead feed data to deserialized type directly.

Fixes #1495
  • Loading branch information
Mingun committed Nov 3, 2020
1 parent aa80468 commit 84c311d
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 32 deletions.
44 changes: 42 additions & 2 deletions serde/src/private/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use de::{MapAccess, Unexpected};
pub use self::content::{
Content, ContentDeserializer, ContentRefDeserializer, EnumDeserializer,
InternallyTaggedUnitVisitor, TagContentOtherField,
TagOrContent, TagOrContentVisitor,
TagOrContentField, TagOrContentFieldVisitor, TaggedContentVisitor, UntaggedUnitVisitor,
};

Expand Down Expand Up @@ -80,6 +81,45 @@ where
Ok(None)
}

/// Creates `ContentDeserializer` by consuming map.
///
/// Used by derived code for deserialization of internally tagged enums.
/// Returns tag and constructed deserializer for fetching other enum fields.
///
/// # Parameters
/// - `map`: map that will be drained
/// - `tag_name`: name of tag in `#[serde(tag = "tag_name")]` attribute
/// - `first_key`: first key already fetched from the map
/// - `is_human_readable`: use readable or compact format for deserializer?
pub fn drain_map<'de, T, A>(
mut map: A,
tag_name: &'static str,
first_key: Content<'de>,
is_human_readable: bool,
) -> Result<(Option<T>, ContentDeserializer<'de, A::Error>), A::Error>
where
T: Deserialize<'de>,
A: MapAccess<'de>,
{
let mut tag: Option<T> = None;
let mut vec: Vec<(Content<'de>, Content<'de>)> = Vec::new();

vec.push((first_key, map.next_value()?));
while let Some(key) = map.next_key_seed(TagOrContentVisitor::new(tag_name))? {
match key {
TagOrContent::Tag => {
if tag.is_some() {
return Err(<A::Error as Error>::duplicate_field(tag_name));
}
tag = Some(map.next_value()?);
},
TagOrContent::Content(key) => vec.push((key, map.next_value()?)),
}
}

Ok((tag, ContentDeserializer::new(Content::Map(vec), is_human_readable)))
}

#[cfg(any(feature = "std", feature = "alloc"))]
pub fn borrow_cow_str<'de: 'a, 'a, D, R>(deserializer: D) -> Result<R, D::Error>
where
Expand Down Expand Up @@ -557,13 +597,13 @@ mod content {
Content(Content<'de>),
}

struct TagOrContentVisitor<'de> {
pub struct TagOrContentVisitor<'de> {
name: &'static str,
value: PhantomData<TagOrContent<'de>>,
}

impl<'de> TagOrContentVisitor<'de> {
fn new(name: &'static str) -> Self {
pub fn new(name: &'static str) -> Self {
TagOrContentVisitor {
name: name,
value: PhantomData,
Expand Down
173 changes: 145 additions & 28 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,9 @@ fn deserialize_enum(
}
}

/// Returns tuple with 2 elements:
/// - `VARIANTS` constant definition with names of all variants
/// - `Field` and `FieldVisitor` structs' definitions and `Visitor` implementation
fn prepare_enum_variant_enum(
prefix: &str,
variants: &[Variant],
Expand Down Expand Up @@ -1360,8 +1363,14 @@ fn deserialize_internally_tagged_enum(
cattrs: &attr::Container,
tag: &str,
) -> Fragment {
let this = &params.this;
let (de_impl_generics, de_ty_generics, ty_generics, where_clause) =
split_with_de_lifetime(params);
let delife = params.borrowed.de_lifetime();

let expecting = format!("internally tagged enum {}", params.type_name());

let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(prefix, variants, cattrs);
let field_struct_name = field_struct_name(prefix);

// Match arms to extract a variant from a string
let variants = variants
Expand Down Expand Up @@ -1389,21 +1398,60 @@ fn deserialize_internally_tagged_enum(
_ => None,
}
});
let variant_arms = variants
let variant_name: Vec<_> = variants.clone().map(|(i, _)| field_i(i)).collect();
let block_typed_seq = variants
.clone()
.map(|(i, variant)| {
let variant_name = field_i(i);
let visitor = make_visitor(&format!("Variant{}", i), params, cattrs);

let block = Match(deserialize_internally_tagged_variant(
&format!("Variant{}", i),
Match(deserialize_internally_tagged_variant(
params,
variant,
cattrs,
quote!(__deserializer),
));
quote!(_serde::de::value::SeqAccessDeserializer::with_representation(__seq, self.is_human_readable)),
None,
quote_block!(_serde::de::Visitor::visit_seq(#visitor, __seq)),
))
});
let block_typed_map = variants
.clone()
.map(|(i, variant)| {
let visitor = make_visitor(&format!("Variant{}", i), params, cattrs);

quote! {
#field_struct_name::#variant_name => #block
}
Match(deserialize_internally_tagged_variant(
params,
variant,
cattrs,
quote!(_serde::de::value::MapAccessDeserializer::with_representation(__map, self.is_human_readable)),
Some(quote! {
while try!(_serde::de::MapAccess::next_entry::<
_serde::de::IgnoredAny,
_serde::de::IgnoredAny,
>(&mut __map)).is_some() {}
}),
quote_block!(_serde::de::Visitor::visit_map(#visitor, __map)),
))
});
let block_other_map = variants
.clone()
.map(|(i, variant)| {
let visitor = make_visitor(&format!("Variant{}", i), params, cattrs);
let type_name = params.type_name();
let variant_name = variant.ident.to_string();

Match(deserialize_internally_tagged_variant(
params,
variant,
cattrs,
quote!(__deserializer),
Some(quote! {
try!(_serde::Deserializer::deserialize_any(
__deserializer,
_serde::private::de::InternallyTaggedUnitVisitor::new(#type_name, #variant_name)
));
}),
quote_block!(_serde::Deserializer::deserialize_any(__deserializer, #visitor)),
))
});

quote_block! {
Expand All @@ -1413,15 +1461,67 @@ fn deserialize_internally_tagged_enum(

#variants_stmt

let __form = _serde::Deserializer::is_human_readable(&__deserializer);
let __tagged = try!(_serde::Deserializer::deserialize_any(
__deserializer,
_serde::private::de::TaggedContentVisitor::<#field_struct_name>::new(#tag, __form)));
let __deserializer = _serde::private::de::ContentDeserializer::<__D::Error>::new(__tagged.content, __form);
struct __Visitor #de_impl_generics #where_clause {
is_human_readable: bool,
marker: _serde::export::PhantomData<#this #ty_generics>,
lifetime: _serde::export::PhantomData<&#delife ()>,
}

impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause {
type Value = #this #ty_generics;

fn expecting(&self, __formatter: &mut _serde::export::Formatter) -> _serde::export::fmt::Result {
_serde::export::Formatter::write_str(__formatter, #expecting)
}

fn visit_seq<__A>(self, mut __seq: __A) -> _serde::export::Result<Self::Value, __A::Error>
where
__A: _serde::de::SeqAccess<#delife>,
{
match try!(_serde::de::SeqAccess::next_element(&mut __seq)) {
#(_serde::export::Some(__Field::#variant_name) => #block_typed_seq)*
_serde::export::None => _serde::export::Err(<__A::Error as _serde::de::Error>::missing_field(#tag)),
}
}

match __tagged.tag {
#(#variant_arms)*
fn visit_map<__A>(self, mut __map: __A) -> _serde::export::Result<Self::Value, __A::Error>
where
__A: _serde::de::MapAccess<#delife>,
{
// Read the first field. If it is a tag, immediately deserialize the typed data.
// Otherwise, we collect everything until we find the tag, and then deserialize
// using ContentDeserializer.
match try!(_serde::de::MapAccess::next_key_seed(
&mut __map, _serde::private::de::TagOrContentVisitor::new(#tag)
)) {
_serde::export::Some(_serde::private::de::TagOrContent::Tag) => {
match try!(_serde::de::MapAccess::next_value(&mut __map)) {
#(__Field::#variant_name => #block_typed_map)*
}
},
_serde::export::Some(_serde::private::de::TagOrContent::Content(__key)) => {
// Drain map to Content::Map, convert it to ContentDeserializer
// Special handling for tag key -- search them and return as separate result
let (__tag, __deserializer) = try!(_serde::private::de::drain_map(
__map, #tag, __key, self.is_human_readable
));

match __tag {
#(_serde::export::Some(__Field::#variant_name) => #block_other_map)*
_serde::export::None => _serde::export::Err(<__A::Error as _serde::de::Error>::missing_field(#tag)),
}
},
_serde::export::None => _serde::export::Err(<__A::Error as _serde::de::Error>::missing_field(#tag)),
}
}
}

let __visitor = __Visitor {
is_human_readable: _serde::Deserializer::is_human_readable(&__deserializer),
marker: _serde::export::PhantomData::<#this #ty_generics>,
lifetime: _serde::export::PhantomData,
};
_serde::Deserializer::deserialize_any(__deserializer, __visitor)
}
}

Expand Down Expand Up @@ -1792,11 +1892,12 @@ fn deserialize_externally_tagged_variant(
}

fn deserialize_internally_tagged_variant(
prefix: &str,
params: &Parameters,
variant: &Variant,
cattrs: &attr::Container,
deserializer: TokenStream,
unit_skip: Option<TokenStream>,
struct_arm: Fragment,
) -> Fragment {
if variant.attrs.deserialize_with().is_some() {
return deserialize_untagged_variant(params, variant, cattrs, deserializer);
Expand All @@ -1807,14 +1908,12 @@ fn deserialize_internally_tagged_variant(
match effective_style(variant) {
Style::Unit => {
let this = &params.this;
let type_name = params.type_name();
let variant_name = variant.ident.to_string();
let default = variant.fields.get(0).map(|field| {
let default = Expr(expr_is_missing(field, cattrs));
quote!((#default))
});
quote_block! {
try!(_serde::Deserializer::deserialize_any(#deserializer, _serde::private::de::InternallyTaggedUnitVisitor::new(#type_name, #variant_name)));
#unit_skip
_serde::export::Ok(#this::#variant_ident #default)
}
}
Expand All @@ -1824,13 +1923,7 @@ fn deserialize_internally_tagged_variant(
&variant.fields[0],
&deserializer,
),
Style::Struct => Fragment::Block(deserialize_struct_dispatch(
prefix,
params,
cattrs,
quote!(_serde::Deserializer::deserialize_any(#deserializer, __visitor)),
quote!(__form),
)),
Style::Struct => struct_arm,
Style::Tuple => unreachable!("checked in serde_derive_internals"),
}
}
Expand Down Expand Up @@ -2866,6 +2959,30 @@ fn visitor_struct_name(prefix: &str) -> Ident {
Ident::new(&format!("__{}Visitor", prefix), Span::call_site())
}

fn make_visitor(
prefix: &str,
params: &Parameters,
cattrs: &attr::Container,
) -> TokenStream {
let this = &params.this;
let (_, _, ty_generics, _) = split_with_de_lifetime(params);

let name = visitor_struct_name(prefix);
let form_init = if cattrs.has_flatten() {
Some(quote!(is_human_readable: self.is_human_readable,))
} else {
None
};

quote! {
#name {
#form_init
marker: _serde::export::PhantomData::<#this #ty_generics>,
lifetime: _serde::export::PhantomData,
}
}
}

/// This function wraps the expression in `#[serde(deserialize_with = "...")]`
/// in a trait to prevent it from accessing the internal `Deserialize` state.
fn wrap_deserialize_with(
Expand Down
2 changes: 0 additions & 2 deletions test_suite/tests/test_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,6 @@ fn test_internally_tagged_enum() {
Token::Seq { len: Some(2) },
Token::Str("C"),
Token::Map { len: Some(0) },
Token::MapEnd,
Token::SeqEnd,
],
"invalid type: sequence, expected a map",
);
Expand Down

0 comments on commit 84c311d

Please sign in to comment.