diff --git a/firewood/src/db/proposal.rs b/firewood/src/db/proposal.rs index ecf18ec11..e842c1087 100644 --- a/firewood/src/db/proposal.rs +++ b/firewood/src/db/proposal.rs @@ -10,15 +10,13 @@ use crate::shale::CachedStore; use crate::{ merkle::{TrieHash, TRIE_HASH_LEN}, storage::{buffer::BufferWrite, AshRecord, StoreRevMut}, - v2::api::{self, KeyType, ValueType}, + v2::api::{self, Batch, BatchOp, KeyType, ValueType}, }; use async_trait::async_trait; use parking_lot::{Mutex, RwLock}; use std::{io::ErrorKind, sync::Arc}; use tokio::task::block_in_place; -pub use crate::v2::api::{Batch, BatchOp}; - /// An atomic batch of changes proposed against the latest committed revision, /// or any existing [Proposal]. Multiple proposals can be created against the /// latest committed revision at the same time. [Proposal] is immutable meaning diff --git a/firewood/src/merkle.rs b/firewood/src/merkle.rs index f8b6fb0f1..25a9aea33 100644 --- a/firewood/src/merkle.rs +++ b/firewood/src/merkle.rs @@ -110,7 +110,10 @@ where fn encode(&self, node: &NodeObjRef) -> Result, MerkleError> { let encoded = match node.inner() { NodeType::Leaf(n) => EncodedNode::new(EncodedNodeType::Leaf(n.clone())), + NodeType::Branch(n) => { + let path = n.path.clone(); + // pair up DiskAddresses with encoded children and pick the right one let encoded_children = n.chd().iter().zip(n.children_encoded.iter()); let children = encoded_children @@ -127,6 +130,7 @@ where .expect("MAX_CHILDREN will always be yielded"); EncodedNode::new(EncodedNodeType::Branch { + path, children, value: n.value.clone(), }) @@ -145,13 +149,19 @@ where match encoded.node { EncodedNodeType::Leaf(leaf) => Ok(NodeType::Leaf(leaf)), - EncodedNodeType::Branch { children, value } => { - let path = Vec::new().into(); + EncodedNodeType::Branch { + path, + children, + value, + } => { + let path = PartialPath::decode(&path).0; let value = value.map(|v| v.0); - Ok(NodeType::Branch( + let branch = NodeType::Branch( BranchNode::new(path, [None; BranchNode::MAX_CHILDREN], value, *children) .into(), - )) + ); + + Ok(branch) } } } @@ -162,7 +172,7 @@ impl + Send + Sync, T> Merkle { self.store .put_item( Node::from_branch(BranchNode { - // path: vec![].into(), + path: vec![].into(), children: [None; BranchNode::MAX_CHILDREN], value: None, children_encoded: Default::default(), @@ -190,9 +200,9 @@ impl + Send + Sync, T> Merkle { }) } - pub fn root_hash(&self, root: DiskAddress) -> Result { + pub fn root_hash(&self, sentinel: DiskAddress) -> Result { let root = self - .get_node(root)? + .get_node(sentinel)? .inner .as_branch() .ok_or(MerkleError::NotBranchNode)? @@ -213,14 +223,14 @@ impl + Send + Sync, T> Merkle { fn dump_(&self, u: DiskAddress, w: &mut dyn Write) -> Result<(), MerkleError> { let u_ref = self.get_node(u)?; - write!( - w, - "{u:?} => {}: ", - match u_ref.root_hash.get() { - Some(h) => hex::encode(**h), - None => "".to_string(), - } - )?; + + let hash = match u_ref.root_hash.get() { + Some(h) => h, + None => u_ref.get_root_hash::(self.store.as_ref()), + }; + + write!(w, "{u:?} => {}: ", hex::encode(**hash))?; + match &u_ref.inner { NodeType::Branch(n) => { writeln!(w, "{n:?}")?; @@ -235,6 +245,7 @@ impl + Send + Sync, T> Merkle { self.dump_(n.chd(), w)? } } + Ok(()) } @@ -247,23 +258,25 @@ impl + Send + Sync, T> Merkle { Ok(()) } + // TODO: replace `split` with a `split_at` function. Handle the logic for matching paths in `insert` instead. #[allow(clippy::too_many_arguments)] - fn split( - &self, - mut node_to_split: NodeObjRef, - parents: &mut [(NodeObjRef, u8)], + fn split<'a>( + &'a self, + mut node_to_split: NodeObjRef<'a>, + parents: &mut [(NodeObjRef<'a>, u8)], insert_path: &[u8], n_path: Vec, n_value: Option, val: Vec, deleted: &mut Vec, - ) -> Result>, MerkleError> { + ) -> Result, Vec)>, MerkleError> { let node_to_split_address = node_to_split.as_ptr(); let split_index = insert_path .iter() .zip(n_path.iter()) .position(|(a, b)| a != b); + #[allow(clippy::indexing_slicing)] let new_child_address = if let Some(idx) = split_index { // paths diverge let new_split_node_path = n_path.split_at(idx + 1).1; @@ -286,10 +299,8 @@ impl + Send + Sync, T> Merkle { let mut chd = [None; BranchNode::MAX_CHILDREN]; - #[allow(clippy::indexing_slicing)] let last_matching_nibble = matching_path[idx]; - #[allow(clippy::indexing_slicing)] - (chd[last_matching_nibble as usize] = Some(leaf_address)); + chd[last_matching_nibble as usize] = Some(leaf_address); let address = match &node_to_split.inner { NodeType::Extension(u) if u.path.len() == 0 => { @@ -299,35 +310,22 @@ impl + Send + Sync, T> Merkle { _ => node_to_split_address, }; - #[allow(clippy::indexing_slicing)] - (chd[n_path[idx] as usize] = Some(address)); + chd[n_path[idx] as usize] = Some(address); let new_branch = Node::from_branch(BranchNode { - // path: PartialPath(matching_path[..idx].to_vec()), + path: PartialPath(matching_path[..idx].to_vec()), children: chd, value: None, children_encoded: Default::default(), }); - let new_branch_address = self.put_node(new_branch)?.as_ptr(); - - if idx > 0 { - self.put_node(Node::from(NodeType::Extension(ExtNode { - #[allow(clippy::indexing_slicing)] - path: PartialPath(matching_path[..idx].to_vec()), - child: new_branch_address, - child_encoded: None, - })))? - .as_ptr() - } else { - new_branch_address - } + self.put_node(new_branch)?.as_ptr() } else { // paths do not diverge let (leaf_address, prefix, idx, value) = match (insert_path.len().cmp(&n_path.len()), n_value) { // no node-value means this is an extension node and we can therefore continue walking the tree - (Ordering::Greater, None) => return Ok(Some(val)), + (Ordering::Greater, None) => return Ok(Some((node_to_split, val))), // if the paths are equal, we overwrite the data (Ordering::Equal, _) => { @@ -368,7 +366,9 @@ impl + Send + Sync, T> Merkle { result = Err(e); } } - NodeType::Branch(_) => unreachable!(), + NodeType::Branch(u) => { + u.value = Some(Data(val)); + } } u.rehash(); @@ -440,24 +440,13 @@ impl + Send + Sync, T> Merkle { #[allow(clippy::indexing_slicing)] (children[idx] = leaf_address.into()); - let branch_address = self - .put_node(Node::from_branch(BranchNode { - children, - value, - children_encoded: Default::default(), - }))? - .as_ptr(); - - if !prefix.is_empty() { - self.put_node(Node::from(NodeType::Extension(ExtNode { - path: PartialPath(prefix.to_vec()), - child: branch_address, - child_encoded: None, - })))? - .as_ptr() - } else { - branch_address - } + self.put_node(Node::from_branch(BranchNode { + path: PartialPath(prefix.to_vec()), + children, + value, + children_encoded: Default::default(), + }))? + .as_ptr() }; // observation: @@ -509,7 +498,7 @@ impl + Send + Sync, T> Merkle { // walk down the merkle tree starting from next_node, currently the root // return None if the value is inserted let next_node_and_val = loop { - let Some(current_nibble) = key_nibbles.next() else { + let Some(mut next_nibble) = key_nibbles.next() else { break Some((node, val)); }; @@ -518,28 +507,138 @@ impl + Send + Sync, T> Merkle { // to another node, we walk down that. Otherwise, we can store our // value as a leaf and we're done NodeType::Leaf(n) => { - // we collided with another key; make a copy - // of the stored key to pass into split - let n_path = n.path.to_vec(); - let n_value = Some(n.data.clone()); - let rem_path = once(current_nibble).chain(key_nibbles).collect::>(); + // TODO: avoid extra allocation + let key_remainder = once(next_nibble) + .chain(key_nibbles.clone()) + .collect::>(); - self.split( - node, - &mut parents, - &rem_path, - n_path, - n_value, - val, - &mut deleted, - )?; + let overlap = PrefixOverlap::from(&n.path, &key_remainder); + + #[allow(clippy::indexing_slicing)] + match (overlap.unique_a.len(), overlap.unique_b.len()) { + // same node, overwrite the data + (0, 0) => { + node.write(|node| { + node.inner.set_data(Data(val)); + node.rehash(); + })?; + } + + // new node is a child of the old node + (0, _) => { + let (new_leaf_index, new_leaf_path) = { + let (index, path) = overlap.unique_b.split_at(1); + (index[0], path.to_vec()) + }; + + let new_leaf = Node::from_leaf(LeafNode::new( + PartialPath(new_leaf_path), + Data(val), + )); + + let new_leaf = self.put_node(new_leaf)?.as_ptr(); + + let mut children = [None; BranchNode::MAX_CHILDREN]; + children[new_leaf_index as usize] = Some(new_leaf); + + let new_branch = BranchNode { + path: PartialPath(overlap.shared.to_vec()), + children, + value: n.data.clone().into(), + children_encoded: Default::default(), + }; + + let new_branch = Node::from_branch(new_branch); + + let new_branch = self.put_node(new_branch)?.as_ptr(); + + set_parent(new_branch, &mut parents); + + deleted.push(node.as_ptr()); + } + + // old node is a child of the new node + (_, 0) => { + let (old_leaf_index, old_leaf_path) = { + let (index, path) = overlap.unique_a.split_at(1); + (index[0], path.to_vec()) + }; + + let new_branch_path = overlap.shared.to_vec(); + + node.write(move |old_leaf| { + *old_leaf.inner.path_mut() = PartialPath(old_leaf_path.to_vec()); + old_leaf.rehash(); + })?; + + let old_leaf = node.as_ptr(); + + let mut new_branch = BranchNode { + path: PartialPath(new_branch_path), + children: [None; BranchNode::MAX_CHILDREN], + value: Some(val.into()), + children_encoded: Default::default(), + }; + + new_branch.children[old_leaf_index as usize] = Some(old_leaf); + + let node = Node::from_branch(new_branch); + let node = self.put_node(node)?.as_ptr(); + + set_parent(node, &mut parents); + } + + // nodes are siblings + _ => { + let (old_leaf_index, old_leaf_path) = { + let (index, path) = overlap.unique_a.split_at(1); + (index[0], path.to_vec()) + }; + + let (new_leaf_index, new_leaf_path) = { + let (index, path) = overlap.unique_b.split_at(1); + (index[0], path.to_vec()) + }; + + let new_branch_path = overlap.shared.to_vec(); + + node.write(move |old_leaf| { + *old_leaf.inner.path_mut() = PartialPath(old_leaf_path.to_vec()); + old_leaf.rehash(); + })?; + + let new_leaf = Node::from_leaf(LeafNode::new( + PartialPath(new_leaf_path), + Data(val), + )); + + let old_leaf = node.as_ptr(); + + let new_leaf = self.put_node(new_leaf)?.as_ptr(); + + let mut new_branch = BranchNode { + path: PartialPath(new_branch_path), + children: [None; BranchNode::MAX_CHILDREN], + value: None, + children_encoded: Default::default(), + }; + + new_branch.children[old_leaf_index as usize] = Some(old_leaf); + new_branch.children[new_leaf_index as usize] = Some(new_leaf); + + let node = Node::from_branch(new_branch); + let node = self.put_node(node)?.as_ptr(); + + set_parent(node, &mut parents); + } + } break None; } - NodeType::Branch(n) => { + NodeType::Branch(n) if n.path.len() == 0 => { #[allow(clippy::indexing_slicing)] - match n.children[current_nibble as usize] { + match n.children[next_nibble as usize] { Some(c) => (node, c), None => { // insert the leaf to the empty slot @@ -550,15 +649,152 @@ impl + Send + Sync, T> Merkle { Data(val), )))? .as_ptr(); + // set the current child to point to this leaf - #[allow(clippy::unwrap_used)] - node.write(|u| { - let uu = u.inner.as_branch_mut().unwrap(); - #[allow(clippy::indexing_slicing)] - (uu.children[current_nibble as usize] = Some(leaf_ptr)); - u.rehash(); - }) - .unwrap(); + #[allow(clippy::indexing_slicing)] + node.write(|node| { + node.as_branch_mut().children[next_nibble as usize] = + Some(leaf_ptr); + node.rehash(); + })?; + + break None; + } + } + } + + NodeType::Branch(n) => { + // TODO: avoid extra allocation + let key_remainder = once(next_nibble) + .chain(key_nibbles.clone()) + .collect::>(); + + let overlap = PrefixOverlap::from(&n.path, &key_remainder); + + #[allow(clippy::indexing_slicing)] + match (overlap.unique_a.len(), overlap.unique_b.len()) { + // same node, overwrite the data + (0, 0) => { + node.write(|node| { + node.inner.set_data(Data(val)); + node.rehash(); + })?; + + break None; + } + + // new node is a child of the old node + (0, _) => { + let (new_leaf_index, new_leaf_path) = { + let (index, path) = overlap.unique_b.split_at(1); + (index[0], path) + }; + + (0..overlap.shared.len()).for_each(|_| { + key_nibbles.next(); + }); + + next_nibble = new_leaf_index; + + match n.children[next_nibble as usize] { + Some(ptr) => (node, ptr), + None => { + let new_leaf = Node::from_leaf(LeafNode::new( + PartialPath(new_leaf_path.to_vec()), + Data(val), + )); + + let new_leaf = self.put_node(new_leaf)?.as_ptr(); + + #[allow(clippy::indexing_slicing)] + node.write(|node| { + node.as_branch_mut().children[next_nibble as usize] = + Some(new_leaf); + node.rehash(); + })?; + + break None; + } + } + } + + // old node is a child of the new node + (_, 0) => { + let (old_branch_index, old_branch_path) = { + let (index, path) = overlap.unique_a.split_at(1); + (index[0], path.to_vec()) + }; + + let new_branch_path = overlap.shared.to_vec(); + + node.write(move |old_branch| { + *old_branch.inner.path_mut() = + PartialPath(old_branch_path.to_vec()); + old_branch.rehash(); + })?; + + let old_branch = node.as_ptr(); + + let mut new_branch = BranchNode { + path: PartialPath(new_branch_path), + children: [None; BranchNode::MAX_CHILDREN], + value: Some(val.into()), + children_encoded: Default::default(), + }; + + new_branch.children[old_branch_index as usize] = Some(old_branch); + + let node = Node::from_branch(new_branch); + let node = self.put_node(node)?.as_ptr(); + + set_parent(node, &mut parents); + + break None; + } + + // nodes are siblings + _ => { + let (old_branch_index, old_branch_path) = { + let (index, path) = overlap.unique_a.split_at(1); + (index[0], path.to_vec()) + }; + + let (new_leaf_index, new_leaf_path) = { + let (index, path) = overlap.unique_b.split_at(1); + (index[0], path.to_vec()) + }; + + let new_branch_path = overlap.shared.to_vec(); + + node.write(move |old_branch| { + *old_branch.inner.path_mut() = + PartialPath(old_branch_path.to_vec()); + old_branch.rehash(); + })?; + + let new_leaf = Node::from_leaf(LeafNode::new( + PartialPath(new_leaf_path), + Data(val), + )); + + let old_branch = node.as_ptr(); + + let new_leaf = self.put_node(new_leaf)?.as_ptr(); + + let mut new_branch = BranchNode { + path: PartialPath(new_branch_path), + children: [None; BranchNode::MAX_CHILDREN], + value: None, + children_encoded: Default::default(), + }; + + new_branch.children[old_branch_index as usize] = Some(old_branch); + new_branch.children[new_leaf_index as usize] = Some(new_leaf); + + let node = Node::from_branch(new_branch); + let node = self.put_node(node)?.as_ptr(); + + set_parent(node, &mut parents); break None; } @@ -568,13 +804,12 @@ impl + Send + Sync, T> Merkle { NodeType::Extension(n) => { let n_path = n.path.to_vec(); let n_ptr = n.chd(); - let rem_path = once(current_nibble) + let rem_path = once(next_nibble) .chain(key_nibbles.clone()) .collect::>(); let n_path_len = n_path.len(); - let node_ptr = node.as_ptr(); - if let Some(v) = self.split( + if let Some((node, v)) = self.split( node, &mut parents, &rem_path, @@ -592,7 +827,7 @@ impl + Send + Sync, T> Merkle { // extension node's next pointer val = v; - (self.get_node(node_ptr)?, n_ptr) + (node, n_ptr) } else { // successfully inserted break None; @@ -601,7 +836,7 @@ impl + Send + Sync, T> Merkle { }; // push another parent, and follow the next pointer - parents.push((node_ref, current_nibble)); + parents.push((node_ref, next_nibble)); node = self.get_node(next_node_ptr)?; }; @@ -674,7 +909,7 @@ impl + Send + Sync, T> Merkle { let branch = self .put_node(Node::from_branch(BranchNode { - // path: vec![].into(), + path: vec![].into(), children: chd, value: Some(Data(val)), children_encoded: Default::default(), @@ -1000,6 +1235,183 @@ impl + Send + Sync, T> Merkle { return Ok(None); } + let mut deleted = Vec::new(); + + let data = { + let (node, mut parents) = + self.get_node_and_parents_by_key(self.get_node(root)?, key)?; + + let Some(mut node) = node else { + return Ok(None); + }; + + let data = match &node.inner { + NodeType::Branch(branch) => { + let data = branch.value.clone(); + let children = branch.children; + + if data.is_none() { + return Ok(None); + } + + let children: Vec<_> = children + .iter() + .enumerate() + .filter_map(|(i, child)| child.map(|child| (i, child))) + .collect(); + + // don't change the sentinal node + if children.len() == 1 && !parents.is_empty() { + let branch_path = &branch.path.0; + + #[allow(clippy::indexing_slicing)] + let (child_index, child) = children[0]; + let mut child = self.get_node(child)?; + + child.write(|child| { + let child_path = child.inner.path_mut(); + let path = branch_path + .iter() + .copied() + .chain(once(child_index as u8)) + .chain(child_path.0.iter().copied()) + .collect(); + *child_path = PartialPath(path); + + child.rehash(); + })?; + + set_parent(child.as_ptr(), &mut parents); + + deleted.push(node.as_ptr()); + } else { + node.write(|node| { + node.as_branch_mut().value = None; + node.rehash(); + })? + } + + data + } + + NodeType::Leaf(n) => { + let data = Some(n.data.clone()); + + // TODO: handle unwrap better + deleted.push(node.as_ptr()); + + let (mut parent, child_index) = parents.pop().expect("parents is never empty"); + + #[allow(clippy::indexing_slicing)] + parent.write(|parent| { + parent.as_branch_mut().children[child_index as usize] = None; + })?; + + let branch = parent + .inner + .as_branch() + .expect("parents are always branch nodes"); + + let children: Vec<_> = branch + .children + .iter() + .enumerate() + .filter_map(|(i, child)| child.map(|child| (i, child))) + .collect(); + + match (children.len(), &branch.value, !parents.is_empty()) { + // node is invalid, all single-child nodes should have data + (1, None, true) => { + let parent_path = &branch.path.0; + + #[allow(clippy::indexing_slicing)] + let (child_index, child) = children[0]; + let child = self.get_node(child)?; + + // TODO: + // there's an optimization here for when the paths are the same length + // and that clone isn't great but ObjRef causes problems + // we can't write directly to the child because we could be changing its size + let new_child = match child.inner.clone() { + NodeType::Branch(mut child) => { + let path = parent_path + .iter() + .copied() + .chain(once(child_index as u8)) + .chain(child.path.0.iter().copied()) + .collect(); + + child.path = PartialPath(path); + + Node::from_branch(child) + } + NodeType::Leaf(mut child) => { + let path = parent_path + .iter() + .copied() + .chain(once(child_index as u8)) + .chain(child.path.0.iter().copied()) + .collect(); + + child.path = PartialPath(path); + + Node::from_leaf(child) + } + NodeType::Extension(_) => todo!(), + }; + + let child = self.put_node(new_child)?.as_ptr(); + + set_parent(child, &mut parents); + + deleted.push(parent.as_ptr()); + } + + // branch nodes shouldn't have no children + (0, Some(data), true) => { + let leaf = Node::from_leaf(LeafNode::new( + PartialPath(branch.path.0.clone()), + data.clone(), + )); + + let leaf = self.put_node(leaf)?.as_ptr(); + set_parent(leaf, &mut parents); + + deleted.push(parent.as_ptr()); + } + + _ => parent.write(|parent| parent.rehash())?, + } + + data + } + + NodeType::Extension(_) => todo!(), + }; + + for (mut parent, _) in parents { + parent.write(|u| u.rehash())?; + } + + data + }; + + for ptr in deleted.into_iter() { + self.free_node(ptr)?; + } + + Ok(data.map(|data| data.0)) + } + + pub fn remove_old>( + &mut self, + key: K, + root: DiskAddress, + ) -> Result>, MerkleError> { + if root.is_null() { + return Ok(None); + } + let (found, parents, deleted) = { let (node_ref, mut parents) = self.get_node_and_parents_by_key(self.get_node(root)?, key)?; @@ -1136,7 +1548,7 @@ impl + Send + Sync, T> Merkle { let mut key_nibbles = Nibbles::<1>::new(key.as_ref()).into_iter(); loop { - let Some(nib) = key_nibbles.next() else { + let Some(mut nib) = key_nibbles.next() else { break; }; @@ -1144,10 +1556,41 @@ impl + Send + Sync, T> Merkle { let next_ptr = match &node_ref.inner { #[allow(clippy::indexing_slicing)] - NodeType::Branch(n) => match n.children[nib as usize] { + NodeType::Branch(n) if n.path.len() == 0 => match n.children[nib as usize] { Some(c) => c, None => return Ok(None), }, + NodeType::Branch(n) => { + let mut n_path_iter = n.path.iter().copied(); + + if n_path_iter.next() != Some(nib) { + return Ok(None); + } + + let path_matches = n_path_iter + .map(Some) + .all(|n_path_nibble| key_nibbles.next() == n_path_nibble); + + if !path_matches { + return Ok(None); + } + + nib = if let Some(nib) = key_nibbles.next() { + nib + } else { + return Ok(if n.value.is_some() { + Some(node_ref) + } else { + None + }); + }; + + #[allow(clippy::indexing_slicing)] + match n.children[nib as usize] { + Some(c) => c, + None => return Ok(None), + } + } NodeType::Leaf(n) => { let node_ref = if once(nib).chain(key_nibbles).eq(n.path.iter().copied()) { Some(node_ref) @@ -1183,7 +1626,9 @@ impl + Send + Sync, T> Merkle { // when we're done iterating over nibbles, check if the node we're at has a value let node_ref = match &node_ref.inner { - NodeType::Branch(n) if n.value.as_ref().is_some() => Some(node_ref), + NodeType::Branch(n) if n.value.as_ref().is_some() && n.path.is_empty() => { + Some(node_ref) + } NodeType::Leaf(n) if n.path.len() == 0 => Some(node_ref), _ => None, }; @@ -1472,6 +1917,41 @@ pub fn from_nibbles(nibbles: &[u8]) -> impl Iterator + '_ { nibbles.chunks_exact(2).map(|p| (p[0] << 4) | p[1]) } +/// The [`PrefixOverlap`] type represents the _shared_ and _unique_ parts of two potentially overlapping slices. +/// As the type-name implies, the `shared` property only constitues a shared *prefix*. +/// The `unique_*` properties, [`unique_a`][`PrefixOverlap::unique_a`] and [`unique_b`][`PrefixOverlap::unique_b`] +/// are set based on the argument order passed into the [`from`][`PrefixOverlap::from`] constructor. +#[derive(Debug)] +struct PrefixOverlap<'a, T> { + shared: &'a [T], + unique_a: &'a [T], + unique_b: &'a [T], +} + +impl<'a, T: PartialEq> PrefixOverlap<'a, T> { + fn from(a: &'a [T], b: &'a [T]) -> Self { + let mut split_index = 0; + + #[allow(clippy::indexing_slicing)] + for i in 0..std::cmp::min(a.len(), b.len()) { + if a[i] != b[i] { + break; + } + + split_index += 1; + } + + let (shared, unique_a) = a.split_at(split_index); + let (_, unique_b) = b.split_at(split_index); + + Self { + shared, + unique_a, + unique_b, + } + } +} + #[cfg(test)] #[allow(clippy::indexing_slicing, clippy::unwrap_used)] mod tests { @@ -1542,9 +2022,36 @@ mod tests { create_generic_test_merkle::() } - fn branch(value: Option>, encoded_child: Option>) -> Node { + fn branch(path: &[u8], value: &[u8], encoded_child: Option>) -> Node { + let (path, value) = (path.to_vec(), value.to_vec()); + let path = Nibbles::<0>::new(&path); + let path = PartialPath(path.into_iter().collect()); + + let children = Default::default(); + // TODO: Properly test empty data as a value + let value = Some(Data(value)); + let mut children_encoded = <[Option>; BranchNode::MAX_CHILDREN]>::default(); + + if let Some(child) = encoded_child { + children_encoded[0] = Some(child); + } + + Node::from_branch(BranchNode { + path, + children, + value, + children_encoded, + }) + } + + fn branch_without_data(path: &[u8], encoded_child: Option>) -> Node { + let path = path.to_vec(); + let path = Nibbles::<0>::new(&path); + let path = PartialPath(path.into_iter().collect()); + let children = Default::default(); - let value = value.map(Data); + // TODO: Properly test empty data as a value + let value = None; let mut children_encoded = <[Option>; BranchNode::MAX_CHILDREN]>::default(); if let Some(child) = encoded_child { @@ -1552,7 +2059,7 @@ mod tests { } Node::from_branch(BranchNode { - // path: vec![].into(), + path, children, value, children_encoded, @@ -1561,11 +2068,19 @@ mod tests { #[test_case(leaf(Vec::new(), Vec::new()) ; "empty leaf encoding")] #[test_case(leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding")] - #[test_case(branch(Some(b"value".to_vec()), vec![1, 2, 3].into()) ; "branch with chd")] - #[test_case(branch(Some(b"value".to_vec()), None); "branch without chd")] - #[test_case(branch(None, None); "branch without value and chd")] + #[test_case(branch(b"", b"value", vec![1, 2, 3].into()) ; "branch with chd")] + #[test_case(branch(b"", b"value", None); "branch without chd")] + #[test_case(branch_without_data(b"", None); "branch without value and chd")] + #[test_case(branch(b"", b"", None); "branch without path value or children")] + #[test_case(branch(b"", b"value", None) ; "branch with value")] + #[test_case(branch(&[2], b"", None); "branch with path")] + #[test_case(branch(b"", b"", vec![1, 2, 3].into()); "branch with children")] + #[test_case(branch(&[2], b"value", None); "branch with path and value")] + #[test_case(branch(b"", b"value", vec![1, 2, 3].into()); "branch with value and children")] + #[test_case(branch(&[2], b"", vec![1, 2, 3].into()); "branch with path and children")] + #[test_case(branch(&[2], b"value", vec![1, 2, 3].into()); "branch with path value and children")] #[test_case(extension(vec![1, 2, 3], DiskAddress::null(), vec![4, 5].into()) ; "extension without child address")] - fn encode_(node: Node) { + fn encode(node: Node) { let merkle = create_test_merkle(); let node_ref = merkle.put_node(node).unwrap(); @@ -1578,14 +2093,14 @@ mod tests { #[test_case(Bincode::new(), leaf(Vec::new(), Vec::new()) ; "empty leaf encoding with Bincode")] #[test_case(Bincode::new(), leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding with Bincode")] - #[test_case(Bincode::new(), branch(Some(b"value".to_vec()), vec![1, 2, 3].into()) ; "branch with chd with Bincode")] - #[test_case(Bincode::new(), branch(Some(b"value".to_vec()), None); "branch without chd with Bincode")] - #[test_case(Bincode::new(), branch(None, None); "branch without value and chd with Bincode")] + #[test_case(Bincode::new(), branch(b"", b"value", vec![1, 2, 3].into()) ; "branch with chd with Bincode")] + #[test_case(Bincode::new(), branch(b"", b"value", None); "branch without chd with Bincode")] + #[test_case(Bincode::new(), branch_without_data(b"", None); "branch without value and chd with Bincode")] #[test_case(PlainCodec::new(), leaf(Vec::new(), Vec::new()) ; "empty leaf encoding with PlainCodec")] #[test_case(PlainCodec::new(), leaf(vec![1, 2, 3], vec![4, 5]) ; "leaf encoding with PlainCodec")] - #[test_case(PlainCodec::new(), branch(Some(b"value".to_vec()), vec![1, 2, 3].into()) ; "branch with chd with PlainCodec")] - #[test_case(PlainCodec::new(), branch(Some(b"value".to_vec()), Some(Vec::new())); "branch with empty chd with PlainCodec")] - #[test_case(PlainCodec::new(), branch(Some(Vec::new()), vec![1, 2, 3].into()); "branch with empty value with PlainCodec")] + #[test_case(PlainCodec::new(), branch(b"", b"value", vec![1, 2, 3].into()) ; "branch with chd with PlainCodec")] + #[test_case(PlainCodec::new(), branch(b"", b"value", Some(Vec::new())); "branch with empty chd with PlainCodec")] + #[test_case(PlainCodec::new(), branch(b"", b"", vec![1, 2, 3].into()); "branch with empty value with PlainCodec")] fn node_encode_decode(_codec: T, node: Node) where T: BinarySerde, @@ -1644,6 +2159,60 @@ mod tests { } } + #[test] + fn long_insert_and_retrieve_multiple() { + let key_val: Vec<(&'static [u8], _)> = vec![ + ( + &[0, 0, 0, 1, 0, 101, 151, 236], + [16, 15, 159, 195, 34, 101, 227, 73], + ), + ( + &[0, 0, 1, 107, 198, 92, 205], + [26, 147, 21, 200, 138, 106, 137, 218], + ), + (&[0, 1, 0, 1, 0, 56], [194, 147, 168, 193, 19, 226, 51, 204]), + (&[1, 90], [101, 38, 25, 65, 181, 79, 88, 223]), + ( + &[1, 1, 1, 0, 0, 0, 1, 59], + [105, 173, 182, 126, 67, 166, 166, 196], + ), + ( + &[0, 1, 0, 0, 1, 1, 55, 33, 38, 194], + [90, 140, 160, 53, 230, 100, 237, 236], + ), + ( + &[1, 1, 0, 1, 249, 46, 69], + [16, 104, 134, 6, 57, 46, 200, 35], + ), + ( + &[1, 1, 0, 1, 0, 0, 1, 33, 163], + [95, 97, 187, 124, 198, 28, 75, 226], + ), + ( + &[1, 1, 0, 1, 0, 57, 156], + [184, 18, 69, 29, 96, 252, 188, 58], + ), + (&[1, 0, 1, 1, 0, 218], [155, 38, 43, 54, 93, 134, 73, 209]), + ]; + + let mut merkle = create_test_merkle(); + let root = merkle.init_root().unwrap(); + + for (key, val) in &key_val { + merkle.insert(key, val.to_vec(), root).unwrap(); + + let fetched_val = merkle.get(key, root).unwrap(); + + assert_eq!(fetched_val.as_deref(), val.as_slice().into()); + } + + for (key, val) in key_val { + let fetched_val = merkle.get(key, root).unwrap(); + + assert_eq!(fetched_val.as_deref(), val.as_slice().into()); + } + } + #[test] fn remove_one() { let key = b"hello"; @@ -1689,7 +2258,10 @@ mod tests { let key = &[key_val]; let val = &[key_val]; - let removed_val = merkle.remove(key, root).unwrap(); + let Ok(removed_val) = merkle.remove(key, root) else { + panic!("({key_val}, {key_val}) missing"); + }; + assert_eq!(removed_val.as_deref(), val.as_slice().into()); let fetched_val = merkle.get(key, root).unwrap(); diff --git a/firewood/src/merkle/node.rs b/firewood/src/merkle/node.rs index 076aa4cb7..3e3ac0f7a 100644 --- a/firewood/src/merkle/node.rs +++ b/firewood/src/merkle/node.rs @@ -91,6 +91,13 @@ impl> Encoded { Encoded::Data(data) => bincode::DefaultOptions::new().deserialize(data.as_ref()), } } + + pub fn deserialize(self) -> Result { + match self { + Encoded::Raw(raw) => Ok(raw), + Encoded::Data(data) => De::deserialize(data.as_ref()), + } + } } #[derive(PartialEq, Eq, Clone, Debug, EnumAsInner)] @@ -148,11 +155,19 @@ impl NodeType { pub fn path_mut(&mut self) -> &mut PartialPath { match self { - NodeType::Branch(_u) => todo!(), + NodeType::Branch(u) => &mut u.path, NodeType::Leaf(node) => &mut node.path, NodeType::Extension(node) => &mut node.path, } } + + pub fn set_data(&mut self, data: Data) { + match self { + NodeType::Branch(u) => u.value = Some(data), + NodeType::Leaf(node) => node.data = data, + NodeType::Extension(_) => (), + } + } } #[derive(Debug)] @@ -233,7 +248,7 @@ impl Node { is_encoded_longer_than_hash_len: OnceLock::new(), inner: NodeType::Branch( BranchNode { - // path: vec![].into(), + path: vec![].into(), children: [Some(DiskAddress::null()); BranchNode::MAX_CHILDREN], value: Some(Data(Vec::new())), children_encoded: Default::default(), @@ -316,6 +331,12 @@ impl Node { pub(super) fn set_dirty(&self, is_dirty: bool) { self.lazy_dirty.store(is_dirty, Ordering::Relaxed) } + + pub(crate) fn as_branch_mut(&mut self) -> &mut Box { + self.inner_mut() + .as_branch_mut() + .expect("must be a branch node") + } } #[derive(Clone, Copy, CheckedBitPattern, NoUninit)] @@ -531,6 +552,7 @@ impl EncodedNode { pub enum EncodedNodeType { Leaf(LeafNode), Branch { + path: PartialPath, children: Box<[Option>; BranchNode::MAX_CHILDREN]>, value: Option, }, @@ -550,14 +572,19 @@ impl Serialize for EncodedNode { where S: serde::Serializer, { - let n = match &self.node { + let (chd, data, path) = match &self.node { EncodedNodeType::Leaf(n) => { - let data = Some(n.data.to_vec()); + let data = Some(&*n.data); let chd: Vec<(u64, Vec)> = Default::default(); - let path = from_nibbles(&n.path.encode(true)).collect(); - EncodedBranchNode { chd, data, path } + let path: Vec<_> = from_nibbles(&n.path.encode(true)).collect(); + (chd, data, path) } - EncodedNodeType::Branch { children, value } => { + + EncodedNodeType::Branch { + path, + children, + value, + } => { let chd: Vec<(u64, Vec)> = children .iter() .enumerate() @@ -571,19 +598,20 @@ impl Serialize for EncodedNode { }) .collect(); - let data = value.as_ref().map(|v| v.0.to_vec()); - EncodedBranchNode { - chd, - data, - path: Vec::new(), - } + let data = value.as_deref(); + + let path = from_nibbles(&path.encode(false)).collect(); + + (chd, data, path) } }; let mut s = serializer.serialize_tuple(3)?; - s.serialize_element(&n.chd)?; - s.serialize_element(&n.data)?; - s.serialize_element(&n.path)?; + + s.serialize_element(&chd)?; + s.serialize_element(&data)?; + s.serialize_element(&path)?; + s.end() } } @@ -593,30 +621,35 @@ impl<'de> Deserialize<'de> for EncodedNode { where D: serde::Deserializer<'de>, { - let node: EncodedBranchNode = Deserialize::deserialize(deserializer)?; - if node.chd.is_empty() { - let data = if let Some(d) = node.data { + let EncodedBranchNode { chd, data, path } = Deserialize::deserialize(deserializer)?; + + let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()).0; + + if chd.is_empty() { + let data = if let Some(d) = data { Data(d) } else { Data(Vec::new()) }; - let path = PartialPath::from_nibbles(Nibbles::<0>::new(&node.path).into_iter()).0; let node = EncodedNodeType::Leaf(LeafNode { path, data }); + Ok(Self::new(node)) } else { let mut children: [Option>; BranchNode::MAX_CHILDREN] = Default::default(); - let value = node.data.map(Data); + let value = data.map(Data); - for (i, chd) in node.chd { - #[allow(clippy::indexing_slicing)] - (children[i as usize] = Some(chd)); + #[allow(clippy::indexing_slicing)] + for (i, chd) in chd { + children[i as usize] = Some(chd); } let node = EncodedNodeType::Branch { + path, children: children.into(), value, }; + Ok(Self::new(node)) } } @@ -639,34 +672,50 @@ impl Serialize for EncodedNode { } seq.end() } - EncodedNodeType::Branch { children, value } => { - let mut list = <[Encoded>; BranchNode::MAX_CHILDREN + 1]>::default(); - for (i, c) in children + EncodedNodeType::Branch { + path, + children, + value, + } => { + let mut list = <[Encoded>; BranchNode::MAX_CHILDREN + 2]>::default(); + let children = children .iter() .enumerate() - .filter_map(|(i, c)| c.as_ref().map(|c| (i, c))) - { - if c.len() >= TRIE_HASH_LEN { - let serialized_hash = Bincode::serialize(&Keccak256::digest(c).to_vec()) - .map_err(|e| S::Error::custom(format!("bincode error: {e}")))?; - #[allow(clippy::indexing_slicing)] - (list[i] = Encoded::Data(serialized_hash)); + .filter_map(|(i, c)| c.as_ref().map(|c| (i, c))); + + #[allow(clippy::indexing_slicing)] + for (i, child) in children { + if child.len() >= TRIE_HASH_LEN { + let serialized_hash = + Bincode::serialize(&Keccak256::digest(child).to_vec()) + .map_err(|e| S::Error::custom(format!("bincode error: {e}")))?; + list[i] = Encoded::Data(serialized_hash); } else { - #[allow(clippy::indexing_slicing)] - (list[i] = Encoded::Raw(c.to_vec())); + list[i] = Encoded::Raw(child.to_vec()); } } - if let Some(Data(val)) = &value { + + list[BranchNode::MAX_CHILDREN] = if let Some(Data(val)) = &value { let serialized_val = Bincode::serialize(val) .map_err(|e| S::Error::custom(format!("bincode error: {e}")))?; - list[BranchNode::MAX_CHILDREN] = Encoded::Data(serialized_val); - } + + Encoded::Data(serialized_val) + } else { + Encoded::default() + }; + + let serialized_path = Bincode::serialize(&path.encode(false)) + .map_err(|e| S::Error::custom(format!("bincode error: {e}")))?; + + list[BranchNode::MAX_CHILDREN + 1] = Encoded::Data(serialized_path); let mut seq = serializer.serialize_seq(Some(list.len()))?; + for e in list { seq.serialize_element(&e)?; } + seq.end() } } @@ -680,8 +729,9 @@ impl<'de> Deserialize<'de> for EncodedNode { { use serde::de::Error; - let items: Vec>> = Deserialize::deserialize(deserializer)?; + let mut items: Vec>> = Deserialize::deserialize(deserializer)?; let len = items.len(); + match len { LEAF_NODE_SIZE => { let mut items = items.into_iter(); @@ -702,10 +752,25 @@ impl<'de> Deserialize<'de> for EncodedNode { }); Ok(Self::new(node)) } + BranchNode::MSIZE => { + let path = items + .pop() + .unwrap_or_default() + .deserialize::() + .map_err(D::Error::custom)?; + let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()).0; + + let mut value = items + .pop() + .unwrap_or_default() + .deserialize::() + .map_err(D::Error::custom) + .map(Data) + .map(Some)? + .filter(|data| !data.is_empty()); + let mut children: [Option>; BranchNode::MAX_CHILDREN] = Default::default(); - let mut value: Option = Default::default(); - let len = items.len(); for (i, chd) in items.into_iter().enumerate() { if i == len - 1 { @@ -729,11 +794,17 @@ impl<'de> Deserialize<'de> for EncodedNode { (children[i] = Some(chd).filter(|chd| !chd.is_empty())); } } + let node = EncodedNodeType::Branch { + path, children: children.into(), value, }; - Ok(Self::new(node)) + + Ok(Self { + node, + phantom: PhantomData, + }) } size => Err(D::Error::custom(format!("invalid size: {size}"))), } @@ -847,7 +918,7 @@ mod tests { ) { let leaf = NodeType::Leaf(LeafNode::new(PartialPath(vec![1, 2, 3]), Data(vec![4, 5]))); let branch = NodeType::Branch(Box::new(BranchNode { - // path: vec![].into(), + path: vec![].into(), children: [Some(DiskAddress::from(1)); BranchNode::MAX_CHILDREN], value: Some(Data(vec![1, 2, 3])), children_encoded: std::array::from_fn(|_| Some(vec![1])), @@ -904,6 +975,7 @@ mod tests { } #[test_matrix( + [&[], &[0xf], &[0xf, 0xf]], [vec![], vec![1,0,0,0,0,0,0,1], vec![1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1], repeat(1).take(16).collect()], [Nil, 0, 15], [ @@ -915,10 +987,13 @@ mod tests { ] )] fn branch_encoding( + path: &[u8], children: Vec, value: impl Into>, children_encoded: [Option>; BranchNode::MAX_CHILDREN], ) { + let path = PartialPath(path.iter().copied().map(|x| x & 0xf).collect()); + let mut children = children.into_iter().map(|x| { if x == 0 { None @@ -934,7 +1009,7 @@ mod tests { .map(|x| Data(std::iter::repeat(x).take(x as usize).collect())); let node = Node::from_branch(BranchNode { - // path: vec![].into(), + path, children, value, children_encoded, diff --git a/firewood/src/merkle/node/branch.rs b/firewood/src/merkle/node/branch.rs index 19b15116f..53b8bac6e 100644 --- a/firewood/src/merkle/node/branch.rs +++ b/firewood/src/merkle/node/branch.rs @@ -3,17 +3,19 @@ use super::{Data, Encoded, Node}; use crate::{ - merkle::{PartialPath, TRIE_HASH_LEN}, - shale::{DiskAddress, Storable}, - shale::{ShaleError, ShaleStore}, + merkle::{from_nibbles, to_nibble_array, PartialPath, TRIE_HASH_LEN}, + nibbles::Nibbles, + shale::{DiskAddress, ShaleError, ShaleStore, Storable}, }; use bincode::{Error, Options}; +use serde::de::Error as DeError; use std::{ fmt::{Debug, Error as FmtError, Formatter}, io::{Cursor, Read, Write}, mem::size_of, }; +type PathLen = u8; pub type DataLen = u32; pub type EncodedChildLen = u8; @@ -21,7 +23,7 @@ const MAX_CHILDREN: usize = 16; #[derive(PartialEq, Eq, Clone)] pub struct BranchNode { - // pub(crate) path: PartialPath, + pub(crate) path: PartialPath, pub(crate) children: [Option; MAX_CHILDREN], pub(crate) value: Option, pub(crate) children_encoded: [Option>; MAX_CHILDREN], @@ -30,7 +32,7 @@ pub struct BranchNode { impl Debug for BranchNode { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { write!(f, "[Branch")?; - // write!(f, " path={:?}", self.path)?; + write!(f, r#" path="{:?}""#, self.path)?; for (i, c) in self.children.iter().enumerate() { if let Some(c) = c { @@ -57,16 +59,16 @@ impl Debug for BranchNode { impl BranchNode { pub const MAX_CHILDREN: usize = MAX_CHILDREN; - pub const MSIZE: usize = Self::MAX_CHILDREN + 1; + pub const MSIZE: usize = Self::MAX_CHILDREN + 2; pub fn new( - _path: PartialPath, + path: PartialPath, chd: [Option; Self::MAX_CHILDREN], value: Option>, chd_encoded: [Option>; Self::MAX_CHILDREN], ) -> Self { BranchNode { - // path, + path, children: chd, value: value.map(Data), children_encoded: chd_encoded, @@ -112,6 +114,13 @@ impl BranchNode { pub(super) fn decode(buf: &[u8]) -> Result { let mut items: Vec>> = bincode::DefaultOptions::new().deserialize(buf)?; + let path = items + .pop() + .ok_or(Error::custom("Invalid Branch Node"))? + .decode()?; + let path = Nibbles::<0>::new(&path); + let (path, _term) = PartialPath::from_nibbles(path.into_iter()); + // we've already validated the size, that's why we can safely unwrap #[allow(clippy::unwrap_used)] let data = items.pop().unwrap().decode()?; @@ -128,9 +137,6 @@ impl BranchNode { (chd_encoded[i] = Some(data).filter(|data| !data.is_empty())); } - // TODO: add path - let path = Vec::new().into(); - Ok(BranchNode::new( path, [None; Self::MAX_CHILDREN], @@ -140,8 +146,8 @@ impl BranchNode { } pub(super) fn encode>(&self, store: &S) -> Vec { - // TODO: add path to encoded node - let mut list = <[Encoded>; Self::MAX_CHILDREN + 1]>::default(); + // path + children + value + let mut list = <[Encoded>; Self::MSIZE]>::default(); for (i, c) in self.children.iter().enumerate() { match c { @@ -202,9 +208,17 @@ impl BranchNode { } #[allow(clippy::unwrap_used)] + let path = from_nibbles(&self.path.encode(false)).collect::>(); + + list[Self::MAX_CHILDREN + 1] = Encoded::Data( + bincode::DefaultOptions::new() + .serialize(&path) + .expect("serializing raw bytes to always succeed"), + ); + bincode::DefaultOptions::new() .serialize(list.as_slice()) - .unwrap() + .expect("serializing `Encoded` to always succeed") } } @@ -215,13 +229,19 @@ impl Storable for BranchNode { let children_encoded_len = self.children_encoded.iter().fold(0, |len, child| { len + optional_data_len::(child.as_ref()) }); + let path_len_size = size_of::() as u64; + let path_len = self.path.serialized_len(); - children_len + data_len + children_encoded_len + children_len + data_len + children_encoded_len + path_len_size + path_len } fn serialize(&self, to: &mut [u8]) -> Result<(), crate::shale::ShaleError> { let mut cursor = Cursor::new(to); + let path: Vec = from_nibbles(&self.path.encode(false)).collect(); + cursor.write_all(&[path.len() as PathLen])?; + cursor.write_all(&path)?; + for child in &self.children { let bytes = child.map(|addr| addr.to_le_bytes()).unwrap_or_default(); cursor.write_all(&bytes)?; @@ -253,10 +273,42 @@ impl Storable for BranchNode { mut addr: usize, mem: &T, ) -> Result { + const PATH_LEN_SIZE: u64 = size_of::() as u64; const DATA_LEN_SIZE: usize = size_of::(); const BRANCH_HEADER_SIZE: u64 = BranchNode::MAX_CHILDREN as u64 * DiskAddress::MSIZE + DATA_LEN_SIZE as u64; + let path_len = mem + .get_view(addr, PATH_LEN_SIZE) + .ok_or(ShaleError::InvalidCacheView { + offset: addr, + size: PATH_LEN_SIZE, + })? + .as_deref(); + + addr += PATH_LEN_SIZE as usize; + + let path_len = { + let mut buf = [0u8; PATH_LEN_SIZE as usize]; + let mut cursor = Cursor::new(path_len); + cursor.read_exact(buf.as_mut())?; + + PathLen::from_le_bytes(buf) as u64 + }; + + let path = mem + .get_view(addr, path_len) + .ok_or(ShaleError::InvalidCacheView { + offset: addr, + size: path_len, + })? + .as_deref(); + + addr += path_len as usize; + + let path: Vec = path.into_iter().flat_map(to_nibble_array).collect(); + let path = PartialPath::decode(&path).0; + let node_raw = mem.get_view(addr, BRANCH_HEADER_SIZE) .ok_or(ShaleError::InvalidCacheView { @@ -342,8 +394,7 @@ impl Storable for BranchNode { } let node = BranchNode { - // TODO: add path - // path: Vec::new().into(), + path, children, value, children_encoded, diff --git a/firewood/src/merkle/node/partial_path.rs b/firewood/src/merkle/node/partial_path.rs index 4646ea5ca..1e1d5a74c 100644 --- a/firewood/src/merkle/node/partial_path.rs +++ b/firewood/src/merkle/node/partial_path.rs @@ -40,7 +40,7 @@ impl PartialPath { self.0 } - pub(super) fn encode(&self, is_terminal: bool) -> Vec { + pub(crate) fn encode(&self, is_terminal: bool) -> Vec { let mut flags = Flags::empty(); if is_terminal { diff --git a/firewood/src/merkle/proof.rs b/firewood/src/merkle/proof.rs index 02281a5ea..4eaf0890b 100644 --- a/firewood/src/merkle/proof.rs +++ b/firewood/src/merkle/proof.rs @@ -142,7 +142,7 @@ impl + Send> Proof { } } - pub fn concat_proofs(&mut self, other: Proof) { + pub fn extend(&mut self, other: Proof) { self.0.extend(other.0) } @@ -356,6 +356,16 @@ impl + Send> Proof { } NodeType::Branch(n) => { + let paths_match = n + .path + .iter() + .copied() + .all(|nibble| Some(nibble) == key_nibbles.next()); + + if !paths_match { + break None; + } + if let Some(index) = key_nibbles.peek() { let subproof = n .chd_encode() @@ -440,6 +450,17 @@ fn locate_subproof( Ok((sub_proof.into(), key_nibbles)) } NodeType::Branch(n) => { + let partial_path = &n.path.0; + + let does_not_match = key_nibbles.size_hint().0 < partial_path.len() + || !partial_path + .iter() + .all(|val| key_nibbles.next() == Some(*val)); + + if does_not_match { + return Ok((None, Nibbles::<0>::new(&[]).into_iter())); + } + let Some(index) = key_nibbles.next().map(|nib| nib as usize) else { let encoded = n.value; @@ -512,13 +533,30 @@ fn unset_internal, S: ShaleStore + Send + Sync, T: BinarySe let mut u_ref = merkle.get_node(root).map_err(|_| ProofError::NoSuchNode)?; let mut parent = DiskAddress::null(); - let mut fork_left: Ordering = Ordering::Equal; - let mut fork_right: Ordering = Ordering::Equal; + let mut fork_left = Ordering::Equal; + let mut fork_right = Ordering::Equal; let mut index = 0; loop { match &u_ref.inner() { + #[allow(clippy::indexing_slicing)] NodeType::Branch(n) => { + // If either the key of left proof or right proof doesn't match with + // stop here, this is the forkpoint. + let path = &*n.path; + + if !path.is_empty() { + [fork_left, fork_right] = [&left_chunks[index..], &right_chunks[index..]] + .map(|chunks| chunks.chunks(path.len()).next().unwrap_or_default()) + .map(|key| key.cmp(path)); + + if !fork_left.is_eq() || !fork_right.is_eq() { + break; + } + + index += path.len(); + } + // If either the node pointed by left proof or right proof is nil, // stop here and the forkpoint is the fullnode. #[allow(clippy::indexing_slicing)] @@ -571,6 +609,61 @@ fn unset_internal, S: ShaleStore + Send + Sync, T: BinarySe match &u_ref.inner() { NodeType::Branch(n) => { + if fork_left.is_lt() && fork_right.is_lt() { + return Err(ProofError::EmptyRange); + } + + if fork_left.is_gt() && fork_right.is_gt() { + return Err(ProofError::EmptyRange); + } + + if fork_left.is_ne() && fork_right.is_ne() { + // The fork point is root node, unset the entire trie + if parent.is_null() { + return Ok(true); + } + + let mut p_ref = merkle + .get_node(parent) + .map_err(|_| ProofError::NoSuchNode)?; + #[allow(clippy::unwrap_used)] + p_ref + .write(|p| { + let pp = p.inner_mut().as_branch_mut().expect("not a branch node"); + #[allow(clippy::indexing_slicing)] + (pp.chd_mut()[left_chunks[index - 1] as usize] = None); + #[allow(clippy::indexing_slicing)] + (pp.chd_encoded_mut()[left_chunks[index - 1] as usize] = None); + }) + .unwrap(); + + return Ok(false); + } + + let p = u_ref.as_ptr(); + index += n.path.len(); + + // Only one proof points to non-existent key. + if fork_right.is_ne() { + #[allow(clippy::indexing_slicing)] + let left_node = n.chd()[left_chunks[index] as usize]; + + drop(u_ref); + #[allow(clippy::indexing_slicing)] + unset_node_ref(merkle, p, left_node, &left_chunks[index..], 1, false)?; + return Ok(false); + } + + if fork_left.is_ne() { + #[allow(clippy::indexing_slicing)] + let right_node = n.chd()[right_chunks[index] as usize]; + + drop(u_ref); + #[allow(clippy::indexing_slicing)] + unset_node_ref(merkle, p, right_node, &right_chunks[index..], 1, true)?; + return Ok(false); + }; + #[allow(clippy::indexing_slicing)] let left_node = n.chd()[left_chunks[index] as usize]; #[allow(clippy::indexing_slicing)] @@ -590,12 +683,14 @@ fn unset_internal, S: ShaleStore + Send + Sync, T: BinarySe .unwrap(); } - let p = u_ref.as_ptr(); drop(u_ref); + #[allow(clippy::indexing_slicing)] unset_node_ref(merkle, p, left_node, &left_chunks[index..], 1, false)?; + #[allow(clippy::indexing_slicing)] unset_node_ref(merkle, p, right_node, &right_chunks[index..], 1, true)?; + Ok(false) } @@ -606,9 +701,6 @@ fn unset_internal, S: ShaleStore + Send + Sync, T: BinarySe // - left proof is less and right proof is greater => valid range, unset the shortnode entirely // - left proof points to the shortnode, but right proof is greater // - right proof points to the shortnode, but left proof is less - let node = n.chd(); - let cur_key = n.path.clone().into_inner(); - if fork_left.is_lt() && fork_right.is_lt() { return Err(ProofError::EmptyRange); } @@ -640,6 +732,9 @@ fn unset_internal, S: ShaleStore + Send + Sync, T: BinarySe return Ok(false); } + let node = n.chd(); + let index = n.path.len(); + let p = u_ref.as_ptr(); drop(u_ref); @@ -651,7 +746,7 @@ fn unset_internal, S: ShaleStore + Send + Sync, T: BinarySe Some(node), #[allow(clippy::indexing_slicing)] &left_chunks[index..], - cur_key.len(), + index, false, )?; @@ -665,7 +760,7 @@ fn unset_internal, S: ShaleStore + Send + Sync, T: BinarySe Some(node), #[allow(clippy::indexing_slicing)] &right_chunks[index..], - cur_key.len(), + index, true, )?; @@ -752,27 +847,29 @@ fn unset_node_ref, S: ShaleStore + Send + Sync, T: BinarySe index: usize, remove_left: bool, ) -> Result<(), ProofError> { - if node.is_none() { + let Some(node) = node else { // If the node is nil, then it's a child of the fork point // fullnode(it's a non-existent branch). return Ok(()); - } + }; let mut chunks = Vec::new(); chunks.extend(key.as_ref()); #[allow(clippy::unwrap_used)] - let mut u_ref = merkle - .get_node(node.unwrap()) - .map_err(|_| ProofError::NoSuchNode)?; + let mut u_ref = merkle.get_node(node).map_err(|_| ProofError::NoSuchNode)?; let p = u_ref.as_ptr(); + if index >= chunks.len() { + return Err(ProofError::InvalidProof); + } + + #[allow(clippy::indexing_slicing)] match &u_ref.inner() { - NodeType::Branch(n) => { - #[allow(clippy::indexing_slicing)] + NodeType::Branch(n) if chunks[index..].starts_with(&n.path) => { + let index = index + n.path.len(); let child_index = chunks[index] as usize; - #[allow(clippy::indexing_slicing)] let node = n.chd()[child_index]; let iter = if remove_left { @@ -799,6 +896,52 @@ fn unset_node_ref, S: ShaleStore + Send + Sync, T: BinarySe unset_node_ref(merkle, p, node, key, index + 1, remove_left) } + NodeType::Branch(n) => { + let cur_key = &n.path; + + // Find the fork point, it's a non-existent branch. + // + // for (true, Ordering::Less) + // The key of fork shortnode is less than the path + // (it belongs to the range), unset the entire + // branch. The parent must be a fullnode. + // + // for (false, Ordering::Greater) + // The key of fork shortnode is greater than the + // path(it belongs to the range), unset the entrie + // branch. The parent must be a fullnode. Otherwise the + // key is not part of the range and should remain in the + // cached hash. + #[allow(clippy::indexing_slicing)] + let should_unset_entire_branch = matches!( + (remove_left, cur_key.cmp(&chunks[index..])), + (true, Ordering::Less) | (false, Ordering::Greater) + ); + + #[allow(clippy::indexing_slicing, clippy::unwrap_used)] + if should_unset_entire_branch { + let mut p_ref = merkle + .get_node(parent) + .map_err(|_| ProofError::NoSuchNode)?; + + p_ref + .write(|p| match p.inner_mut() { + NodeType::Branch(pp) => { + pp.chd_mut()[chunks[index - 1] as usize] = None; + pp.chd_encoded_mut()[chunks[index - 1] as usize] = None; + } + NodeType::Extension(n) => { + *n.chd_mut() = DiskAddress::null(); + *n.chd_encoded_mut() = None; + } + NodeType::Leaf(_) => (), + }) + .unwrap(); + } + + Ok(()) + } + #[allow(clippy::indexing_slicing)] NodeType::Extension(n) if chunks[index..].starts_with(&n.path) => { let node = Some(n.chd()); @@ -807,9 +950,6 @@ fn unset_node_ref, S: ShaleStore + Send + Sync, T: BinarySe NodeType::Extension(n) => { let cur_key = &n.path; - let mut p_ref = merkle - .get_node(parent) - .map_err(|_| ProofError::NoSuchNode)?; // Find the fork point, it's a non-existent branch. // @@ -832,6 +972,10 @@ fn unset_node_ref, S: ShaleStore + Send + Sync, T: BinarySe #[allow(clippy::indexing_slicing, clippy::unwrap_used)] if should_unset_entire_branch { + let mut p_ref = merkle + .get_node(parent) + .map_err(|_| ProofError::NoSuchNode)?; + p_ref .write(|p| { let pp = p.inner_mut().as_branch_mut().expect("not a branch node"); diff --git a/firewood/src/merkle/stream.rs b/firewood/src/merkle/stream.rs index 773b1a519..6cd2c7f2a 100644 --- a/firewood/src/merkle/stream.rs +++ b/firewood/src/merkle/stream.rs @@ -3,6 +3,7 @@ use super::{node::Node, BranchNode, Merkle, NodeObjRef, NodeType}; use crate::{ + nibbles::Nibbles, shale::{DiskAddress, ShaleStore}, v2::api, }; @@ -18,6 +19,7 @@ enum IteratorState<'a> { StartAtKey(Key), /// Continue iterating after the last node in the `visited_node_path` Iterating { + check_child_nibble: bool, visited_node_path: Vec<(NodeObjRef<'a>, u8)>, }, } @@ -41,7 +43,7 @@ pub struct MerkleKeyValueStream<'a, S, T> { impl<'a, S: ShaleStore + Send + Sync, T> FusedStream for MerkleKeyValueStream<'a, S, T> { fn is_terminated(&self) -> bool { - matches!(&self.key_state, IteratorState::Iterating { visited_node_path } if visited_node_path.is_empty()) + matches!(&self.key_state, IteratorState::Iterating { visited_node_path, .. } if visited_node_path.is_empty()) } } @@ -88,6 +90,8 @@ impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleKeyValueStream<' .get_node(*merkle_root) .map_err(|e| api::Error::InternalError(Box::new(e)))?; + let mut check_child_nibble = false; + // traverse the trie along each nibble until we find a node with a value // TODO: merkle.iter_by_key(key) will simplify this entire code-block. let (found_node, mut visited_node_path) = { @@ -97,14 +101,51 @@ impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleKeyValueStream<' .get_node_by_key_with_callbacks( root_node, &key, - |node_addr, i| visited_node_path.push((node_addr, i)), + |node_addr, _| visited_node_path.push(node_addr), |_, _| {}, ) .map_err(|e| api::Error::InternalError(Box::new(e)))?; + let mut nibbles = Nibbles::<1>::new(key).into_iter(); + let visited_node_path = visited_node_path .into_iter() - .map(|(node, pos)| merkle.get_node(node).map(|node| (node, pos))) + .map(|node| merkle.get_node(node)) + .map(|node_result| { + let nibbles = &mut nibbles; + + node_result + .map(|node| match node.inner() { + NodeType::Branch(branch) => { + let mut partial_path_iter = branch.path.iter(); + let next_nibble = nibbles + .map(|nibble| (Some(nibble), partial_path_iter.next())) + .find(|(a, b)| a.as_ref() != *b); + + match next_nibble { + // this case will be hit by all but the last nodes + // unless there is a deviation between the key and the path + None | Some((None, _)) => None, + + Some((Some(key_nibble), Some(path_nibble))) => { + check_child_nibble = key_nibble < *path_nibble; + None + } + + // path is subset of the key + Some((Some(nibble), None)) => { + check_child_nibble = true; + Some((node, nibble)) + } + } + } + NodeType::Leaf(_) => Some((node, 0)), + NodeType::Extension(_) => Some((node, 0)), + }) + .transpose() + }) + .take_while(|node| node.is_some()) + .flatten() .collect::, _>>() .map_err(|e| api::Error::InternalError(Box::new(e)))?; @@ -113,7 +154,10 @@ impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleKeyValueStream<' if let Some(found_node) = found_node { let value = match found_node.inner() { - NodeType::Branch(branch) => branch.value.as_ref(), + NodeType::Branch(branch) => { + check_child_nibble = true; + branch.value.as_ref() + } NodeType::Leaf(leaf) => Some(&leaf.data), NodeType::Extension(_) => None, }; @@ -126,7 +170,10 @@ impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleKeyValueStream<' visited_node_path.push((found_node, 0)); - self.key_state = IteratorState::Iterating { visited_node_path }; + self.key_state = IteratorState::Iterating { + check_child_nibble, + visited_node_path, + }; return Poll::Ready(next_result); } @@ -135,16 +182,23 @@ impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleKeyValueStream<' let found_key = key_from_nibble_iter(found_key); if found_key > *key { + check_child_nibble = false; visited_node_path.pop(); } - self.key_state = IteratorState::Iterating { visited_node_path }; + self.key_state = IteratorState::Iterating { + check_child_nibble, + visited_node_path, + }; self.poll_next(_cx) } - IteratorState::Iterating { visited_node_path } => { - let next = find_next_result(merkle, visited_node_path) + IteratorState::Iterating { + check_child_nibble, + visited_node_path, + } => { + let next = find_next_result(merkle, visited_node_path, check_child_nibble) .map_err(|e| api::Error::InternalError(Box::new(e))) .transpose(); @@ -184,20 +238,27 @@ impl<'a> NodeRef<'a> { fn find_next_result<'a, S: ShaleStore, T>( merkle: &'a Merkle, visited_path: &mut Vec<(NodeObjRef<'a>, u8)>, + check_child_nibble: &mut bool, ) -> Result, super::MerkleError> { - let next = find_next_node_with_data(merkle, visited_path)?.map(|(next_node, value)| { - let partial_path = match next_node.inner() { - NodeType::Leaf(leaf) => leaf.path.iter().copied(), - NodeType::Extension(extension) => extension.path.iter().copied(), - _ => [].iter().copied(), - }; + let next = find_next_node_with_data(merkle, visited_path, *check_child_nibble)?.map( + |(next_node, value)| { + let partial_path = match next_node.inner() { + NodeType::Leaf(leaf) => leaf.path.iter().copied(), + NodeType::Extension(extension) => extension.path.iter().copied(), + NodeType::Branch(branch) => branch.path.iter().copied(), + }; - let key = key_from_nibble_iter(nibble_iter_from_parents(visited_path).chain(partial_path)); + // always check the child for branch nodes with data + *check_child_nibble = next_node.inner().is_branch(); - visited_path.push((next_node, 0)); + let key = + key_from_nibble_iter(nibble_iter_from_parents(visited_path).chain(partial_path)); - (key, value) - }); + visited_path.push((next_node, 0)); + + (key, value) + }, + ); Ok(next) } @@ -205,6 +266,7 @@ fn find_next_result<'a, S: ShaleStore, T>( fn find_next_node_with_data<'a, S: ShaleStore, T>( merkle: &'a Merkle, visited_path: &mut Vec<(NodeObjRef<'a>, u8)>, + check_child_nibble: bool, ) -> Result, Vec)>, super::MerkleError> { use InnerNode::*; @@ -244,7 +306,7 @@ fn find_next_node_with_data<'a, S: ShaleStore, T>( Visited(NodeType::Branch(branch)) => { // if the first node that we check is a visited branch, that means that the branch had a value // and we need to visit the first child, for all other cases, we need to visit the next child - let compare_op = if first_loop { + let compare_op = if first_loop && check_child_nibble { ::ge // >= } else { ::gt @@ -328,7 +390,13 @@ fn nibble_iter_from_parents<'a>(parents: &'a [(NodeObjRef, u8)]) -> impl Iterato .iter() .skip(1) // always skip the sentinal node .flat_map(|(parent, child_nibble)| match parent.inner() { - NodeType::Branch(_) => Either::Left(std::iter::once(*child_nibble)), + NodeType::Branch(branch) => Either::Left( + branch + .path + .iter() + .copied() + .chain(std::iter::once(*child_nibble)), + ), NodeType::Extension(extension) => Either::Right(extension.path.iter().copied()), NodeType::Leaf(leaf) => Either::Right(leaf.path.iter().copied()), }) @@ -711,6 +779,56 @@ mod tests { check_stream_is_done(stream).await; } + #[tokio::test] + async fn start_at_key_overlapping_with_extension_but_greater() { + let start_key = 0x0a; + let shared_path = 0x09; + // 0x0900, 0x0901, ... 0x0a0f + // path extension is 0x090 + let children = (0..=0x0f).map(|val| vec![shared_path, val]); + + let mut merkle = create_test_merkle(); + let root = merkle.init_root().unwrap(); + + children.for_each(|key| { + merkle.insert(&key, key.clone(), root).unwrap(); + }); + + let stream = merkle.iter_from(root, vec![start_key].into_boxed_slice()); + + check_stream_is_done(stream).await; + } + + #[tokio::test] + async fn start_at_key_overlapping_with_extension_but_smaller() { + let start_key = 0x00; + let shared_path = 0x09; + // 0x0900, 0x0901, ... 0x0a0f + // path extension is 0x090 + let children = (0..=0x0f).map(|val| vec![shared_path, val]); + + let mut merkle = create_test_merkle(); + let root = merkle.init_root().unwrap(); + + let keys: Vec<_> = children + .map(|key| { + merkle.insert(&key, key.clone(), root).unwrap(); + key + }) + .collect(); + + let mut stream = merkle.iter_from(root, vec![start_key].into_boxed_slice()); + + for key in keys { + let next = stream.next().await.unwrap().unwrap(); + + assert_eq!(&*next.0, &*next.1); + assert_eq!(&*next.0, key); + } + + check_stream_is_done(stream).await; + } + #[tokio::test] async fn start_at_key_between_siblings() { let missing = 0xaa; diff --git a/firewood/src/merkle_util.rs b/firewood/src/merkle_util.rs index 328ae03cd..fd53908be 100644 --- a/firewood/src/merkle_util.rs +++ b/firewood/src/merkle_util.rs @@ -1,10 +1,15 @@ // Copyright (C) 2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE.md for licensing terms. -use crate::merkle::{BinarySerde, Bincode, Merkle, Node, Proof, ProofError, Ref, RefMut, TrieHash}; -use crate::shale::{ - self, cached::DynamicMem, compact::CompactSpace, disk_address::DiskAddress, CachedStore, - ShaleStore, StoredView, +use crate::{ + merkle::{ + proof::{Proof, ProofError}, + BinarySerde, Bincode, Merkle, Node, Ref, RefMut, TrieHash, + }, + shale::{ + self, cached::DynamicMem, compact::CompactSpace, disk_address::DiskAddress, CachedStore, + ShaleStore, StoredView, + }, }; use std::num::NonZeroUsize; use thiserror::Error; diff --git a/firewood/src/shale/compact.rs b/firewood/src/shale/compact.rs index 9b0f4b148..15306bceb 100644 --- a/firewood/src/shale/compact.rs +++ b/firewood/src/shale/compact.rs @@ -1,6 +1,7 @@ // Copyright (C) 2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE.md for licensing terms. +use crate::logger::trace; use crate::merkle::Node; use crate::shale::ObjCache; use crate::storage::{StoreRevMut, StoreRevShared}; @@ -13,17 +14,18 @@ use std::io::{Cursor, Write}; use std::num::NonZeroUsize; use std::sync::RwLock; -use crate::logger::trace; +type PayLoadSize = u64; #[derive(Debug)] pub struct CompactHeader { - payload_size: u64, + payload_size: PayLoadSize, is_freed: bool, desc_addr: DiskAddress, } impl CompactHeader { pub const MSIZE: u64 = 17; + pub const fn is_freed(&self) -> bool { self.is_freed } @@ -71,11 +73,11 @@ impl Storable for CompactHeader { #[derive(Debug)] struct CompactFooter { - payload_size: u64, + payload_size: PayLoadSize, } impl CompactFooter { - const MSIZE: u64 = 8; + const MSIZE: u64 = std::mem::size_of::() as u64; } impl Storable for CompactFooter { @@ -103,7 +105,7 @@ impl Storable for CompactFooter { #[derive(Clone, Copy, Debug)] struct CompactDescriptor { - payload_size: u64, + payload_size: PayLoadSize, haddr: usize, // disk address of the free space } diff --git a/firewood/src/shale/mod.rs b/firewood/src/shale/mod.rs index fc519d9b5..81fbd3b1e 100644 --- a/firewood/src/shale/mod.rs +++ b/firewood/src/shale/mod.rs @@ -318,6 +318,7 @@ impl StoredView { #[inline(always)] fn new(offset: usize, len_limit: u64, space: &U) -> Result { let decoded = T::deserialize(offset, space)?; + Ok(Self { offset, decoded, diff --git a/firewood/tests/merkle.rs b/firewood/tests/merkle.rs index 62b1ad006..34f314c67 100644 --- a/firewood/tests/merkle.rs +++ b/firewood/tests/merkle.rs @@ -8,17 +8,19 @@ use firewood::{ shale::{cached::DynamicMem, compact::CompactSpace}, }; use rand::Rng; -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Write}; type Store = CompactSpace; -fn merkle_build_test + std::cmp::Ord + Clone, V: AsRef<[u8]> + Clone>( +fn merkle_build_test< + K: AsRef<[u8]> + std::cmp::Ord + Clone + std::fmt::Debug, + V: AsRef<[u8]> + Clone, +>( items: Vec<(K, V)>, meta_size: u64, compact_size: u64, ) -> Result, DataStoreError> { let mut merkle = new_merkle(meta_size, compact_size); - for (k, v) in items.iter() { merkle.insert(k, v.as_ref().to_vec())?; } @@ -66,10 +68,12 @@ fn test_root_hash_fuzz_insertions() -> Result<(), DataStoreError> { for _ in 0..10 { let mut items = Vec::new(); + for _ in 0..10 { let val: Vec = (0..8).map(|_| rng.borrow_mut().gen()).collect(); items.push((keygen(), val)); } + merkle_build_test(items, 0x1000000, 0x1000000)?; } @@ -97,54 +101,51 @@ fn test_root_hash_reversed_deletions() -> Result<(), DataStoreError> { .collect(); key }; - for i in 0..10 { - let mut items = std::collections::HashMap::new(); - for _ in 0..10 { - let val: Vec = (0..8).map(|_| rng.borrow_mut().gen()).collect(); - items.insert(keygen(), val); - } - let mut items: Vec<_> = items.into_iter().collect(); + + for _ in 0..10 { + let mut items: Vec<_> = (0..10) + .map(|_| keygen()) + .map(|key| { + let val: Vec = (0..8).map(|_| rng.borrow_mut().gen()).collect(); + (key, val) + }) + .collect(); + items.sort(); + let mut merkle = new_merkle(0x100000, 0x100000); + let mut hashes = Vec::new(); - let mut dumps = Vec::new(); + for (k, v) in items.iter() { - dumps.push(merkle.dump()); + hashes.push((merkle.root_hash()?, merkle.dump()?)); merkle.insert(k, v.to_vec())?; - hashes.push(merkle.root_hash()); } - hashes.pop(); - println!("----"); - let mut prev_dump = merkle.dump()?; - for (((k, _), h), d) in items - .iter() - .rev() - .zip(hashes.iter().rev()) - .zip(dumps.iter().rev()) - { + + let mut new_hashes = Vec::new(); + + for (k, _) in items.iter().rev() { + let before = merkle.dump()?; merkle.remove(k)?; - let h0 = merkle.root_hash()?.0; - if h.as_ref().unwrap().0 != h0 { - for (k, _) in items.iter() { - println!("{}", hex::encode(k)); - } - println!( - "{} != {}", - hex::encode(**h.as_ref().unwrap()), - hex::encode(h0) - ); - println!("== before {} ===", hex::encode(k)); - print!("{prev_dump}"); - println!("== after {} ===", hex::encode(k)); - print!("{}", merkle.dump()?); - println!("== should be ==="); - print!("{:?}", d); - panic!(); - } - prev_dump = merkle.dump()?; + new_hashes.push((merkle.root_hash()?, k, before, merkle.dump()?)); + } + + hashes.reverse(); + + for i in 0..hashes.len() { + #[allow(clippy::indexing_slicing)] + let (new_hash, key, before_removal, after_removal) = &new_hashes[i]; + #[allow(clippy::indexing_slicing)] + let (expected_hash, expected_dump) = &hashes[i]; + let key = key.iter().fold(String::new(), |mut s, b| { + let _ = write!(s, "{:02x}", b); + s + }); + + assert_eq!(new_hash, expected_hash, "\n\nkey: {key}\nbefore:\n{before_removal}\nafter:\n{after_removal}\n\nexpected:\n{expected_dump}\n"); } - println!("i = {i}"); } + Ok(()) } @@ -169,38 +170,56 @@ fn test_root_hash_random_deletions() -> Result<(), DataStoreError> { .collect(); key }; + for i in 0..10 { let mut items = std::collections::HashMap::new(); + for _ in 0..10 { let val: Vec = (0..8).map(|_| rng.borrow_mut().gen()).collect(); items.insert(keygen(), val); } + let mut items_ordered: Vec<_> = items.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); items_ordered.sort(); items_ordered.shuffle(&mut *rng.borrow_mut()); let mut merkle = new_merkle(0x100000, 0x100000); + for (k, v) in items.iter() { merkle.insert(k, v.to_vec())?; } + + for (k, v) in items.iter() { + assert_eq!(&*merkle.get(k)?.unwrap(), &v[..]); + assert_eq!(&*merkle.get_mut(k)?.unwrap().get(), &v[..]); + } + for (k, _) in items_ordered.into_iter() { assert!(merkle.get(&k)?.is_some()); assert!(merkle.get_mut(&k)?.is_some()); + merkle.remove(&k)?; + assert!(merkle.get(&k)?.is_none()); assert!(merkle.get_mut(&k)?.is_none()); + items.remove(&k); + for (k, v) in items.iter() { assert_eq!(&*merkle.get(k)?.unwrap(), &v[..]); assert_eq!(&*merkle.get_mut(k)?.unwrap().get(), &v[..]); } + let h = triehash::trie_root::, _, _>( items.iter().collect(), ); + let h0 = merkle.root_hash()?; + if h[..] != *h0 { println!("{} != {}", hex::encode(h), hex::encode(*h0)); } } + println!("i = {i}"); } Ok(()) @@ -343,7 +362,7 @@ fn test_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(items[end - 1].0)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let mut keys = Vec::new(); let mut vals = Vec::new(); @@ -379,7 +398,7 @@ fn test_bad_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(items[end - 1].0)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let mut keys: Vec<[u8; 32]> = Vec::new(); let mut vals: Vec<[u8; 20]> = Vec::new(); @@ -476,7 +495,7 @@ fn test_range_proof_with_non_existent_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(last)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let mut keys: Vec<[u8; 32]> = Vec::new(); let mut vals: Vec<[u8; 20]> = Vec::new(); @@ -495,7 +514,7 @@ fn test_range_proof_with_non_existent_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(last)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let (keys, vals): (Vec<&[u8; 32]>, Vec<&[u8; 20]>) = items.into_iter().unzip(); merkle.verify_range_proof(&proof, &first, &last, keys, vals)?; @@ -523,7 +542,7 @@ fn test_range_proof_with_invalid_non_existent_proof() -> Result<(), ProofError> assert!(!proof.0.is_empty()); let end_proof = merkle.prove(items[end - 1].0)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); start = 105; // Gap created let mut keys: Vec<[u8; 32]> = Vec::new(); @@ -546,7 +565,7 @@ fn test_range_proof_with_invalid_non_existent_proof() -> Result<(), ProofError> assert!(!proof.0.is_empty()); let end_proof = merkle.prove(last)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); end = 195; // Capped slice let mut keys: Vec<[u8; 32]> = Vec::new(); @@ -593,7 +612,7 @@ fn test_one_element_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(items[start].0)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); merkle.verify_range_proof( &proof, @@ -609,7 +628,7 @@ fn test_one_element_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(last)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); merkle.verify_range_proof( &proof, @@ -624,7 +643,7 @@ fn test_one_element_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(last)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); merkle.verify_range_proof( &proof, @@ -644,7 +663,7 @@ fn test_one_element_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(key)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); merkle.verify_range_proof(&proof, first, key, vec![key], vec![val])?; @@ -683,7 +702,7 @@ fn test_all_elements_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(items[end].0)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); merkle.verify_range_proof( &proof, @@ -700,7 +719,7 @@ fn test_all_elements_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(last)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); merkle.verify_range_proof(&proof, &first, &last, keys, vals)?; @@ -762,7 +781,7 @@ fn test_gapped_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(items[last - 1].0)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let middle = (first + last) / 2 - first; let (keys, vals): (Vec<&[u8; 32]>, Vec<&[u8; 4]>) = items[first..last] @@ -797,7 +816,7 @@ fn test_same_side_proof() -> Result<(), DataStoreError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(last)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); assert!(merkle .verify_range_proof(&proof, first, last, vec![*items[pos].0], vec![items[pos].1]) @@ -811,7 +830,7 @@ fn test_same_side_proof() -> Result<(), DataStoreError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(last)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); assert!(merkle .verify_range_proof(&proof, first, last, vec![*items[pos].0], vec![items[pos].1]) @@ -842,7 +861,7 @@ fn test_single_side_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(items[case].0)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let item_iter = items.clone().into_iter().take(case + 1); let keys = item_iter.clone().map(|item| *item.0).collect(); @@ -876,7 +895,7 @@ fn test_reverse_single_side_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(end)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let item_iter = items.clone().into_iter().skip(case); let keys = item_iter.clone().map(|item| item.0).collect(); @@ -909,7 +928,7 @@ fn test_both_sides_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(end)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let (keys, vals): (Vec<&[u8; 32]>, Vec<&[u8; 20]>) = items.into_iter().unzip(); merkle.verify_range_proof(&proof, &start, &end, keys, vals)?; @@ -941,7 +960,7 @@ fn test_empty_value_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(items[end - 1].0)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let item_iter = items.clone().into_iter().skip(start).take(end - start); let keys = item_iter.clone().map(|item| item.0).collect(); @@ -977,7 +996,7 @@ fn test_all_elements_empty_value_range_proof() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(items[end].0)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let item_iter = items.clone().into_iter(); let keys = item_iter.clone().map(|item| item.0).collect(); @@ -1014,7 +1033,7 @@ fn test_range_proof_keys_with_shared_prefix() -> Result<(), ProofError> { assert!(!proof.0.is_empty()); let end_proof = merkle.prove(&end)?; assert!(!end_proof.0.is_empty()); - proof.concat_proofs(end_proof); + proof.extend(end_proof); let item_iter = items.into_iter(); let keys = item_iter.clone().map(|item| item.0).collect(); @@ -1051,7 +1070,7 @@ fn test_bloadted_range_proof() -> Result<(), ProofError> { for (i, item) in items.iter().enumerate() { let cur_proof = merkle.prove(item.0)?; assert!(!cur_proof.0.is_empty()); - proof.concat_proofs(cur_proof); + proof.extend(cur_proof); if i == 50 { keys.push(item.0); vals.push(item.1);