Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hardcode error type in IntoVisitor #41

Merged
merged 9 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
/Cargo.lock
.DS_Store
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"scale-decode-derive",
"testing/no_std",
]
resolver = "2"

[workspace.package]
version = "0.9.0"
Expand Down
7 changes: 7 additions & 0 deletions scale-decode/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ impl From<DecodeError> for Error {
}
}

impl From<codec::Error> for Error {
fn from(err: codec::Error) -> Error {
let err: DecodeError = err.into();
Error::new(err.into())
}
}

/// The underlying nature of the error.
#[derive(Debug, derive_more::From, derive_more::Display)]
pub enum ErrorKind {
Expand Down
53 changes: 13 additions & 40 deletions scale-decode/src/impls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ macro_rules! impl_decode_seq_via_collect {
impl <$generic> Visitor for BasicVisitor<$ty<$generic>>
where
$generic: IntoVisitor,
Error: From<<$generic::Visitor as Visitor>::Error>,
$( $($where)* )?
{
type Value<'scale, 'info> = $ty<$generic>;
Expand Down Expand Up @@ -306,11 +305,7 @@ macro_rules! array_method_impl {
Ok(arr)
}};
}
impl<const N: usize, T> Visitor for BasicVisitor<[T; N]>
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<const N: usize, T: IntoVisitor> Visitor for BasicVisitor<[T; N]> {
type Value<'scale, 'info> = [T; N];
type Error = Error;

Expand All @@ -331,22 +326,14 @@ where

visit_single_field_composite_tuple_impls!();
}
impl<const N: usize, T> IntoVisitor for [T; N]
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<const N: usize, T: IntoVisitor> IntoVisitor for [T; N] {
type Visitor = BasicVisitor<[T; N]>;
fn into_visitor() -> Self::Visitor {
BasicVisitor { _marker: core::marker::PhantomData }
}
}

impl<T> Visitor for BasicVisitor<BTreeMap<String, T>>
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<T: IntoVisitor> Visitor for BasicVisitor<BTreeMap<String, T>> {
type Error = Error;
type Value<'scale, 'info> = BTreeMap<String, T>;

Expand All @@ -365,19 +352,15 @@ where
// Decode the value now that we have a valid name.
let Some(val) = value.decode_item(T::into_visitor()) else { break };
// Save to the map.
let val = val.map_err(|e| Error::from(e).at_field(key.to_owned()))?;
let val = val.map_err(|e| e.at_field(key.to_owned()))?;
map.insert(key.to_owned(), val);
}
Ok(map)
}
}
impl_into_visitor!(BTreeMap<String, T>);

impl<T> Visitor for BasicVisitor<Option<T>>
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<T: IntoVisitor> Visitor for BasicVisitor<Option<T>> {
type Error = Error;
type Value<'scale, 'info> = Option<T>;

Expand All @@ -391,7 +374,7 @@ where
.fields()
.decode_item(T::into_visitor())
.transpose()
.map_err(|e| Error::from(e).at_variant("Some"))?
.map_err(|e| e.at_variant("Some"))?
.expect("checked for 1 field already so should be ok");
Ok(Some(val))
} else if value.name() == "None" && value.fields().remaining() == 0 {
Expand All @@ -407,13 +390,7 @@ where
}
impl_into_visitor!(Option<T>);

