diff --git a/firewood/src/merkle.rs b/firewood/src/merkle.rs index 23f76c6cb..812a24b33 100644 --- a/firewood/src/merkle.rs +++ b/firewood/src/merkle.rs @@ -19,8 +19,7 @@ mod stream; mod trie_hash; pub use node::{ - BinarySerde, Bincode, BranchNode, Data, EncodedNode, EncodedNodeType, LeafNode, Node, NodeType, - Path, + BinarySerde, Bincode, BranchNode, Data, EncodedNode, LeafNode, Node, NodeType, Path, }; pub use proof::{Proof, ProofError}; pub use stream::MerkleKeyValueStream; @@ -115,11 +114,17 @@ where #[allow(dead_code)] fn encode(&self, node: &NodeType) -> Result, MerkleError> { let encoded = match node { - NodeType::Leaf(n) => EncodedNode::new(EncodedNodeType::Leaf(n.clone())), + NodeType::Leaf(n) => { + let children: [Option>; BranchNode::MAX_CHILDREN] = Default::default(); + EncodedNode { + partial_path: n.partial_path.clone(), + children: Box::new(children), + value: n.data.clone().into(), + phantom: PhantomData, + } + } NodeType::Branch(n) => { - let path = n.partial_path.clone(); - // pair up DiskAddresses with encoded children and pick the right one let encoded_children = n.chd().iter().zip(n.children_encoded.iter()); let children = encoded_children @@ -138,11 +143,13 @@ where .try_into() .expect("MAX_CHILDREN will always be yielded"); - EncodedNode::new(EncodedNodeType::Branch { - path, + let value = n.value.as_ref().map(|v| v.0.clone()); + EncodedNode { + partial_path: n.partial_path.clone(), children, - value: n.value.clone(), - }) + value, + phantom: PhantomData, + } } }; @@ -154,23 +161,23 @@ where let encoded: EncodedNode = T::deserialize(buf).map_err(|e| MerkleError::BinarySerdeError(e.to_string()))?; - match encoded.node { - EncodedNodeType::Leaf(leaf) => Ok(NodeType::Leaf(leaf)), - EncodedNodeType::Branch { - path, - children, - value, - } => { - let path = Path::decode(&path); - let value = value.map(|v| v.0); - let branch = NodeType::Branch( - BranchNode::new(path, [None; BranchNode::MAX_CHILDREN], value, *children) - .into(), - ); - - Ok(branch) - } + if encoded.children.iter().all(|b| b.is_none()) { + // This is a leaf node + return Ok(NodeType::Leaf(LeafNode::new( + encoded.partial_path, + Data(encoded.value.expect("leaf nodes must always have a value")), + ))); } + + Ok(NodeType::Branch( + BranchNode::new( + encoded.partial_path, + [None; BranchNode::MAX_CHILDREN], + encoded.value, + *encoded.children, + ) + .into(), + )) } } @@ -1451,8 +1458,11 @@ mod tests { let path = Path(path.into_iter().collect()); let children = Default::default(); - // TODO: Properly test empty data as a value - let value = Some(Data(value)); + let value = if value.is_empty() { + None + } else { + Some(Data(value)) + }; let mut children_encoded = <[Option>; BranchNode::MAX_CHILDREN]>::default(); if let Some(child) = encoded_child { @@ -1513,16 +1523,16 @@ mod tests { assert_eq!(encoded, new_node_encoded); } - #[test_case(Bincode::new(), leaf(Vec::new(), Vec::new()) ; "empty leaf encoding with Bincode")] - #[test_case(Bincode::new(), leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding with Bincode")] - #[test_case(Bincode::new(), branch(b"", b"value", vec![1, 2, 3].into()) ; "branch with chd with Bincode")] - #[test_case(Bincode::new(), branch(b"", b"value", None); "branch without chd with Bincode")] - #[test_case(Bincode::new(), branch_without_data(b"", None); "branch without value and chd with Bincode")] - #[test_case(PlainCodec::new(), leaf(Vec::new(), Vec::new()) ; "empty leaf encoding with PlainCodec")] - #[test_case(PlainCodec::new(), leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding with PlainCodec")] - #[test_case(PlainCodec::new(), branch(b"", b"value", vec![1, 2, 3].into()) ; "branch with chd with PlainCodec")] - #[test_case(PlainCodec::new(), branch(b"", b"value", Some(Vec::new())); "branch with empty chd with PlainCodec")] - #[test_case(PlainCodec::new(), branch(b"", b"", vec![1, 2, 3].into()); "branch with empty value with PlainCodec")] + #[test_case(Bincode::new(), leaf(vec![], vec![4, 5]) ; "leaf without partial path encoding with Bincode")] + #[test_case(Bincode::new(), leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf with partial path encoding with Bincode")] + #[test_case(Bincode::new(), branch(b"abcd", b"value", vec![1, 2, 3].into()) ; "branch with partial path and value with Bincode")] + #[test_case(Bincode::new(), branch(b"abcd", &[], vec![1, 2, 3].into()) ; "branch with partial path and no value with Bincode")] + #[test_case(Bincode::new(), branch(b"", &[1,3,3,7], vec![1, 2, 3].into()) ; "branch with no partial path and value with Bincode")] + #[test_case(PlainCodec::new(), leaf(Vec::new(), vec![4, 5]) ; "leaf without partial path encoding with PlainCodec")] + #[test_case(PlainCodec::new(), leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf with partial path encoding with PlainCodec")] + #[test_case(PlainCodec::new(), branch(b"abcd", b"value", vec![1, 2, 3].into()) ; "branch with partial path and value with PlainCodec")] + #[test_case(PlainCodec::new(), branch(b"abcd", &[], vec![1, 2, 3].into()) ; "branch with partial path and no value with PlainCodec")] + #[test_case(PlainCodec::new(), branch(b"", &[1,3,3,7], vec![1, 2, 3].into()) ; "branch with no partial path and value with PlainCodec")] fn node_encode_decode(_codec: T, node: Node) where T: BinarySerde, diff --git a/firewood/src/merkle/node.rs b/firewood/src/merkle/node.rs index 03531b13b..78152f395 100644 --- a/firewood/src/merkle/node.rs +++ b/firewood/src/merkle/node.rs @@ -56,6 +56,12 @@ impl std::ops::Deref for Data { } } +impl From for Option> { + fn from(val: Data) -> Self { + Some(val.0) + } +} + impl From> for Data { fn from(v: Vec) -> Self { Self(v) @@ -482,82 +488,55 @@ impl Storable for Node { } } +/// Contains the fields that we include in a node's hash. +/// If this is a leaf node, `children` is empty and `value` is Some. +/// If this is a branch node, `children` is non-empty. +#[derive(Debug)] pub struct EncodedNode { - pub(crate) node: EncodedNodeType, + pub(crate) partial_path: Path, + /// If a child is None, it doesn't exist. + /// If it's Some, it's the value or value hash of the child. + pub(crate) children: Box<[Option>; BranchNode::MAX_CHILDREN]>, + pub(crate) value: Option>, pub(crate) phantom: PhantomData, } -impl EncodedNode { - pub const fn new(node: EncodedNodeType) -> Self { - Self { - node, - phantom: PhantomData, - } +impl PartialEq for EncodedNode { + fn eq(&self, other: &Self) -> bool { + self.partial_path == other.partial_path + && self.children == other.children + && self.value == other.value } } -#[derive(Debug, PartialEq)] -pub enum EncodedNodeType { - Leaf(LeafNode), - Branch { - path: Path, - children: Box<[Option>; BranchNode::MAX_CHILDREN]>, - value: Option, - }, -} - -// TODO: probably can merge with `EncodedNodeType`. -#[derive(Debug, Deserialize)] -struct EncodedBranchNode { - chd: Vec<(u64, Vec)>, - data: Option>, - path: Vec, -} - // Note that the serializer passed in should always be the same type as T in EncodedNode. impl Serialize for EncodedNode { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { - let (chd, data, path) = match &self.node { - EncodedNodeType::Leaf(n) => { - let data = Some(&*n.data); - let chd: Vec<(u64, Vec)> = Default::default(); - let path: Vec<_> = nibbles_to_bytes_iter(&n.partial_path.encode()).collect(); - (chd, data, path) - } + let chd: Vec<(u64, Vec)> = self + .children + .iter() + .enumerate() + .filter_map(|(i, c)| c.as_ref().map(|c| (i as u64, c))) + .map(|(i, c)| { + if c.len() >= TRIE_HASH_LEN { + (i, Keccak256::digest(c).to_vec()) + } else { + (i, c.to_vec()) + } + }) + .collect(); - EncodedNodeType::Branch { - path, - children, - value, - } => { - let chd: Vec<(u64, Vec)> = children - .iter() - .enumerate() - .filter_map(|(i, c)| c.as_ref().map(|c| (i as u64, c))) - .map(|(i, c)| { - if c.len() >= TRIE_HASH_LEN { - (i, Keccak256::digest(c).to_vec()) - } else { - (i, c.to_vec()) - } - }) - .collect(); - - let data = value.as_deref(); - - let path = nibbles_to_bytes_iter(&path.encode()).collect(); - - (chd, data, path) - } - }; + let value = self.value.as_deref(); + + let path: Vec = nibbles_to_bytes_iter(&self.partial_path.encode()).collect(); let mut s = serializer.serialize_tuple(3)?; s.serialize_element(&chd)?; - s.serialize_element(&data)?; + s.serialize_element(&value)?; s.serialize_element(&path)?; s.end() @@ -569,96 +548,63 @@ impl<'de> Deserialize<'de> for EncodedNode { where D: serde::Deserializer<'de>, { - let EncodedBranchNode { chd, data, path } = Deserialize::deserialize(deserializer)?; - - let path = Path::from_nibbles(Nibbles::<0>::new(&path).into_iter()); - - if chd.is_empty() { - let data = if let Some(d) = data { - Data(d) - } else { - Data(Vec::new()) - }; - - let node = EncodedNodeType::Leaf(LeafNode { - partial_path: path, - data, - }); - - Ok(Self::new(node)) - } else { - let mut children: [Option>; BranchNode::MAX_CHILDREN] = Default::default(); - let value = data.map(Data); + let chd: Vec<(u64, Vec)>; + let value: Option>; + let path: Vec; - #[allow(clippy::indexing_slicing)] - for (i, chd) in chd { - children[i as usize] = Some(chd); - } + (chd, value, path) = Deserialize::deserialize(deserializer)?; - let node = EncodedNodeType::Branch { - path, - children: children.into(), - value, - }; + let path = Path::from_nibbles(Nibbles::<0>::new(&path).into_iter()); - Ok(Self::new(node)) + let mut children: [Option>; BranchNode::MAX_CHILDREN] = Default::default(); + #[allow(clippy::indexing_slicing)] + for (i, chd) in chd { + children[i as usize] = Some(chd); } + + Ok(Self { + partial_path: path, + children: children.into(), + value, + phantom: PhantomData, + }) } } // Note that the serializer passed in should always be the same type as T in EncodedNode. impl Serialize for EncodedNode { fn serialize(&self, serializer: S) -> Result { - match &self.node { - EncodedNodeType::Leaf(n) => { - let list = [ - nibbles_to_bytes_iter(&n.partial_path.encode()).collect(), - n.data.to_vec(), - ]; - let mut seq = serializer.serialize_seq(Some(list.len()))?; - for e in list { - seq.serialize_element(&e)?; - } - seq.end() + let mut list = <[Vec; BranchNode::MAX_CHILDREN + 2]>::default(); + let children = self + .children + .iter() + .enumerate() + .filter_map(|(i, c)| c.as_ref().map(|c| (i, c))); + + #[allow(clippy::indexing_slicing)] + for (i, child) in children { + if child.len() >= TRIE_HASH_LEN { + let serialized_hash = Keccak256::digest(child).to_vec(); + list[i] = serialized_hash; + } else { + list[i] = child.to_vec(); } + } - EncodedNodeType::Branch { - path, - children, - value, - } => { - let mut list = <[Vec; BranchNode::MAX_CHILDREN + 2]>::default(); - let children = children - .iter() - .enumerate() - .filter_map(|(i, c)| c.as_ref().map(|c| (i, c))); - - #[allow(clippy::indexing_slicing)] - for (i, child) in children { - if child.len() >= TRIE_HASH_LEN { - let serialized_hash = Keccak256::digest(child).to_vec(); - list[i] = serialized_hash; - } else { - list[i] = child.to_vec(); - } - } - - if let Some(Data(val)) = &value { - list[BranchNode::MAX_CHILDREN] = val.clone(); - } - - let serialized_path = nibbles_to_bytes_iter(&path.encode()).collect(); - list[BranchNode::MAX_CHILDREN + 1] = serialized_path; + if let Some(val) = &self.value { + list[BranchNode::MAX_CHILDREN] = val.clone(); + } - let mut seq = serializer.serialize_seq(Some(list.len()))?; + let serialized_path = nibbles_to_bytes_iter(&self.partial_path.encode()).collect(); + list[BranchNode::MAX_CHILDREN + 1] = serialized_path; - for e in list { - seq.serialize_element(&e)?; - } + let mut seq = serializer.serialize_seq(Some(list.len()))?; - seq.end() - } + for e in list { + seq.serialize_element(&e)?; } + + seq.end() } } @@ -686,11 +632,13 @@ impl<'de> Deserialize<'de> for EncodedNode { )); }; let path = Path::from_nibbles(Nibbles::<0>::new(&path).into_iter()); - let node = EncodedNodeType::Leaf(LeafNode { + let children: [Option>; BranchNode::MAX_CHILDREN] = Default::default(); + Ok(Self { partial_path: path, - data: Data(data), - }); - Ok(Self::new(node)) + children: children.into(), + value: Some(data), + phantom: PhantomData, + }) } BranchNode::MSIZE => { @@ -698,11 +646,7 @@ impl<'de> Deserialize<'de> for EncodedNode { let path = Path::from_nibbles(Nibbles::<0>::new(&path).into_iter()); let value = items.pop().expect("length was checked above"); - let value = if value.is_empty() { - None - } else { - Some(Data(value)) - }; + let value = if value.is_empty() { None } else { Some(value) }; let mut children: [Option>; BranchNode::MAX_CHILDREN] = Default::default(); @@ -711,14 +655,10 @@ impl<'de> Deserialize<'de> for EncodedNode { (children[i] = Some(chd).filter(|chd| !chd.is_empty())); } - let node = EncodedNodeType::Branch { - path, + Ok(Self { + partial_path: path, children: children.into(), value, - }; - - Ok(Self { - node, phantom: PhantomData, }) } @@ -885,17 +825,18 @@ mod tests { #[test_case(&[0x0F,0x0F])] #[test_case(&[0x0F,0x01,0x0F])] fn encoded_branch_node_bincode_serialize(path_nibbles: &[u8]) -> Result<(), Error> { - let node = EncodedNode::::new(EncodedNodeType::Branch { - path: Path(path_nibbles.to_vec()), + let node = EncodedNode:: { + partial_path: Path(path_nibbles.to_vec()), children: Default::default(), - value: Some(Data(vec![1, 2, 3, 4])), - }); + value: Some(vec![1, 2, 3, 4]), + phantom: PhantomData, + }; let node_bytes = Bincode::serialize(&node)?; let deserialized_node: EncodedNode = Bincode::deserialize(&node_bytes)?; - assert_eq!(&node.node, &deserialized_node.node); + assert_eq!(&node, &deserialized_node); Ok(()) }