Skip to content

Commit

Permalink
Implement Nibbles to reduce memory allocations (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkuris authored Aug 21, 2023
1 parent 4485c89 commit 906034d
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 22 deletions.
1 change: 1 addition & 0 deletions firewood/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ pub mod storage;

pub mod api;
pub(crate) mod config;
pub mod nibbles;
pub mod service;

pub mod v2;
Expand Down
54 changes: 32 additions & 22 deletions firewood/src/merkle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE.md for licensing terms.

use crate::proof::Proof;
use crate::{nibbles::Nibbles, proof::Proof};
use enum_as_inner::EnumAsInner;
use sha3::Digest;
use shale::{disk_address::DiskAddress, CachedStore, ObjRef, ShaleError, ShaleStore, Storable};
Expand Down Expand Up @@ -1170,10 +1170,10 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
let mut deleted = Vec::new();
let mut parents = Vec::new();

// TODO: Explain why this always starts with a 0 chunk
// I think this may have to do with avoiding moving the root
let mut chunked_key = vec![0];
chunked_key.extend(key.as_ref().iter().copied().flat_map(to_nibble_array));
// we use Nibbles::<1> so that 1 zero nibble is at the front
// this is for the sentinel node, which avoids moving the root
// and always only has one child
let key_nibbles = Nibbles::<1>(key.as_ref());

let mut next_node = Some(self.get_node(root)?);
let mut nskip = 0;
Expand All @@ -1184,7 +1184,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
let mut val = Some(val);

// walk down the merkle tree starting from next_node, currently the root
for (key_nib_offset, key_nib) in chunked_key.iter().enumerate() {
for (key_nib_offset, key_nib) in key_nibbles.iter().enumerate() {
// special handling for extension nodes
if nskip > 0 {
nskip -= 1;
Expand All @@ -1200,21 +1200,21 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
// For a Branch node, we look at the child pointer. If it points
// to another node, we walk down that. Otherwise, we can store our
// value as a leaf and we're done
NodeType::Branch(n) => match n.chd[*key_nib as usize] {
NodeType::Branch(n) => match n.chd[key_nib as usize] {
Some(c) => c,
None => {
// insert the leaf to the empty slot
// create a new leaf
let leaf_ptr = self
.new_node(Node::new(NodeType::Leaf(LeafNode(
PartialPath(chunked_key[key_nib_offset + 1..].to_vec()),
PartialPath(key_nibbles.skip(key_nib_offset + 1).iter().collect()),
Data(val.take().unwrap()),
))))?
.as_ptr();
// set the current child to point to this leaf
node.write(|u| {
let uu = u.inner.as_branch_mut().unwrap();
uu.chd[*key_nib as usize] = Some(leaf_ptr);
uu.chd[key_nib as usize] = Some(leaf_ptr);
u.rehash();
})
.unwrap();
Expand All @@ -1226,10 +1226,11 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
// of the stored key to pass into split
let n_path = n.0.to_vec();
let n_value = Some(n.1.clone());
let rem_path = key_nibbles.skip(key_nib_offset).iter().collect::<Vec<_>>();
self.split(
node,
&mut parents,
&chunked_key[key_nib_offset..],
&rem_path,
n_path,
n_value,
val.take().unwrap(),
Expand All @@ -1241,10 +1242,12 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
let n_path = n.0.to_vec();
let n_ptr = n.1;
nskip = n_path.len() - 1;
let rem_path = key_nibbles.skip(key_nib_offset).iter().collect::<Vec<_>>();

if let Some(v) = self.split(
node,
&mut parents,
&chunked_key[key_nib_offset..],
&rem_path,
n_path,
None,
val.take().unwrap(),
Expand All @@ -1263,7 +1266,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
}
};
// push another parent, and follow the next pointer
parents.push((node, *key_nib));
parents.push((node, key_nib));
next_node = Some(self.get_node(next_node_ptr)?);
}
if val.is_some() {
Expand Down Expand Up @@ -1835,8 +1838,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
K: AsRef<[u8]>,
T: ValueTransformer,
{
let mut chunks = Vec::new();
chunks.extend(key.as_ref().iter().copied().flat_map(to_nibble_array));
let key_nibbles = Nibbles::<0>(key.as_ref());

let mut proofs: HashMap<[u8; TRIE_HASH_LEN], Vec<u8>> = HashMap::new();
if root.is_null() {
Expand All @@ -1857,29 +1859,37 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {

let mut nskip = 0;
let mut nodes: Vec<DiskAddress> = Vec::new();
for (i, nib) in chunks.iter().enumerate() {
for (i, nib) in key_nibbles.iter().enumerate() {
if nskip > 0 {
nskip -= 1;
continue;
}
nodes.push(u_ref.as_ptr());
let next_ptr: DiskAddress = match &u_ref.inner {
NodeType::Branch(n) => match n.chd[*nib as usize] {
NodeType::Branch(n) => match n.chd[nib as usize] {
Some(c) => c,
None => break,
},
NodeType::Leaf(_) => break,
NodeType::Extension(n) => {
// the key passed in must match the entire remainder of this
// extension node, otherwise we break out
let n_path = &*n.0;
let remaining_path = &chunks[i..];
if remaining_path.len() < n_path.len()
|| &remaining_path[..n_path.len()] != n_path
let remaining_path = key_nibbles.skip(i);
if remaining_path.len() < n_path.len() {
// all bytes aren't there
break;
}
if !remaining_path
.iter()
.take(n_path.len())
.eq(n_path.iter().cloned())
{
// contents aren't the same
break;
} else {
nskip = n_path.len() - 1;
n.1
}
nskip = n_path.len() - 1;
n.1
}
};
u_ref = self.get_node(next_ptr)?;
Expand Down
190 changes: 190 additions & 0 deletions firewood/src/nibbles.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
use std::ops::Index;

static NIBBLES: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];

/// Nibbles is a newtype that contains only a reference to a [u8], and produces
/// nibbles. Nibbles can be indexed using nib\[x\] or you can get an iterator
/// with iter()
///
/// Nibbles can be constructed with a number of leading zeroes. This is used
/// in firewood because there is a sentinel node, so we always want the first
/// byte to be 0
///
/// When creating a Nibbles object, use the syntax `Nibbles::<N>(r)` where
/// `N` is the number of leading zero bytes you need and `r` is a reference to
/// a [u8]
///
/// # Examples
///
/// ```
/// # use firewood::nibbles;
/// # fn main() {
/// let nib = nibbles::Nibbles::<0>(&[0x56, 0x78]);
/// assert_eq!(nib.iter().collect::<Vec<_>>(), [0x5, 0x6, 0x7, 0x8]);
///
/// // nibbles can be efficiently advanced without rendering the
/// // intermediate values
/// assert_eq!(nib.skip(3).iter().collect::<Vec<_>>(), [0x8]);
///
/// // nibbles can also be indexed
///
/// assert_eq!(nib[1], 0x6);
/// # }
/// ```
#[derive(Debug)]
pub struct Nibbles<'a, const LEADING_ZEROES: usize>(pub &'a [u8]);

impl<'a, const LEADING_ZEROES: usize> Index<usize> for Nibbles<'a, LEADING_ZEROES> {
type Output = u8;

fn index(&self, index: usize) -> &Self::Output {
match index {
_ if index < LEADING_ZEROES => &NIBBLES[0],
_ if (index - LEADING_ZEROES) % 2 == 0 => {
&NIBBLES[(self.0[(index - LEADING_ZEROES) / 2] >> 4) as usize]
}
_ => &NIBBLES[(self.0[(index - LEADING_ZEROES) / 2] & 0xf) as usize],
}
}
}

impl<'a, const LEADING_ZEROES: usize> Nibbles<'a, LEADING_ZEROES> {
#[must_use]
pub fn iter(&self) -> NibblesIterator<'_, LEADING_ZEROES> {
NibblesIterator { data: self, pos: 0 }
}

/// Efficently skip some values
#[must_use]
pub fn skip(&self, at: usize) -> NibblesSlice<'a> {
assert!(at >= LEADING_ZEROES, "Cannot split before LEADING_ZEROES (requested split at {at} is less than the {LEADING_ZEROES} leading zero(es)");
NibblesSlice {
skipfirst: (at - LEADING_ZEROES) % 2 != 0,
nibbles: Nibbles(&self.0[(at - LEADING_ZEROES) / 2..]),
}
}

#[must_use]
pub fn len(&self) -> usize {
LEADING_ZEROES + 2 * self.0.len()
}

#[must_use]
pub fn is_empty(&self) -> bool {
LEADING_ZEROES == 0 && self.0.is_empty()
}
}

/// NibblesSlice is created by [Nibbles::skip]. This is
/// used to create an interator that starts at some particular
/// nibble
#[derive(Debug)]
pub struct NibblesSlice<'a> {
nibbles: Nibbles<'a, 0>,
skipfirst: bool,
}

impl<'a> NibblesSlice<'a> {
/// Returns an iterator over this subset of nibbles
pub fn iter(&self) -> NibblesIterator<'_, 0> {
let pos = if self.skipfirst { 1 } else { 0 };
NibblesIterator {
data: &self.nibbles,
pos,
}
}

pub fn len(&self) -> usize {
self.nibbles.len() - if self.skipfirst { 1 } else { 0 }
}

pub fn is_empty(&self) -> bool {
self.len() > 0
}
}

/// An interator returned by [Nibbles::iter] or [NibblesSlice::iter].
/// See their documentation for details.
#[derive(Debug)]
pub struct NibblesIterator<'a, const LEADING_ZEROES: usize> {
data: &'a Nibbles<'a, LEADING_ZEROES>,
pos: usize,
}

impl<'a, const LEADING_ZEROES: usize> Iterator for NibblesIterator<'a, LEADING_ZEROES> {
type Item = u8;

fn next(&mut self) -> Option<Self::Item> {
let result = if self.pos >= LEADING_ZEROES + self.data.0.len() * 2 {
None
} else {
Some(self.data[self.pos])
};
self.pos += 1;
result
}
}

#[cfg(test)]
mod test {
use super::*;
static TEST_BYTES: [u8; 4] = [0xdeu8, 0xad, 0xbe, 0xef];
#[test]
fn happy_regular_nibbles() {
let nib = Nibbles::<0>(&TEST_BYTES);
let expected = [0xdu8, 0xe, 0xa, 0xd, 0xb, 0xe, 0xe, 0xf];
for v in expected.into_iter().enumerate() {
assert_eq!(nib[v.0], v.1, "{v:?}");
}
}
#[test]
fn leadingzero_nibbles_index() {
let nib = Nibbles::<1>(&TEST_BYTES);
let expected = [0u8, 0xd, 0xe, 0xa, 0xd, 0xb, 0xe, 0xe, 0xf];
for v in expected.into_iter().enumerate() {
assert_eq!(nib[v.0], v.1, "{v:?}");
}
}
#[test]
fn leading_zero_nibbles_iter() {
let nib = Nibbles::<1>(&TEST_BYTES);
let expected: [u8; 9] = [0u8, 0xd, 0xe, 0xa, 0xd, 0xb, 0xe, 0xe, 0xf];
expected.into_iter().eq(nib.iter());
}

#[test]
fn skip_zero() {
let nib = Nibbles::<0>(&TEST_BYTES);
let slice = nib.skip(0);
assert!(nib.iter().eq(slice.iter()));
}
#[test]
fn skip_one() {
let nib = Nibbles::<0>(&TEST_BYTES);
let slice = nib.skip(1);
assert!(nib.iter().skip(1).eq(slice.iter()));
}
#[test]
fn skip_skips_zeroes() {
let nib = Nibbles::<1>(&TEST_BYTES);
let slice = nib.skip(1);
assert!(nib.iter().skip(1).eq(slice.iter()));
}
#[test]
#[should_panic]
fn test_out_of_bounds_panics() {
let nib = Nibbles::<0>(&TEST_BYTES);
let _ = nib[8];
}
#[test]
fn test_last_nibble() {
let nib = Nibbles::<0>(&TEST_BYTES);
assert_eq!(nib[7], 0xf);
}
#[test]
#[should_panic]
fn test_skip_before_zeroes_panics() {
let nib = Nibbles::<1>(&TEST_BYTES);
let _ = nib.skip(0);
}
}

0 comments on commit 906034d

Please sign in to comment.