Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Nibbles to reduce memory allocations #207

Merged
merged 4 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firewood/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ pub mod storage;

pub mod api;
pub(crate) mod config;
pub mod nibbles;
pub mod service;

pub mod v2;
Expand Down
41 changes: 23 additions & 18 deletions firewood/src/merkle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE.md for licensing terms.

use crate::proof::Proof;
use crate::{nibbles::Nibbles, proof::Proof};
use enum_as_inner::EnumAsInner;
use sha3::Digest;
use shale::{disk_address::DiskAddress, CachedStore, ObjRef, ShaleError, ShaleStore, Storable};
Expand Down Expand Up @@ -1170,10 +1170,10 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
let mut deleted = Vec::new();
let mut parents = Vec::new();

// TODO: Explain why this always starts with a 0 chunk
// I think this may have to do with avoiding moving the root
let mut chunked_key = vec![0];
chunked_key.extend(key.as_ref().iter().copied().flat_map(to_nibble_array));
// we use Nibbles::<1> so that 1 zero nibble is at the front
// this is for the sentinel node, which avoids moving the root
// and always only has one child
let key_nibbles = Nibbles::<1>(key.as_ref());

let mut next_node = Some(self.get_node(root)?);
let mut nskip = 0;
Expand All @@ -1184,7 +1184,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
let mut val = Some(val);

