Skip to content

Commit

Permalink
remove terminal flag from PartialPath (#571)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Laine authored Mar 1, 2024
1 parent 3530acc commit 4d87424
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 40 deletions.
2 changes: 1 addition & 1 deletion firewood/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ where
children,
value,
} => {
let path = PartialPath::decode(&path).0;
let path = PartialPath::decode(&path);
let value = value.map(|v| v.0);
let branch = NodeType::Branch(
BranchNode::new(path, [None; BranchNode::MAX_CHILDREN], value, *children)
Expand Down
20 changes: 8 additions & 12 deletions firewood/src/merkle/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ use super::{TrieHash, TRIE_HASH_LEN};
bitflags! {
// should only ever be the size of a nibble
struct Flags: u8 {
const TERMINAL = 0b0010;
const ODD_LEN = 0b0001;
}
}
Expand Down Expand Up @@ -88,7 +87,7 @@ impl NodeType {

let decoded_key_nibbles = Nibbles::<0>::new(&decoded_key);

let cur_key_path = PartialPath::from_nibbles(decoded_key_nibbles.into_iter()).0;
let cur_key_path = PartialPath::from_nibbles(decoded_key_nibbles.into_iter());

let cur_key = cur_key_path.into_inner();
#[allow(clippy::unwrap_used)]
Expand Down Expand Up @@ -522,7 +521,7 @@ impl Serialize for EncodedNode<PlainCodec> {
EncodedNodeType::Leaf(n) => {
let data = Some(&*n.data);
let chd: Vec<(u64, Vec<u8>)> = Default::default();
let path: Vec<_> = from_nibbles(&n.path.encode(true)).collect();
let path: Vec<_> = from_nibbles(&n.path.encode()).collect();
(chd, data, path)
}

Expand All @@ -546,7 +545,7 @@ impl Serialize for EncodedNode<PlainCodec> {

let data = value.as_deref();

let path = from_nibbles(&path.encode(false)).collect();
let path = from_nibbles(&path.encode()).collect();

(chd, data, path)
}
Expand All @@ -569,7 +568,7 @@ impl<'de> Deserialize<'de> for EncodedNode<PlainCodec> {
{
let EncodedBranchNode { chd, data, path } = Deserialize::deserialize(deserializer)?;

let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()).0;
let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter());

if chd.is_empty() {
let data = if let Some(d) = data {
Expand Down Expand Up @@ -606,10 +605,7 @@ impl Serialize for EncodedNode<Bincode> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match &self.node {
EncodedNodeType::Leaf(n) => {
let list = [
from_nibbles(&n.path.encode(true)).collect(),
n.data.to_vec(),
];
let list = [from_nibbles(&n.path.encode()).collect(), n.data.to_vec()];
let mut seq = serializer.serialize_seq(Some(list.len()))?;
for e in list {
seq.serialize_element(&e)?;
Expand Down Expand Up @@ -642,7 +638,7 @@ impl Serialize for EncodedNode<Bincode> {
list[BranchNode::MAX_CHILDREN] = val.clone();
}

let serialized_path = from_nibbles(&path.encode(true)).collect();
let serialized_path = from_nibbles(&path.encode()).collect();
list[BranchNode::MAX_CHILDREN + 1] = serialized_path;

let mut seq = serializer.serialize_seq(Some(list.len()))?;
Expand Down Expand Up @@ -680,7 +676,7 @@ impl<'de> Deserialize<'de> for EncodedNode<Bincode> {
"incorrect encoded type for leaf node data",
));
};
let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()).0;
let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter());
let node = EncodedNodeType::Leaf(LeafNode {
path,
data: Data(data),
Expand All @@ -690,7 +686,7 @@ impl<'de> Deserialize<'de> for EncodedNode<Bincode> {

BranchNode::MSIZE => {
let path = items.pop().expect("length was checked above");
let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()).0;
let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter());

let value = items.pop().expect("length was checked above");
let value = if value.is_empty() {
Expand Down
8 changes: 4 additions & 4 deletions firewood/src/merkle/node/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl BranchNode {

let path = items.pop().ok_or(Error::custom("Invalid Branch Node"))?;
let path = Nibbles::<0>::new(&path);
let (path, _term) = PartialPath::from_nibbles(path.into_iter());
let path = PartialPath::from_nibbles(path.into_iter());

// we've already validated the size, that's why we can safely unwrap
#[allow(clippy::unwrap_used)]
Expand Down Expand Up @@ -176,7 +176,7 @@ impl BranchNode {
}

#[allow(clippy::unwrap_used)]
let path = from_nibbles(&self.path.encode(false)).collect::<Vec<_>>();
let path = from_nibbles(&self.path.encode()).collect::<Vec<_>>();

list[Self::MAX_CHILDREN + 1] = path;

Expand All @@ -202,7 +202,7 @@ impl Storable for BranchNode {
fn serialize(&self, to: &mut [u8]) -> Result<(), crate::shale::ShaleError> {
let mut cursor = Cursor::new(to);

let path: Vec<u8> = from_nibbles(&self.path.encode(false)).collect();
let path: Vec<u8> = from_nibbles(&self.path.encode()).collect();
cursor.write_all(&[path.len() as PathLen])?;
cursor.write_all(&path)?;

Expand Down Expand Up @@ -271,7 +271,7 @@ impl Storable for BranchNode {
addr += path_len as usize;

let path: Vec<u8> = path.into_iter().flat_map(to_nibble_array).collect();
let path = PartialPath::decode(&path).0;
let path = PartialPath::decode(&path);

let node_raw =
mem.get_view(addr, BRANCH_HEADER_SIZE)
Expand Down
13 changes: 8 additions & 5 deletions firewood/src/merkle/node/leaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl LeafNode {
bincode::DefaultOptions::new()
.serialize(
[
from_nibbles(&self.path.encode(true)).collect(),
from_nibbles(&self.path.encode()).collect(),
self.data.to_vec(),
]
.as_slice(),
Expand Down Expand Up @@ -85,7 +85,7 @@ impl Storable for LeafNode {
fn serialize(&self, to: &mut [u8]) -> Result<(), crate::shale::ShaleError> {
let mut cursor = Cursor::new(to);

let path = &self.path.encode(true);
let path = &self.path.encode();
let path = from_nibbles(path);
let data = &self.data;

Expand Down Expand Up @@ -149,9 +149,12 @@ mod tests {
use test_case::test_case;

// these tests will fail if the encoding mechanism changes and should be updated accordingly
#[test_case(0b10 << 4, vec![0x12, 0x34], vec![1, 2, 3, 4]; "even length")]
// first nibble is part of the prefix
#[test_case((0b11 << 4) + 2, vec![0x34], vec![2, 3, 4]; "odd length")]
//
// Even length so ODD_LEN flag is not set so flag byte is 0b0000_0000
#[test_case(0x00, vec![0x12, 0x34], vec![1, 2, 3, 4]; "even length")]
// Odd length so ODD_LEN flag is set so flag byte is 0b0000_0001
// This is combined with the first nibble of the path (0b0000_0010) to become 0b0001_0010
#[test_case(0b0001_0010, vec![0x34], vec![2, 3, 4]; "odd length")]
fn encode_regression_test(prefix: u8, path: Vec<u8>, nibbles: Vec<u8>) {
let data = vec![5, 6, 7, 8];

Expand Down
31 changes: 13 additions & 18 deletions firewood/src/merkle/node/partial_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,9 @@ impl PartialPath {
self.0
}

pub(crate) fn encode(&self, is_terminal: bool) -> Vec<u8> {
pub(crate) fn encode(&self) -> Vec<u8> {
let mut flags = Flags::empty();

if is_terminal {
flags.insert(Flags::TERMINAL);
}

let has_odd_len = self.0.len() & 1 == 1;

let extra_byte = if has_odd_len {
Expand All @@ -67,24 +63,24 @@ impl PartialPath {
// I also think `PartialPath` could probably borrow instead of own data.
//
/// returns a tuple of the decoded partial path and whether the path is terminal
pub fn decode(raw: &[u8]) -> (Self, bool) {
pub fn decode(raw: &[u8]) -> Self {
Self::from_iter(raw.iter().copied())
}

/// returns a tuple of the decoded partial path and whether the path is terminal
pub fn from_nibbles<const N: usize>(nibbles: NibblesIterator<'_, N>) -> (Self, bool) {
pub fn from_nibbles<const N: usize>(nibbles: NibblesIterator<'_, N>) -> Self {
Self::from_iter(nibbles)
}

/// Assumes all bytes are nibbles, prefer to use `from_nibbles` instead.
fn from_iter<Iter: Iterator<Item = u8>>(mut iter: Iter) -> (Self, bool) {
fn from_iter<Iter: Iterator<Item = u8>>(mut iter: Iter) -> Self {
let flags = Flags::from_bits_retain(iter.next().unwrap_or_default());

if !flags.contains(Flags::ODD_LEN) {
let _ = iter.next();
}

(Self(iter.collect()), flags.contains(Flags::TERMINAL))
Self(iter.collect())
}

pub(super) fn serialized_len(&self) -> u64 {
Expand All @@ -107,20 +103,19 @@ mod tests {
use super::*;
use test_case::test_case;

#[test_case(&[1, 2, 3, 4], true)]
#[test_case(&[1, 2, 3], false)]
#[test_case(&[0, 1, 2], false)]
#[test_case(&[1, 2], true)]
#[test_case(&[1], true)]
fn test_encoding(steps: &[u8], term: bool) {
#[test_case(&[1, 2, 3, 4])]
#[test_case(&[1, 2, 3])]
#[test_case(&[0, 1, 2])]
#[test_case(&[1, 2])]
#[test_case(&[1])]
fn test_encoding(steps: &[u8]) {
let path = PartialPath(steps.to_vec());
let encoded = path.encode(term);
let encoded = path.encode();

assert_eq!(encoded.len(), path.serialized_len() as usize * 2);

let (decoded, decoded_term) = PartialPath::decode(&encoded);
let decoded = PartialPath::decode(&encoded);

assert_eq!(&&*decoded, &steps);
assert_eq!(decoded_term, term);
}
}

0 comments on commit 4d87424

Please sign in to comment.