impl<T, E> Visitor for BasicVisitor<Result<T, E>>
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
E: IntoVisitor,
Error: From<<E::Visitor as Visitor>::Error>,
{
impl<T: IntoVisitor, E: IntoVisitor> Visitor for BasicVisitor<Result<T, E>> {
type Error = Error;
type Value<'scale, 'info> = Result<T, E>;

Expand All @@ -427,15 +404,15 @@ where
.fields()
.decode_item(T::into_visitor())
.transpose()
.map_err(|e| Error::from(e).at_variant("Ok"))?
.map_err(|e| e.at_variant("Ok"))?
.expect("checked for 1 field already so should be ok");
Ok(Ok(val))
} else if value.name() == "Err" && value.fields().remaining() == 1 {
let val = value
.fields()
.decode_item(E::into_visitor())
.transpose()
.map_err(|e| Error::from(e).at_variant("Err"))?
.map_err(|e| e.at_variant("Err"))?
.expect("checked for 1 field already so should be ok");
Ok(Err(val))
} else {
Expand Down Expand Up @@ -541,7 +518,7 @@ macro_rules! tuple_method_impl {
let v = $value
.decode_item($t::into_visitor())
.transpose()
.map_err(|e| Error::from(e).at_idx(idx))?
.map_err(|e| e.at_idx(idx))?
.expect("length already checked via .remaining()");
idx += 1;
v
Expand Down Expand Up @@ -593,7 +570,6 @@ macro_rules! impl_decode_tuple {
impl < $($t),* > Visitor for BasicVisitor<($($t,)*)>
where $(
$t: IntoVisitor,
Error: From<<$t::Visitor as Visitor>::Error>,
)*
{
type Value<'scale, 'info> = ($($t,)*);
Expand Down Expand Up @@ -621,7 +597,7 @@ macro_rules! impl_decode_tuple {

// We can turn this tuple into a visitor which knows how to decode it:
impl < $($t),* > IntoVisitor for ($($t,)*)
where $( $t: IntoVisitor, Error: From<<$t::Visitor as Visitor>::Error>, )*
where $( $t: IntoVisitor, )*
{
type Visitor = BasicVisitor<($($t,)*)>;
fn into_visitor() -> Self::Visitor {
Expand All @@ -631,7 +607,7 @@ macro_rules! impl_decode_tuple {

// We can decode given a list of fields (just delegate to the visitor impl:
impl < $($t),* > DecodeAsFields for ($($t,)*)
where $( $t: IntoVisitor, Error: From<<$t::Visitor as Visitor>::Error>, )*
where $( $t: IntoVisitor, )*
{
fn decode_as_fields<'info>(input: &mut &[u8], fields: &mut dyn FieldIter<'info>, types: &'info scale_info::PortableRegistry) -> Result<Self, Error> {
let mut composite = crate::visitor::types::Composite::new(input, crate::EMPTY_SCALE_INFO_PATH, fields, types, false);
Expand Down Expand Up @@ -676,14 +652,11 @@ fn decode_items_using<'a, 'scale, 'info, D: DecodeItemIterator<'scale, 'info>, T
) -> impl Iterator<Item = Result<T, Error>> + 'a
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
D: DecodeItemIterator<'scale, 'info>,
{
let mut idx = 0;
core::iter::from_fn(move || {
let item = decoder
.decode_item(T::into_visitor())
.map(|res| res.map_err(|e| Error::from(e).at_idx(idx)));
let item = decoder.decode_item(T::into_visitor()).map(|res| res.map_err(|e| e.at_idx(idx)));
idx += 1;
item
})
Expand Down
19 changes: 10 additions & 9 deletions scale-decode/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ use alloc::vec::Vec;
/// This trait is implemented for any type `T` where `T` implements [`IntoVisitor`] and the errors returned
/// from this [`Visitor`] can be converted into [`Error`]. It's essentially a convenience wrapper around
/// [`visitor::decode_with_visitor`] that mirrors `scale-encode`'s `EncodeAsType`.
pub trait DecodeAsType: Sized {
pub trait DecodeAsType: Sized + IntoVisitor {
/// Given some input bytes, a `type_id`, and type registry, attempt to decode said bytes into
/// `Self`. Implementations should modify the `&mut` reference to the bytes such that any bytes
/// not used in the course of decoding are still pointed to after decoding is complete.
Expand All @@ -192,11 +192,7 @@ pub trait DecodeAsType: Sized {
) -> Result<Self, Error>;
}

impl<T> DecodeAsType for T
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<T: Sized + IntoVisitor> DecodeAsType for T {
fn decode_as_type_maybe_compact(
input: &mut &[u8],
type_id: u32,
Expand Down Expand Up @@ -267,11 +263,16 @@ pub trait FieldIter<'a>: Iterator<Item = Field<'a>> {}
impl<'a, T> FieldIter<'a> for T where T: Iterator<Item = Field<'a>> {}

/// This trait can be implemented on any type that has an associated [`Visitor`] responsible for decoding
/// SCALE encoded bytes to it. If you implement this on some type and the [`Visitor`] that you return has
/// an error type that converts into [`Error`], then you'll also get a [`DecodeAsType`] implementation for free.
/// SCALE encoded bytes to it whose error type is [`Error`]. Anything that implements this trait gets a
/// [`DecodeAsType`] implementation for free.
// Dev note: This used to allow for any Error type that could be converted into `scale_decode::Error`.
// The problem with this is that the `DecodeAsType` trait became tricky to use in some contexts, because it
// didn't automatically imply so much. Realistically, being stricter here shouldn't matter too much; derive
// impls all use `scale_decode::Error` anyway, and manual impls can just manually convert into the error
// rather than rely on auto conversion, if they care about also being able to impl `DecodeAsType`.
pub trait IntoVisitor {
/// The visitor type used to decode SCALE encoded bytes to `Self`.
type Visitor: for<'scale, 'info> visitor::Visitor<Value<'scale, 'info> = Self>;
type Visitor: for<'scale, 'info> visitor::Visitor<Value<'scale, 'info> = Self, Error = Error>;
/// A means of obtaining this visitor.
fn into_visitor() -> Self::Visitor;
}
Expand Down
110 changes: 105 additions & 5 deletions scale-decode/src/visitor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,26 @@ pub enum DecodeAsTypeResult<V, R> {
Decoded(R),
}

impl<V, R> DecodeAsTypeResult<V, R> {
/// If we have a [`DecodeAsTypeResult::Decoded`], the function provided will
/// map this decoded result to whatever it returns.
pub fn map_decoded<T, F: FnOnce(R) -> T>(self, f: F) -> DecodeAsTypeResult<V, T> {
match self {
DecodeAsTypeResult::Skipped(s) => DecodeAsTypeResult::Skipped(s),
DecodeAsTypeResult::Decoded(r) => DecodeAsTypeResult::Decoded(f(r)),
}
}

/// If we have a [`DecodeAsTypeResult::Skipped`], the function provided will
/// map this skipped value to whatever it returns.
pub fn map_skipped<T, F: FnOnce(V) -> T>(self, f: F) -> DecodeAsTypeResult<T, R> {
match self {
DecodeAsTypeResult::Skipped(s) => DecodeAsTypeResult::Skipped(f(s)),
DecodeAsTypeResult::Decoded(r) => DecodeAsTypeResult::Decoded(r),
}
}
}

/// This is implemented for visitor related types which have a `decode_item` method,
/// and allows you to generically talk about decoding unnamed items.
pub trait DecodeItemIterator<'scale, 'info> {
Expand Down Expand Up @@ -358,6 +378,34 @@ impl Visitor for IgnoreVisitor {
}
}

/// Some [`Visitor`] implementations may want to return an error type other than [`crate::Error`], which means
/// that they would not be automatically compatible with [`crate::IntoVisitor`], which requires visitors that do return
/// [`crate::Error`] errors.
///
/// As long as the error type of the visitor implementation can be converted into [`crate::Error`] via [`Into`],
/// the visitor implementation can be wrapped in this [`VisitorWithCrateError`] struct to make it work with
/// [`crate::IntoVisitor`].
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct VisitorWithCrateError<V>(pub V);

impl<V: Visitor> Visitor for VisitorWithCrateError<V>
where
V::Error: Into<crate::Error>,
{
type Value<'scale, 'info> = V::Value<'scale, 'info>;
type Error = crate::Error;

fn unchecked_decode_as_type<'scale, 'info>(
self,
input: &mut &'scale [u8],
type_id: TypeId,
types: &'info scale_info::PortableRegistry,
) -> DecodeAsTypeResult<Self, Result<Self::Value<'scale, 'info>, Self::Error>> {
let res = decode_with_visitor(input, type_id.0, types, self.0).map_err(Into::into);
DecodeAsTypeResult::Decoded(res)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this will allow for all sorts of easy conversions.

Copy link
Collaborator Author

@jsdw jsdw Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah; it basically replaces the "auto conversion" we did before but has the nice effect of simplifying the code anyway and making DecodeAsType work better, so I'm quite happy with this solution in the end :)


#[cfg(test)]
mod test {
use crate::visitor::TypeId;
Expand Down Expand Up @@ -397,6 +445,7 @@ mod test {
BitSequence(scale_bits::Bits),
}

#[derive(Clone, Copy)]
struct ValueVisitor;
impl Visitor for ValueVisitor {
type Value<'scale, 'info> = Value;
Expand Down Expand Up @@ -595,22 +644,69 @@ mod test {

/// This just tests that if we try to decode some values we've encoded using a visitor
/// which just ignores everything by default, that we'll consume all of the bytes.
fn encode_decode_check_explicit_info<Ty: scale_info::TypeInfo + 'static, T: Encode>(
fn encode_decode_check_explicit_info<
Ty: scale_info::TypeInfo + 'static,
T: Encode,
V: for<'s, 'i> Visitor<Value<'s, 'i> = Value, Error = E>,
E: core::fmt::Debug,
>(
val: T,
expected: Value,
visitor: V,
) {
let encoded = val.encode();
let (id, types) = make_type::<Ty>();
let bytes = &mut &*encoded;
let val = decode_with_visitor(bytes, id, &types, ValueVisitor)
.expect("decoding should not error");
let val =
decode_with_visitor(bytes, id, &types, visitor).expect("decoding should not error");

assert_eq!(bytes.len(), 0, "Decoding should consume all bytes");
assert_eq!(val, expected);
}

fn encode_decode_check_with_visitor<
T: Encode + scale_info::TypeInfo + 'static,
V: for<'s, 'i> Visitor<Value<'s, 'i> = Value, Error = E>,
E: core::fmt::Debug,
>(
val: T,
expected: Value,
visitor: V,
) {
encode_decode_check_explicit_info::<T, T, _, _>(val, expected, visitor);
}

fn encode_decode_check<T: Encode + scale_info::TypeInfo + 'static>(val: T, expected: Value) {
encode_decode_check_explicit_info::<T, T>(val, expected);
encode_decode_check_explicit_info::<T, T, _, _>(val, expected, ValueVisitor);
}

#[test]
fn decode_with_root_error_wrapper_works() {
use crate::visitor::VisitorWithCrateError;
let visitor = VisitorWithCrateError(ValueVisitor);

encode_decode_check_with_visitor(123u8, Value::U8(123), visitor);
encode_decode_check_with_visitor(123u16, Value::U16(123), visitor);
encode_decode_check_with_visitor(123u32, Value::U32(123), visitor);
encode_decode_check_with_visitor(123u64, Value::U64(123), visitor);
encode_decode_check_with_visitor(123u128, Value::U128(123), visitor);
encode_decode_check_with_visitor(
"Hello there",
Value::Str("Hello there".to_owned()),
visitor,
);

#[derive(Encode, scale_info::TypeInfo)]
struct Unnamed(bool, String, Vec<u8>);
encode_decode_check_with_visitor(
Unnamed(true, "James".into(), vec![1, 2, 3]),
Value::Composite(vec![
(String::new(), Value::Bool(true)),
(String::new(), Value::Str("James".to_string())),
(String::new(), Value::Sequence(vec![Value::U8(1), Value::U8(2), Value::U8(3)])),
]),
visitor,
);
}

#[test]
Expand All @@ -627,7 +723,11 @@ mod test {
encode_decode_check(codec::Compact(123u128), Value::U128(123));
encode_decode_check(true, Value::Bool(true));
encode_decode_check(false, Value::Bool(false));
encode_decode_check_explicit_info::<char, _>('c' as u32, Value::Char('c'));
encode_decode_check_explicit_info::<char, _, _, _>(
'c' as u32,
Value::Char('c'),
ValueVisitor,
);
encode_decode_check("Hello there", Value::Str("Hello there".to_owned()));
encode_decode_check("Hello there".to_string(), Value::Str("Hello there".to_owned()));
}
Expand Down
Loading