// walk down the merkle tree starting from next_node, currently the root
for (key_nib_offset, key_nib) in chunked_key.iter().enumerate() {
for (key_nib_offset, key_nib) in key_nibbles.iter().enumerate() {
// special handling for extension nodes
if nskip > 0 {
nskip -= 1;
Expand All @@ -1200,21 +1200,21 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
// For a Branch node, we look at the child pointer. If it points
// to another node, we walk down that. Otherwise, we can store our
// value as a leaf and we're done
NodeType::Branch(n) => match n.chd[*key_nib as usize] {
NodeType::Branch(n) => match n.chd[key_nib as usize] {
Some(c) => c,
None => {
// insert the leaf to the empty slot
// create a new leaf
let leaf_ptr = self
.new_node(Node::new(NodeType::Leaf(LeafNode(
PartialPath(chunked_key[key_nib_offset + 1..].to_vec()),
PartialPath(key_nibbles.skip(key_nib_offset + 1).iter().collect()),
rkuris marked this conversation as resolved.
Show resolved Hide resolved
Data(val.take().unwrap()),
))))?
.as_ptr();
// set the current child to point to this leaf
node.write(|u| {
let uu = u.inner.as_branch_mut().unwrap();
uu.chd[*key_nib as usize] = Some(leaf_ptr);
uu.chd[key_nib as usize] = Some(leaf_ptr);
u.rehash();
})
.unwrap();
Expand All @@ -1226,10 +1226,11 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
// of the stored key to pass into split
let n_path = n.0.to_vec();
let n_value = Some(n.1.clone());
let rem_path = key_nibbles.skip(key_nib_offset).iter().collect::<Vec<_>>();
self.split(
node,
&mut parents,
&chunked_key[key_nib_offset..],
&rem_path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a more efficient way to do this, but it requires changes in split. Maybe you can add a todo and we can get rid of the allocation for rem_path.

In any case, it is better since you aren't allocating key_nib_offet bytes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are a LOT of improvements possible here, not sure that this one is worse than some of the others I already know about.

n_path,
n_value,
val.take().unwrap(),
Expand All @@ -1241,10 +1242,12 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
let n_path = n.0.to_vec();
let n_ptr = n.1;
nskip = n_path.len() - 1;
let rem_path = key_nibbles.skip(key_nib_offset).iter().collect::<Vec<_>>();

if let Some(v) = self.split(
node,
&mut parents,
&chunked_key[key_nib_offset..],
&rem_path,
n_path,
None,
val.take().unwrap(),
Expand All @@ -1263,7 +1266,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
}
};
// push another parent, and follow the next pointer
parents.push((node, *key_nib));
parents.push((node, key_nib));
next_node = Some(self.get_node(next_node_ptr)?);
}
if val.is_some() {
Expand Down Expand Up @@ -1835,8 +1838,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
K: AsRef<[u8]>,
T: ValueTransformer,
{
let mut chunks = Vec::new();
chunks.extend(key.as_ref().iter().copied().flat_map(to_nibble_array));
let key_nibbles = Nibbles::<0>(key.as_ref());

let mut proofs: HashMap<[u8; TRIE_HASH_LEN], Vec<u8>> = HashMap::new();
if root.is_null() {
Expand All @@ -1857,23 +1859,26 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {

let mut nskip = 0;
let mut nodes: Vec<DiskAddress> = Vec::new();
for (i, nib) in chunks.iter().enumerate() {
for (i, nib) in key_nibbles.iter().enumerate() {
if nskip > 0 {
nskip -= 1;
continue;
}
nodes.push(u_ref.as_ptr());
let next_ptr: DiskAddress = match &u_ref.inner {
NodeType::Branch(n) => match n.chd[*nib as usize] {
NodeType::Branch(n) => match n.chd[nib as usize] {
Some(c) => c,
None => break,
},
NodeType::Leaf(_) => break,
NodeType::Extension(n) => {
let n_path = &*n.0;
let remaining_path = &chunks[i..];
let remaining_path = key_nibbles.skip(i);
if remaining_path.len() < n_path.len()
|| &remaining_path[..n_path.len()] != n_path
|| !remaining_path
.iter()
.take(n_path.len())
.eq(n_path.iter().cloned())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find these multi-line-if clauses super cumbersome to read; think we could pop this out into a variable (with a name that explains why we are going this too)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split it into two if statements; I think it's a bit easier to read now.

{
break;
} else {
Expand Down
186 changes: 186 additions & 0 deletions firewood/src/nibbles.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
use std::ops::Index;

static NIBBLES: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];

/// Nibbles is a newtype that contains only a pointer to a u8, and produces
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Nibbles is a newtype that contains only a pointer to a u8, and produces
/// Nibbles is a newtype that contains only a reference to a `[u8]`, and produces

/// nibbles. Nibbles can be indexed using nib\[x\] or you can get an iterator
/// with iter()
///
/// Nibbles can be constructed with a number of leading zeroes. This is used
/// in firewood because there is a sentinel node, so we always want the first
/// byte to be 0
///
/// When creating a Nibbles object, use the syntax `Nibbles::<N>(r)` where
/// `N` is the number of leading zero bytes you need and `r` is a reference to
/// a [u8]
///
/// # Examples
///
/// ```
/// # use firewood::nibbles;
/// # fn main() {
/// let nib = nibbles::Nibbles::<0>(&[0x56, 0x78]);
/// assert_eq!(nib.iter().collect::<Vec<_>>(), [0x5, 0x6, 0x7, 0x8]);
///
/// // nibbles can be efficiently advanced without rendering the
/// // intermediate values
/// assert_eq!(nib.skip(3).iter().collect::<Vec<_>>(), [0x8]);
///
/// // nibbles can also be indexed
///
/// assert_eq!(nib[1], 0x6);
/// # }
/// ```
pub struct Nibbles<'a, const LEADING_ZEROES: usize>(pub &'a [u8]);
impl<'a, const LEADING_ZEROES: usize> Index<usize> for Nibbles<'a, LEADING_ZEROES> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pub struct Nibbles<'a, const LEADING_ZEROES: usize>(pub &'a [u8]);
impl<'a, const LEADING_ZEROES: usize> Index<usize> for Nibbles<'a, LEADING_ZEROES> {
pub struct Nibbles<'a, const LEADING_ZEROES: usize>(pub &'a [u8]);
impl<'a, const LEADING_ZEROES: usize> Index<usize> for Nibbles<'a, LEADING_ZEROES> {

type Output = u8;

fn index(&self, index: usize) -> &Self::Output {
match index {
_ if index < LEADING_ZEROES => &NIBBLES[0],
_ if (index - LEADING_ZEROES) % 2 == 0 => {
&NIBBLES[(self.0[(index - LEADING_ZEROES) / 2] >> 4) as usize]
}
_ => &NIBBLES[(self.0[(index - LEADING_ZEROES) / 2] & 0xf) as usize],
}
}
}

impl<'a, const LEADING_ZEROES: usize> Nibbles<'a, LEADING_ZEROES> {
#[must_use]
pub fn iter(&self) -> NibblesIterator<'_, LEADING_ZEROES> {
NibblesIterator { data: self, pos: 0 }
}

/// Efficently skip some values
#[must_use]
pub fn skip(&self, at: usize) -> NibblesSlice<'a> {
assert!(at >= LEADING_ZEROES, "Cannot split before LEADING_ZEROES (requested split at {at} is less than the {LEADING_ZEROES} leading zero(es)");
NibblesSlice {
skipfirst: (at - LEADING_ZEROES) % 2 != 0,
nibbles: Nibbles(&self.0[(at - LEADING_ZEROES) / 2..]),
}
}

#[must_use]
pub fn len(&self) -> usize {
LEADING_ZEROES + 2 * self.0.len()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

This might be useful for improving that split function and taking in Nibbles instead of &[u8]


#[must_use]
pub fn is_empty(&self) -> bool {
LEADING_ZEROES == 0 && self.0.is_empty()
}
}

/// NibblesSlice is created by [Nibbles::skip]. This is
/// used to create an interator that starts at some particular
/// nibble
pub struct NibblesSlice<'a> {
nibbles: Nibbles<'a, 0>,
skipfirst: bool,
}

impl<'a> NibblesSlice<'a> {
/// Returns an iterator over this subset of nibbles
pub fn iter(&self) -> NibblesIterator<'_, 0> {
let pos = if self.skipfirst { 1 } else { 0 };
NibblesIterator {
data: &self.nibbles,
pos,
}
}

pub fn len(&self) -> usize {
self.nibbles.len() - if self.skipfirst { 1 } else { 0 }
}

pub fn is_empty(&self) -> bool {
self.len() > 0
}
}

/// An interator returned by [Nibbles::iter] or [NibblesSlice::iter].
/// See their documentation for details.
pub struct NibblesIterator<'a, const LEADING_ZEROES: usize> {
data: &'a Nibbles<'a, LEADING_ZEROES>,
pos: usize,
}

impl<'a, const LEADING_ZEROES: usize> Iterator for NibblesIterator<'a, LEADING_ZEROES> {
type Item = u8;

fn next(&mut self) -> Option<Self::Item> {
let result = if self.pos >= LEADING_ZEROES + self.data.0.len() * 2 {
None
} else {
Some(self.data[self.pos])
};
self.pos += 1;
result
}
}

#[cfg(test)]
mod test {
use super::*;
static TEST_BYTES: [u8; 4] = [0xdeu8, 0xad, 0xbe, 0xef];
#[test]
fn happy_regular_nibbles() {
let nib = Nibbles::<0>(&TEST_BYTES);
let expected = [0xdu8, 0xe, 0xa, 0xd, 0xb, 0xe, 0xe, 0xf];
for v in expected.into_iter().enumerate() {
assert_eq!(nib[v.0], v.1, "{v:?}");
}
}
#[test]
fn leadingzero_nibbles_index() {
let nib = Nibbles::<1>(&TEST_BYTES);
let expected = [0u8, 0xd, 0xe, 0xa, 0xd, 0xb, 0xe, 0xe, 0xf];
for v in expected.into_iter().enumerate() {
assert_eq!(nib[v.0], v.1, "{v:?}");
}
}
#[test]
fn leading_zero_nibbles_iter() {
let nib = Nibbles::<1>(&TEST_BYTES);
let expected: [u8; 9] = [0u8, 0xd, 0xe, 0xa, 0xd, 0xb, 0xe, 0xe, 0xf];
expected.into_iter().eq(nib.iter());
}

#[test]
fn skip_zero() {
let nib = Nibbles::<0>(&TEST_BYTES);
let slice = nib.skip(0);
assert!(nib.iter().eq(slice.iter()));
}
#[test]
fn skip_one() {
let nib = Nibbles::<0>(&TEST_BYTES);
let slice = nib.skip(1);
assert!(nib.iter().skip(1).eq(slice.iter()));
}
#[test]
fn skip_skips_zeroes() {
let nib = Nibbles::<1>(&TEST_BYTES);
let slice = nib.skip(1);
assert!(nib.iter().skip(1).eq(slice.iter()));
}
#[test]
#[should_panic]
fn test_out_of_bounds_panics() {
let nib = Nibbles::<0>(&TEST_BYTES);
let _ = nib[8];
}
#[test]
fn test_last_nibble() {
let nib = Nibbles::<0>(&TEST_BYTES);
assert_eq!(nib[7], 0xf);
}
#[test]
#[should_panic]
fn test_skip_before_zeroes_panics() {
let nib = Nibbles::<1>(&TEST_BYTES);
let _ = nib.skip(0);
}
}
Loading