From 4d874245b1e05faabc807650637d16f97bbe6a23 Mon Sep 17 00:00:00 2001 From: Dan Laine Date: Fri, 1 Mar 2024 14:58:35 -0500 Subject: [PATCH] remove terminal flag from PartialPath (#571) --- firewood/src/merkle.rs | 2 +- firewood/src/merkle/node.rs | 20 ++++++--------- firewood/src/merkle/node/branch.rs | 8 +++--- firewood/src/merkle/node/leaf.rs | 13 ++++++---- firewood/src/merkle/node/partial_path.rs | 31 ++++++++++-------------- 5 files changed, 34 insertions(+), 40 deletions(-) diff --git a/firewood/src/merkle.rs b/firewood/src/merkle.rs index 023358c49..c20e54d77 100644 --- a/firewood/src/merkle.rs +++ b/firewood/src/merkle.rs @@ -159,7 +159,7 @@ where children, value, } => { - let path = PartialPath::decode(&path).0; + let path = PartialPath::decode(&path); let value = value.map(|v| v.0); let branch = NodeType::Branch( BranchNode::new(path, [None; BranchNode::MAX_CHILDREN], value, *children) diff --git a/firewood/src/merkle/node.rs b/firewood/src/merkle/node.rs index 6cb6f0333..c743ac885 100644 --- a/firewood/src/merkle/node.rs +++ b/firewood/src/merkle/node.rs @@ -42,7 +42,6 @@ use super::{TrieHash, TRIE_HASH_LEN}; bitflags! { // should only ever be the size of a nibble struct Flags: u8 { - const TERMINAL = 0b0010; const ODD_LEN = 0b0001; } } @@ -88,7 +87,7 @@ impl NodeType { let decoded_key_nibbles = Nibbles::<0>::new(&decoded_key); - let cur_key_path = PartialPath::from_nibbles(decoded_key_nibbles.into_iter()).0; + let cur_key_path = PartialPath::from_nibbles(decoded_key_nibbles.into_iter()); let cur_key = cur_key_path.into_inner(); #[allow(clippy::unwrap_used)] @@ -522,7 +521,7 @@ impl Serialize for EncodedNode { EncodedNodeType::Leaf(n) => { let data = Some(&*n.data); let chd: Vec<(u64, Vec)> = Default::default(); - let path: Vec<_> = from_nibbles(&n.path.encode(true)).collect(); + let path: Vec<_> = from_nibbles(&n.path.encode()).collect(); (chd, data, path) } @@ -546,7 +545,7 @@ impl Serialize for EncodedNode { let data = value.as_deref(); - let path = from_nibbles(&path.encode(false)).collect(); + let path = from_nibbles(&path.encode()).collect(); (chd, data, path) } @@ -569,7 +568,7 @@ impl<'de> Deserialize<'de> for EncodedNode { { let EncodedBranchNode { chd, data, path } = Deserialize::deserialize(deserializer)?; - let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()).0; + let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()); if chd.is_empty() { let data = if let Some(d) = data { @@ -606,10 +605,7 @@ impl Serialize for EncodedNode { fn serialize(&self, serializer: S) -> Result { match &self.node { EncodedNodeType::Leaf(n) => { - let list = [ - from_nibbles(&n.path.encode(true)).collect(), - n.data.to_vec(), - ]; + let list = [from_nibbles(&n.path.encode()).collect(), n.data.to_vec()]; let mut seq = serializer.serialize_seq(Some(list.len()))?; for e in list { seq.serialize_element(&e)?; @@ -642,7 +638,7 @@ impl Serialize for EncodedNode { list[BranchNode::MAX_CHILDREN] = val.clone(); } - let serialized_path = from_nibbles(&path.encode(true)).collect(); + let serialized_path = from_nibbles(&path.encode()).collect(); list[BranchNode::MAX_CHILDREN + 1] = serialized_path; let mut seq = serializer.serialize_seq(Some(list.len()))?; @@ -680,7 +676,7 @@ impl<'de> Deserialize<'de> for EncodedNode { "incorrect encoded type for leaf node data", )); }; - let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()).0; + let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()); let node = EncodedNodeType::Leaf(LeafNode { path, data: Data(data), @@ -690,7 +686,7 @@ impl<'de> Deserialize<'de> for EncodedNode { BranchNode::MSIZE => { let path = items.pop().expect("length was checked above"); - let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()).0; + let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()); let value = items.pop().expect("length was checked above"); let value = if value.is_empty() { diff --git a/firewood/src/merkle/node/branch.rs b/firewood/src/merkle/node/branch.rs index 749c02586..62e1941fa 100644 --- a/firewood/src/merkle/node/branch.rs +++ b/firewood/src/merkle/node/branch.rs @@ -100,7 +100,7 @@ impl BranchNode { let path = items.pop().ok_or(Error::custom("Invalid Branch Node"))?; let path = Nibbles::<0>::new(&path); - let (path, _term) = PartialPath::from_nibbles(path.into_iter()); + let path = PartialPath::from_nibbles(path.into_iter()); // we've already validated the size, that's why we can safely unwrap #[allow(clippy::unwrap_used)] @@ -176,7 +176,7 @@ impl BranchNode { } #[allow(clippy::unwrap_used)] - let path = from_nibbles(&self.path.encode(false)).collect::>(); + let path = from_nibbles(&self.path.encode()).collect::>(); list[Self::MAX_CHILDREN + 1] = path; @@ -202,7 +202,7 @@ impl Storable for BranchNode { fn serialize(&self, to: &mut [u8]) -> Result<(), crate::shale::ShaleError> { let mut cursor = Cursor::new(to); - let path: Vec = from_nibbles(&self.path.encode(false)).collect(); + let path: Vec = from_nibbles(&self.path.encode()).collect(); cursor.write_all(&[path.len() as PathLen])?; cursor.write_all(&path)?; @@ -271,7 +271,7 @@ impl Storable for BranchNode { addr += path_len as usize; let path: Vec = path.into_iter().flat_map(to_nibble_array).collect(); - let path = PartialPath::decode(&path).0; + let path = PartialPath::decode(&path); let node_raw = mem.get_view(addr, BRANCH_HEADER_SIZE) diff --git a/firewood/src/merkle/node/leaf.rs b/firewood/src/merkle/node/leaf.rs index c38a20328..905cf2b4c 100644 --- a/firewood/src/merkle/node/leaf.rs +++ b/firewood/src/merkle/node/leaf.rs @@ -53,7 +53,7 @@ impl LeafNode { bincode::DefaultOptions::new() .serialize( [ - from_nibbles(&self.path.encode(true)).collect(), + from_nibbles(&self.path.encode()).collect(), self.data.to_vec(), ] .as_slice(), @@ -85,7 +85,7 @@ impl Storable for LeafNode { fn serialize(&self, to: &mut [u8]) -> Result<(), crate::shale::ShaleError> { let mut cursor = Cursor::new(to); - let path = &self.path.encode(true); + let path = &self.path.encode(); let path = from_nibbles(path); let data = &self.data; @@ -149,9 +149,12 @@ mod tests { use test_case::test_case; // these tests will fail if the encoding mechanism changes and should be updated accordingly - #[test_case(0b10 << 4, vec![0x12, 0x34], vec![1, 2, 3, 4]; "even length")] - // first nibble is part of the prefix - #[test_case((0b11 << 4) + 2, vec![0x34], vec![2, 3, 4]; "odd length")] + // + // Even length so ODD_LEN flag is not set so flag byte is 0b0000_0000 + #[test_case(0x00, vec![0x12, 0x34], vec![1, 2, 3, 4]; "even length")] + // Odd length so ODD_LEN flag is set so flag byte is 0b0000_0001 + // This is combined with the first nibble of the path (0b0000_0010) to become 0b0001_0010 + #[test_case(0b0001_0010, vec![0x34], vec![2, 3, 4]; "odd length")] fn encode_regression_test(prefix: u8, path: Vec, nibbles: Vec) { let data = vec![5, 6, 7, 8]; diff --git a/firewood/src/merkle/node/partial_path.rs b/firewood/src/merkle/node/partial_path.rs index 1e1d5a74c..fe1583b3d 100644 --- a/firewood/src/merkle/node/partial_path.rs +++ b/firewood/src/merkle/node/partial_path.rs @@ -40,13 +40,9 @@ impl PartialPath { self.0 } - pub(crate) fn encode(&self, is_terminal: bool) -> Vec { + pub(crate) fn encode(&self) -> Vec { let mut flags = Flags::empty(); - if is_terminal { - flags.insert(Flags::TERMINAL); - } - let has_odd_len = self.0.len() & 1 == 1; let extra_byte = if has_odd_len { @@ -67,24 +63,24 @@ impl PartialPath { // I also think `PartialPath` could probably borrow instead of own data. // /// returns a tuple of the decoded partial path and whether the path is terminal - pub fn decode(raw: &[u8]) -> (Self, bool) { + pub fn decode(raw: &[u8]) -> Self { Self::from_iter(raw.iter().copied()) } /// returns a tuple of the decoded partial path and whether the path is terminal - pub fn from_nibbles(nibbles: NibblesIterator<'_, N>) -> (Self, bool) { + pub fn from_nibbles(nibbles: NibblesIterator<'_, N>) -> Self { Self::from_iter(nibbles) } /// Assumes all bytes are nibbles, prefer to use `from_nibbles` instead. - fn from_iter>(mut iter: Iter) -> (Self, bool) { + fn from_iter>(mut iter: Iter) -> Self { let flags = Flags::from_bits_retain(iter.next().unwrap_or_default()); if !flags.contains(Flags::ODD_LEN) { let _ = iter.next(); } - (Self(iter.collect()), flags.contains(Flags::TERMINAL)) + Self(iter.collect()) } pub(super) fn serialized_len(&self) -> u64 { @@ -107,20 +103,19 @@ mod tests { use super::*; use test_case::test_case; - #[test_case(&[1, 2, 3, 4], true)] - #[test_case(&[1, 2, 3], false)] - #[test_case(&[0, 1, 2], false)] - #[test_case(&[1, 2], true)] - #[test_case(&[1], true)] - fn test_encoding(steps: &[u8], term: bool) { + #[test_case(&[1, 2, 3, 4])] + #[test_case(&[1, 2, 3])] + #[test_case(&[0, 1, 2])] + #[test_case(&[1, 2])] + #[test_case(&[1])] + fn test_encoding(steps: &[u8]) { let path = PartialPath(steps.to_vec()); - let encoded = path.encode(term); + let encoded = path.encode(); assert_eq!(encoded.len(), path.serialized_len() as usize * 2); - let (decoded, decoded_term) = PartialPath::decode(&encoded); + let decoded = PartialPath::decode(&encoded); assert_eq!(&&*decoded, &steps); - assert_eq!(decoded_term, term); } }