Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return visitor errors over skip_decode errors if there are any #58

Merged
merged 6 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions scale-decode/src/visitor/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ impl<'a, 'scale, 'resolver, V: Visitor> Decoder<'a, 'scale, 'resolver, V> {
}
}

// Our types like Composite/Variant/Sequence/Array/Tuple all use the same
// approach to skip over any bytes that the visitor didn't consume, so this
// macro performs that logic.
macro_rules! skip_decoding_and_return {
($self:ident, $visit_result:ident, $visitor_ty:ident) => {{
// Skip over any bytes that the visitor chose not to decode:
let skip_res = $visitor_ty.skip_decoding();
if skip_res.is_ok() {
*$self.data = $visitor_ty.bytes_from_undecoded();
}

// Prioritize returning visitor errors over skip_decoding errors.
match ($visit_result, skip_res) {
(Err(e), _) => Err(e),
(_, Err(e)) => Err(e.into()),
(Ok(v), _) => Ok(v),
}
}};
}

impl<'temp, 'scale, 'resolver, V: Visitor> ResolvedTypeVisitor<'resolver>
for Decoder<'temp, 'scale, 'resolver, V>
{
Expand Down Expand Up @@ -122,11 +142,7 @@ impl<'temp, 'scale, 'resolver, V: Visitor> ResolvedTypeVisitor<'resolver>
let mut items = Composite::new(path, self.data, &mut fields, self.types, self.is_compact);
let res = self.visitor.visit_composite(&mut items, self.type_id);

// Skip over any bytes that the visitor chose not to decode:
items.skip_decoding()?;
*self.data = items.bytes_from_undecoded();

res
skip_decoding_and_return!(self, res, items)
}

fn visit_variant<Path, Fields, Var>(self, _path: Path, variants: Var) -> Self::Value
Expand All @@ -142,11 +158,7 @@ impl<'temp, 'scale, 'resolver, V: Visitor> ResolvedTypeVisitor<'resolver>
let mut variant = Variant::new(self.data, variants, self.types)?;
let res = self.visitor.visit_variant(&mut variant, self.type_id);

// Skip over any bytes that the visitor chose not to decode:
variant.skip_decoding()?;
*self.data = variant.bytes_from_undecoded();

res
skip_decoding_and_return!(self, res, variant)
}

fn visit_sequence<Path>(self, _path: Path, inner_type_id: Self::TypeId) -> Self::Value
Expand All @@ -160,11 +172,7 @@ impl<'temp, 'scale, 'resolver, V: Visitor> ResolvedTypeVisitor<'resolver>
let mut items = Sequence::new(self.data, inner_type_id, self.types)?;
let res = self.visitor.visit_sequence(&mut items, self.type_id);

// Skip over any bytes that the visitor chose not to decode:
items.skip_decoding()?;
*self.data = items.bytes_from_undecoded();

res
skip_decoding_and_return!(self, res, items)
}

fn visit_array(self, inner_type_id: Self::TypeId, len: usize) -> Self::Value {
Expand All @@ -175,11 +183,7 @@ impl<'temp, 'scale, 'resolver, V: Visitor> ResolvedTypeVisitor<'resolver>
let mut arr = Array::new(self.data, inner_type_id, len, self.types);
let res = self.visitor.visit_array(&mut arr, self.type_id);

// Skip over any bytes that the visitor chose not to decode:
arr.skip_decoding()?;
*self.data = arr.bytes_from_undecoded();

res
skip_decoding_and_return!(self, res, arr)
}

fn visit_tuple<TypeIds>(self, type_ids: TypeIds) -> Self::Value
Expand All @@ -195,11 +199,7 @@ impl<'temp, 'scale, 'resolver, V: Visitor> ResolvedTypeVisitor<'resolver>
let mut items = Tuple::new(self.data, &mut fields, self.types, self.is_compact);
let res = self.visitor.visit_tuple(&mut items, self.type_id);

// Skip over any bytes that the visitor chose not to decode:
items.skip_decoding()?;
*self.data = items.bytes_from_undecoded();

res
skip_decoding_and_return!(self, res, items)
}

fn visit_primitive(self, primitive: Primitive) -> Self::Value {
Expand Down
87 changes: 87 additions & 0 deletions scale-decode/src/visitor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,93 @@ mod test {
);
}

// We want to make sure that if the visitor returns an error, then that error is propagated
// up to the user. with some types (Sequence/Composite/Tuple/Array/Variant), we skip over
// undecoded items after the visitor runs, and want to ensure that any error skipping over
// things doesn't mask any visitor error.
//
// These tests all fail prior to https://github.com/paritytech/scale-decode/pull/58 and pass
// after it.
macro_rules! decoding_returns_error_first {
($name:ident $expr:expr) => {
#[test]
fn $name() {
fn visitor_err() -> DecodeError {
DecodeError::TypeResolvingError("Whoops".to_string())
}

#[derive(codec::Encode)]
struct HasBadTypeInfo;
impl scale_info::TypeInfo for HasBadTypeInfo {
type Identity = Self;
fn type_info() -> scale_info::Type {
// The actual struct is zero bytes but the type info says it is 1 byte,
// so using type info to decode it will lead to failures.
scale_info::meta_type::<u8>().type_info()
}
}

struct VisitorImpl;
impl Visitor for VisitorImpl {
type Value<'scale, 'resolver> = ();
type Error = DecodeError;
type TypeResolver = PortableRegistry;

fn visit_unexpected<'scale, 'resolver>(
self,
_unexpected: Unexpected,
) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
// Our visitor just returns a specific error, so we can check that
// we get it back when trying to decode.
Err(visitor_err())
}
}

fn assert_visitor_err<E: codec::Encode + scale_info::TypeInfo + 'static>(input: E) {
let input_encoded = input.encode();
let (ty_id, types) = make_type::<E>();
let err = decode_with_visitor(&mut &*input_encoded, ty_id, &types, VisitorImpl)
.unwrap_err();
assert_eq!(err, visitor_err());
}

assert_visitor_err($expr);
}
};
}

decoding_returns_error_first!(decode_composite_returns_error_first {
#[derive(codec::Encode, scale_info::TypeInfo)]
struct SomeComposite {
a: bool,
b: HasBadTypeInfo,
c: Vec<u8>
}

SomeComposite { a: true, b: HasBadTypeInfo, c: vec![1,2,3] }
});

decoding_returns_error_first!(decode_variant_returns_error_first {
#[derive(codec::Encode, scale_info::TypeInfo)]
enum SomeVariant {
Foo(u32, HasBadTypeInfo, String)
}

SomeVariant::Foo(32, HasBadTypeInfo, "hi".to_owned())
});

decoding_returns_error_first!(decode_array_returns_error_first {
[HasBadTypeInfo, HasBadTypeInfo]
});

decoding_returns_error_first!(decode_sequence_returns_error_first {
vec![HasBadTypeInfo, HasBadTypeInfo]
});

decoding_returns_error_first!(decode_tuple_returns_error_first {
(32u64, HasBadTypeInfo, true)
});

#[test]
fn zero_copy_string_decoding() {
let input = ("hello", "world");
Expand Down
Loading