Skip to content

Commit

Permalink
chore: Use Option in branch nodes (#5)
Browse files Browse the repository at this point in the history
* use Option<Box<..>>

* Apply suggestions from code review

Co-authored-by: Rami <[email protected]>

* use sum

---------

Co-authored-by: Rami <[email protected]>
  • Loading branch information
Wollac and hashcashier authored Aug 9, 2023
1 parent a020f13 commit 6009e4d
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 152 deletions.
70 changes: 33 additions & 37 deletions lib/src/host/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::{
block_builder::BlockBuilder,
consts::{ChainSpec, MAINNET},
host::{
mpt::{load_pointers, orphaned_pointers, resolve_pointers, shorten_key},
mpt::{orphaned_digests, resolve_digests, shorten_key},
provider::{new_provider, BlockQuery},
},
mem_db::MemDb,
Expand Down Expand Up @@ -208,17 +208,13 @@ pub fn verify_state(
let fini_storage_keys = fini_db.storage_keys();

// Construct expected tries from fini proofs
let (mut nodes_by_pointer, mut storage) =
proofs_to_tries(fini_proofs.values().cloned().collect());
storage_deltas.values().for_each(|storage_trie| {
load_pointers(storage_trie, &mut nodes_by_pointer);
});
let (nodes_by_pointer, mut storage) = proofs_to_tries(fini_proofs.values().cloned().collect());
storage
.values_mut()
.for_each(|(n, _)| *n = resolve_pointers(n, &nodes_by_pointer));
.for_each(|(n, _)| *n = resolve_digests(n, &nodes_by_pointer));
storage_deltas
.values_mut()
.for_each(|n| *n = resolve_pointers(n, &nodes_by_pointer));
.for_each(|n| *n = resolve_digests(n, &nodes_by_pointer));

for (address, indices) in fini_storage_keys {
let mut address_errors = Vec::new();
Expand Down Expand Up @@ -333,13 +329,13 @@ fn proofs_to_tries(
HashMap<B160, StorageEntry>,
) {
// construct the proof tries
let mut nodes_by_pointer = HashMap::new();
let mut nodes_by_reference = HashMap::new();
let mut storage = HashMap::new();
for proof in proofs {
// parse the nodes of the account proof
for bytes in &proof.account_proof {
let mpt_node = MptNode::decode(bytes).expect("Failed to decode state proof");
nodes_by_pointer.insert(mpt_node.reference(), mpt_node);
nodes_by_reference.insert(mpt_node.reference(), mpt_node);
}

// process the proof for each storage entry
Expand All @@ -351,7 +347,7 @@ fn proofs_to_tries(
.iter()
.rev()
.map(|bytes| MptNode::decode(bytes).expect("Failed to decode storage proof"))
.inspect(|node| drop(nodes_by_pointer.insert(node.reference(), node.clone())))
.inspect(|node| drop(nodes_by_reference.insert(node.reference(), node.clone())))
.last();
// the hash of the root node should match the proof's storage hash
assert_eq!(
Expand All @@ -378,21 +374,20 @@ fn proofs_to_tries(

storage.insert(proof.address.into(), (root_node, slots));
}
(nodes_by_pointer, storage)
(nodes_by_reference, storage)
}

fn resolve_orphans(
nodes: &Vec<Bytes>,
orphan_set: &mut HashSet<MptNodeReference>,
nodes_by_pointer: &mut HashMap<MptNodeReference, MptNode>,
orphans: &mut HashSet<MptNodeReference>,
nodes_by_reference: &mut HashMap<MptNodeReference, MptNode>,
) {
for node in nodes {
let mpt_node = MptNode::decode(node).expect("Failed to decode state proof");
for potential_orphan in shorten_key(mpt_node) {
let potential_orphan_hash = potential_orphan.reference();
if orphan_set.contains(&potential_orphan_hash) {
orphan_set.remove(&potential_orphan_hash);
nodes_by_pointer.insert(potential_orphan_hash, potential_orphan);
if orphans.remove(&potential_orphan_hash) {
nodes_by_reference.insert(potential_orphan_hash, potential_orphan);
}
}
}
Expand All @@ -404,16 +399,11 @@ fn resolve_orphans(
impl Into<Input> for Init {
fn into(self) -> Input {
// construct the proof tries
let (mut nodes_by_pointer, mut storage) =
let (mut nodes_by_reference, mut storage) =
proofs_to_tries(self.init_proofs.values().cloned().collect());
// there should be a trie and a list of storage slots for every account
assert_eq!(storage.len(), self.db.accounts_len());

info!(
"Constructed proof tries with {} nodes",
nodes_by_pointer.len(),
);

// collect the code from each account
let mut contracts = HashMap::new();
for account in self.db.accounts.values() {
Expand All @@ -425,37 +415,43 @@ impl Into<Input> for Init {

// extract the state trie
let state_root = self.init_block.state_root;
let state_trie = nodes_by_pointer
let state_trie = nodes_by_reference
.remove(&MptNodeReference::Digest(state_root))
.expect("State root node not found");
assert_eq!(state_root, state_trie.hash());

// Find orphaned pointers during deletion
let mut orphan_set = HashSet::new();
// identify orphaned digests, that could lead to issues when deleting nodes
let mut orphans = HashSet::new();
for root in storage.values().map(|v| &v.0).chain(once(&state_trie)) {
let root = resolve_pointers(root, &nodes_by_pointer);
orphan_set.extend(
orphaned_pointers(&root)
.iter()
.map(|node: &MptNode| node.reference()),
);
let root = resolve_digests(root, &nodes_by_reference);
orphans.extend(orphaned_digests(&root));
}
// resolve those orphans using the proofs of the final state
for fini_proof in self.fini_proofs.values() {
resolve_orphans(
&fini_proof.account_proof,
&mut orphan_set,
&mut nodes_by_pointer,
&mut orphans,
&mut nodes_by_reference,
);
for storage_proof in &fini_proof.storage_proof {
resolve_orphans(&storage_proof.proof, &mut orphan_set, &mut nodes_by_pointer);
resolve_orphans(&storage_proof.proof, &mut orphans, &mut nodes_by_reference);
}
}

// resolve the pointers in the state root node and all storage root nodes
let state_trie = resolve_pointers(&state_trie, &nodes_by_pointer);
let state_trie = resolve_digests(&state_trie, &nodes_by_reference);
storage
.values_mut()
.for_each(|(n, _)| *n = resolve_pointers(n, &nodes_by_pointer));
.for_each(|(n, _)| *n = resolve_digests(n, &nodes_by_reference));

info!(
"The partial state trie consists of {} nodes",
state_trie.size()
);
info!(
"The partial storage tries consist of {} nodes",
storage.values().map(|(n, _)| n.size()).sum::<usize>()
);

// Create the block builder input
Input {
Expand Down
110 changes: 37 additions & 73 deletions lib/src/host/mpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,103 +13,67 @@
// limitations under the License.

use hashbrown::HashMap;
use zeth_primitives::{
trie::{to_encoded_path, MptNode, MptNodeData, MptNodeReference},
RlpBytes,
};
use zeth_primitives::trie::{to_encoded_path, MptNode, MptNodeData, MptNodeReference};

pub fn load_pointers(
root: &MptNode,
node_store: &mut HashMap<MptNodeReference, MptNode>,
) -> MptNode {
let compact_node = match root.as_data() {
MptNodeData::Null | MptNodeData::Digest(_) | MptNodeData::Leaf(_, _) => root.clone(),
/// Creates a new MPT trie where all the digests contained in `node_store` are resolved.
pub fn resolve_digests(trie: &MptNode, node_store: &HashMap<MptNodeReference, MptNode>) -> MptNode {
let result: MptNode = match trie.as_data() {
MptNodeData::Null | MptNodeData::Leaf(_, _) => trie.clone(),
MptNodeData::Branch(children) => {
let compact_children: Vec<Box<MptNode>> = children
let children: Vec<_> = children
.iter()
.map(|child| Box::new(load_pointers(child, node_store)))
.map(|child| {
child
.as_ref()
.map(|node| Box::new(resolve_digests(node, node_store)))
})
.collect();
MptNodeData::Branch(compact_children.try_into().unwrap()).into()
MptNodeData::Branch(children.try_into().unwrap()).into()
}
MptNodeData::Extension(prefix, target) => MptNodeData::Extension(
prefix.clone(),
Box::new(MptNodeData::Digest(target.hash()).into()),
Box::new(resolve_digests(target, node_store)),
)
.into(),
};
if let MptNodeData::Digest(_) = compact_node.as_data() {
// do nothing
} else {
node_store.insert(compact_node.reference(), compact_node.clone());
}
compact_node
}

pub fn resolve_pointers(
root: &MptNode,
node_store: &HashMap<MptNodeReference, MptNode>,
) -> MptNode {
let result: MptNode = match root.as_data() {
MptNodeData::Null | MptNodeData::Leaf(_, _) => root.clone(),
MptNodeData::Branch(nodes) => {
let node_list: Vec<_> = nodes
.iter()
.map(|n| Box::new(resolve_pointers(n, node_store)))
.collect();
MptNodeData::Branch(
node_list
.try_into()
.expect("Could not convert vector to 16-element array."),
)
.into()
}
MptNodeData::Extension(prefix, node) => {
MptNodeData::Extension(prefix.clone(), Box::new(resolve_pointers(node, node_store)))
.into()
}
MptNodeData::Digest(digest) => {
if let Some(node) = node_store.get(&MptNodeReference::Digest(*digest)) {
resolve_pointers(node, node_store)
resolve_digests(node, node_store)
} else {
root.clone()
trie.clone()
}
}
};
assert_eq!(
root.hash(),
result.hash(),
"Invalid node resolution! {:?} ({:?})",
root.to_rlp(),
result.to_rlp(),
);
assert_eq!(trie.hash(), result.hash());
result
}

pub fn orphaned_pointers(node: &MptNode) -> Vec<MptNode> {
/// Returns all orphaned digests in the trie.
pub fn orphaned_digests(trie: &MptNode) -> Vec<MptNodeReference> {
let mut result = Vec::new();
_orphaned_pointers(node, &mut result);
orphaned_digests_internal(trie, &mut result);
result
}

fn _orphaned_pointers(node: &MptNode, res: &mut Vec<MptNode>) {
match node.as_data() {
MptNodeData::Null => {}
fn orphaned_digests_internal(trie: &MptNode, orphans: &mut Vec<MptNodeReference>) {
match trie.as_data() {
MptNodeData::Branch(children) => {
let unresolved_count = children.iter().filter(|n| !n.is_resolved()).count();
if unresolved_count == 1 {
let unresolved_index = children.iter().position(|n| !n.is_resolved()).unwrap();
res.push(*children[unresolved_index].clone());
}
// Continue descent
for child in children {
_orphaned_pointers(child, res);
}
// iterate over all digest children
let mut digests = children.iter().flatten().filter(|node| node.is_digest());
// if there is exactly one digest child, it is an orphan
if let Some(orphan_digest) = digests.next() {
if digests.next().is_none() {
orphans.push(orphan_digest.reference());
}
};
// recurse
children.iter().flatten().for_each(|child| {
orphaned_digests_internal(child, orphans);
});
}
MptNodeData::Leaf(_, _) => {}
MptNodeData::Extension(_, target) => {
_orphaned_pointers(target, res);
orphaned_digests_internal(target, orphans);
}
MptNodeData::Digest(_) => {}
MptNodeData::Null | MptNodeData::Leaf(_, _) | MptNodeData::Digest(_) => {}
}
}

Expand All @@ -121,12 +85,12 @@ pub fn shorten_key(node: MptNode) -> Vec<MptNode> {
res.push(node.clone())
}
MptNodeData::Leaf(_, value) => {
for i in 0..nibs.len() {
for i in 0..=nibs.len() {
res.push(MptNodeData::Leaf(to_encoded_path(&nibs[i..], true), value.clone()).into())
}
}
MptNodeData::Extension(_, target) => {
for i in 0..nibs.len() {
for i in 0..=nibs.len() {
res.push(
MptNodeData::Extension(to_encoded_path(&nibs[i..], false), target.clone())
.into(),
Expand Down
2 changes: 1 addition & 1 deletion lib/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ fn proof_internal(node: &MptNode, key_nibs: &[u8]) -> Result<Vec<Vec<u8>>, anyho
MptNodeData::Null | MptNodeData::Leaf(_, _) => vec![],
MptNodeData::Branch(children) => {
let mut path = Vec::new();
for node in children {
for node in children.iter().flatten() {
path.extend(proof_internal(&node, &key_nibs[1..])?);
}
path
Expand Down
Loading

0 comments on commit 6009e4d

Please sign in to comment.