Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use nibbles more #267

Merged
merged 3 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions firewood/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
let mut val = Some(val);

// walk down the merkle tree starting from next_node, currently the root
for (key_nib_offset, key_nib) in key_nibbles.iter().enumerate() {
for (key_nib_offset, key_nib) in key_nibbles.into_iter().enumerate() {
// special handling for extension nodes
if nskip > 0 {
nskip -= 1;
Expand All @@ -393,7 +393,9 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
// create a new leaf
let leaf_ptr = self
.new_node(Node::new(NodeType::Leaf(LeafNode(
PartialPath(key_nibbles.iter().skip(key_nib_offset + 1).collect()),
PartialPath(
key_nibbles.into_iter().skip(key_nib_offset + 1).collect(),
),
Data(val.take().unwrap()),
))))?
.as_ptr();
Expand All @@ -412,7 +414,10 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
// of the stored key to pass into split
let n_path = n.0.to_vec();
let n_value = Some(n.1.clone());
let rem_path = key_nibbles.iter().skip(key_nib_offset).collect::<Vec<_>>();
let rem_path = key_nibbles
.into_iter()
.skip(key_nib_offset)
.collect::<Vec<_>>();
self.split(
node,
&mut parents,
Expand All @@ -428,7 +433,10 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
let n_path = n.0.to_vec();
let n_ptr = n.1;
nskip = n_path.len() - 1;
let rem_path = key_nibbles.iter().skip(key_nib_offset).collect::<Vec<_>>();
let rem_path = key_nibbles
.into_iter()
.skip(key_nib_offset)
.collect::<Vec<_>>();

if let Some(v) = self.split(
node,
Expand Down Expand Up @@ -1044,7 +1052,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {

let mut nskip = 0;
let mut nodes: Vec<DiskAddress> = Vec::new();
for (i, nib) in key_nibbles.iter().enumerate() {
for (i, nib) in key_nibbles.into_iter().enumerate() {
if nskip > 0 {
nskip -= 1;
continue;
Expand All @@ -1060,7 +1068,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
// the key passed in must match the entire remainder of this
// extension node, otherwise we break out
let n_path = &*n.0;
let remaining_path = key_nibbles.iter().skip(i);
let remaining_path = key_nibbles.into_iter().skip(i);
if remaining_path.size_hint().0 < n_path.len() {
// all bytes aren't there
break;
Expand Down Expand Up @@ -1115,7 +1123,7 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
let mut u_ref = self.get_node(root)?;
let mut nskip = 0;

for (i, nib) in key_nibbles.iter().enumerate() {
for (i, nib) in key_nibbles.into_iter().enumerate() {
if nskip > 0 {
nskip -= 1;
continue;
Expand All @@ -1126,14 +1134,14 @@ impl<S: ShaleStore<Node> + Send + Sync> Merkle<S> {
None => return Ok(None),
},
NodeType::Leaf(n) => {
if !key_nibbles.iter().skip(i).eq(n.0.iter().cloned()) {
if !key_nibbles.into_iter().skip(i).eq(n.0.iter().cloned()) {
return Ok(None);
}
return Ok(Some(Ref(u_ref)));
}
NodeType::Extension(n) => {
let n_path = &*n.0;
let rem_path = key_nibbles.iter().skip(i);
let rem_path = key_nibbles.into_iter().skip(i);
if rem_path.size_hint().0 < n_path.len() {
return Ok(None);
}
Expand Down Expand Up @@ -1287,7 +1295,7 @@ mod test {
#[test]
fn test_partial_path_encoding() {
let check = |steps: &[u8], term| {
let (d, t) = PartialPath::decode(PartialPath(steps.to_vec()).encode(term));
let (d, t) = PartialPath::decode(&PartialPath(steps.to_vec()).encode(term));
assert_eq!(d.0, steps);
assert_eq!(t, term);
};
Expand Down
4 changes: 2 additions & 2 deletions firewood/src/merkle/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ impl Storable for Node {
.flat_map(to_nibble_array)
.collect();

let (path, _) = PartialPath::decode(nibbles);
let (path, _) = PartialPath::decode(&nibbles);

let mut buff = [0_u8; 1];
let rlp_len_raw = mem
Expand Down Expand Up @@ -598,7 +598,7 @@ impl Storable for Node {
.flat_map(to_nibble_array)
.collect();

let (path, _) = PartialPath::decode(nibbles);
let (path, _) = PartialPath::decode(&nibbles);
let value = Data(remainder.as_deref()[path_len as usize..].to_vec());
Ok(Self::new_from_hash(
root_hash,
Expand Down
32 changes: 20 additions & 12 deletions firewood/src/merkle/partial_path.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE.md for licensing terms.

use crate::nibbles::NibblesIterator;
use std::fmt::{self, Debug};

/// PartialPath keeps a list of nibbles to represent a path on the Trie.
Expand Down Expand Up @@ -40,18 +41,25 @@ impl PartialPath {
res
}

pub fn decode<R: AsRef<[u8]>>(raw: R) -> (Self, bool) {
let raw = raw.as_ref();
let term = raw[0] > 1;
let odd_len = raw[0] & 1;
(
Self(if odd_len == 1 {
raw[1..].to_vec()
} else {
raw[2..].to_vec()
}),
term,
)
// TODO: remove all non `Nibbles` usages and delete this function.
// I also think `PartialPath` could probably borrow instead of own data.
//
/// returns a tuple of the decoded partial path and whether the path is terminal
pub fn decode(raw: &[u8]) -> (Self, bool) {
let prefix = raw[0];
let is_odd = (prefix & 1) as usize;
Comment on lines +49 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let prefix = raw[0];
let is_odd = (prefix & 1) as usize;
let i = raw.iter();
let is_odd = (i.next() & 1) as usize;
let decoded = i.skip(1-is_odd).copied().collect();

let decoded = raw.iter().skip(1).skip(1 - is_odd).copied().collect();

(Self(decoded), prefix > 1)
}

/// returns a tuple of the decoded partial path and whether the path is terminal
pub fn from_nibbles<const N: usize>(mut nibbles: NibblesIterator<'_, N>) -> (Self, bool) {
let prefix = nibbles.next().unwrap();
let is_odd = (prefix & 1) as usize;
let decoded = nibbles.skip(1 - is_odd).collect();

(Self(decoded), prefix > 1)
}

pub(super) fn dehydrated_len(&self) -> u64 {
Expand Down
33 changes: 19 additions & 14 deletions firewood/src/nibbles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ static NIBBLES: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15

/// Nibbles is a newtype that contains only a reference to a [u8], and produces
/// nibbles. Nibbles can be indexed using nib\[x\] or you can get an iterator
/// with iter()
/// with `into_iter()`
///
/// Nibbles can be constructed with a number of leading zeroes. This is used
/// in firewood because there is a sentinel node, so we always want the first
Expand All @@ -23,21 +23,21 @@ static NIBBLES: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
/// # use firewood::nibbles;
/// # fn main() {
/// let nib = nibbles::Nibbles::<0>::new(&[0x56, 0x78]);
/// assert_eq!(nib.iter().collect::<Vec<_>>(), [0x5, 0x6, 0x7, 0x8]);
/// assert_eq!(nib.into_iter().collect::<Vec<_>>(), [0x5, 0x6, 0x7, 0x8]);
///
/// // nibbles can be efficiently advanced without rendering the
/// // intermediate values
/// assert_eq!(nib.iter().skip(3).collect::<Vec<_>>(), [0x8]);
/// assert_eq!(nib.into_iter().skip(3).collect::<Vec<_>>(), [0x8]);
///
/// // nibbles can also be indexed
///
/// assert_eq!(nib[1], 0x6);
///
/// // or reversed
/// assert_eq!(nib.iter().rev().next(), Some(0x8));
/// assert_eq!(nib.into_iter().rev().next(), Some(0x8));
/// # }
/// ```
#[derive(Debug)]
#[derive(Debug, Copy, Clone)]
pub struct Nibbles<'a, const LEADING_ZEROES: usize>(&'a [u8]);

impl<'a, const LEADING_ZEROES: usize> Index<usize> for Nibbles<'a, LEADING_ZEROES> {
Expand All @@ -54,16 +54,21 @@ impl<'a, const LEADING_ZEROES: usize> Index<usize> for Nibbles<'a, LEADING_ZEROE
}
}

impl<'a, const LEADING_ZEROES: usize> Nibbles<'a, LEADING_ZEROES> {
impl<'a, const LEADING_ZEROES: usize> IntoIterator for Nibbles<'a, LEADING_ZEROES> {
type Item = u8;
type IntoIter = NibblesIterator<'a, LEADING_ZEROES>;

#[must_use]
pub fn iter(&self) -> NibblesIterator<'_, LEADING_ZEROES> {
fn into_iter(self) -> Self::IntoIter {
NibblesIterator {
data: self,
head: 0,
tail: self.len(),
}
}
}

impl<'a, const LEADING_ZEROES: usize> Nibbles<'a, LEADING_ZEROES> {
#[must_use]
pub fn len(&self) -> usize {
LEADING_ZEROES + 2 * self.0.len()
Expand All @@ -79,11 +84,11 @@ impl<'a, const LEADING_ZEROES: usize> Nibbles<'a, LEADING_ZEROES> {
}
}

/// An interator returned by [Nibbles::iter]
/// An interator returned by [Nibbles::into_iter]
/// See their documentation for details.
#[derive(Clone, Debug)]
pub struct NibblesIterator<'a, const LEADING_ZEROES: usize> {
data: &'a Nibbles<'a, LEADING_ZEROES>,
data: Nibbles<'a, LEADING_ZEROES>,
head: usize,
tail: usize,
}
Expand Down Expand Up @@ -161,14 +166,14 @@ mod test {
fn leading_zero_nibbles_iter() {
let nib = Nibbles::<1>(&TEST_BYTES);
let expected: [u8; 9] = [0u8, 0xd, 0xe, 0xa, 0xd, 0xb, 0xe, 0xe, 0xf];
expected.into_iter().eq(nib.iter());
expected.into_iter().eq(nib.into_iter());
}

#[test]
fn skip_skips_zeroes() {
let nib1 = Nibbles::<1>(&TEST_BYTES);
let nib0 = Nibbles::<0>(&TEST_BYTES);
assert!(nib1.iter().skip(1).eq(nib0.iter()));
assert!(nib1.into_iter().skip(1).eq(nib0.into_iter()));
}

#[test]
Expand All @@ -187,7 +192,7 @@ mod test {
#[test]
fn size_hint_0() {
let nib = Nibbles::<0>(&TEST_BYTES);
let mut nib_iter = nib.iter();
let mut nib_iter = nib.into_iter();
assert_eq!((8, Some(8)), nib_iter.size_hint());
let _ = nib_iter.next();
assert_eq!((7, Some(7)), nib_iter.size_hint());
Expand All @@ -196,7 +201,7 @@ mod test {
#[test]
fn size_hint_1() {
let nib = Nibbles::<1>(&TEST_BYTES);
let mut nib_iter = nib.iter();
let mut nib_iter = nib.into_iter();
assert_eq!((9, Some(9)), nib_iter.size_hint());
let _ = nib_iter.next();
assert_eq!((8, Some(8)), nib_iter.size_hint());
Expand All @@ -205,7 +210,7 @@ mod test {
#[test]
fn backwards() {
let nib = Nibbles::<1>(&TEST_BYTES);
let nib_iter = nib.iter().rev();
let nib_iter = nib.into_iter().rev();
let expected = [0xf, 0xe, 0xe, 0xb, 0xd, 0xa, 0xe, 0xd, 0x0];

assert!(nib_iter.eq(expected));
Expand Down
Loading