Skip to content

Commit

Permalink
Merge pull request #781 from Chia-Network/merkle_blob-tree_index_newtype
Browse files Browse the repository at this point in the history
use newtype pattern for `TreeIndex`
  • Loading branch information
altendky authored Nov 6, 2024
2 parents 3ffd27c + a40df03 commit 5899478
Showing 1 changed file with 67 additions and 35 deletions.
102 changes: 67 additions & 35 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use pyo3::{
buffer::PyBuffer,
exceptions::{PyAttributeError, PyValueError},
pyclass, pymethods, PyResult, Python,
pyclass, pymethods, FromPyObject, IntoPy, PyObject, PyResult, Python,
};

use clvmr::sha2::Sha256;
Expand All @@ -14,7 +14,33 @@ use std::mem::size_of;
use std::ops::Range;
use thiserror::Error;

type TreeIndex = u32;
#[cfg_attr(feature = "py-bindings", derive(FromPyObject), pyo3(transparent))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TreeIndex(u32);

impl TreeIndex {
fn from_bytes(bytes: &[u8]) -> Self {
Self(u32::from_be_bytes(bytes.try_into().unwrap()))
}

fn to_bytes(self) -> [u8; 4] {
self.0.to_be_bytes()
}
}

#[cfg(feature = "py-bindings")]
impl IntoPy<PyObject> for TreeIndex {
fn into_py(self, py: Python<'_>) -> PyObject {
self.0.into_py(py)
}
}

impl std::fmt::Display for TreeIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

type Parent = Option<TreeIndex>;
type Hash = [u8; 32];
// key and value ids are provided from outside of this code and are implemented as
Expand Down Expand Up @@ -171,7 +197,7 @@ pub enum InsertLocation {
Leaf { index: TreeIndex, side: Side },
}

const NULL_PARENT: TreeIndex = 0xffff_ffffu32;
const NULL_PARENT: TreeIndex = TreeIndex(0xffff_ffffu32);

#[derive(Debug, PartialEq)]
pub struct NodeMetadata {
Expand Down Expand Up @@ -250,8 +276,8 @@ impl Node {
hash: Self::hash_from_bytes(&blob),
specific: match metadata.node_type {
NodeType::Internal => NodeSpecific::Internal {
left: TreeIndex::from_be_bytes(blob[LEFT_RANGE].try_into().unwrap()),
right: TreeIndex::from_be_bytes(blob[RIGHT_RANGE].try_into().unwrap()),
left: TreeIndex::from_bytes(&blob[LEFT_RANGE]),
right: TreeIndex::from_bytes(&blob[RIGHT_RANGE]),
},
NodeType::Leaf => NodeSpecific::Leaf {
key: KvId::from_be_bytes(blob[KEY_RANGE].try_into().unwrap()),
Expand All @@ -262,7 +288,7 @@ impl Node {
}

fn parent_from_bytes(blob: &DataBytes) -> Parent {
let parent_integer = TreeIndex::from_be_bytes(blob[PARENT_RANGE].try_into().unwrap());
let parent_integer = TreeIndex::from_bytes(&blob[PARENT_RANGE]);
match parent_integer {
NULL_PARENT => None,
_ => Some(parent_integer),
Expand All @@ -286,9 +312,9 @@ impl Node {
Some(parent) => *parent,
};
blob[HASH_RANGE].copy_from_slice(hash);
blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_be_bytes());
blob[LEFT_RANGE].copy_from_slice(&left.to_be_bytes());
blob[RIGHT_RANGE].copy_from_slice(&right.to_be_bytes());
blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_bytes());
blob[LEFT_RANGE].copy_from_slice(&left.to_bytes());
blob[RIGHT_RANGE].copy_from_slice(&right.to_bytes());
}
Node {
parent,
Expand All @@ -300,7 +326,7 @@ impl Node {
Some(parent) => *parent,
};
blob[HASH_RANGE].copy_from_slice(hash);
blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_be_bytes());
blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_bytes());
blob[KEY_RANGE].copy_from_slice(&key.to_be_bytes());
blob[VALUE_RANGE].copy_from_slice(&value.to_be_bytes());
}
Expand Down Expand Up @@ -359,7 +385,7 @@ impl Node {
}

fn block_range(index: TreeIndex) -> Range<usize> {
let block_start = index as usize * BLOCK_SIZE;
let block_start = index.0 as usize * BLOCK_SIZE;
block_start..block_start + BLOCK_SIZE
}

Expand Down Expand Up @@ -403,7 +429,7 @@ fn get_free_indexes_and_keys_values_indexes(
let mut key_to_index: HashMap<KvId, TreeIndex> = HashMap::default();

for (index, block) in MerkleBlobLeftChildFirstIterator::new(blob) {
seen_indexes[index as usize] = true;
seen_indexes[index.0 as usize] = true;

if let NodeSpecific::Leaf { key, .. } = block.node.specific {
key_to_index.insert(key, index);
Expand All @@ -413,7 +439,7 @@ fn get_free_indexes_and_keys_values_indexes(
let mut free_indexes: HashSet<TreeIndex> = HashSet::new();
for (index, seen) in seen_indexes.iter().enumerate() {
if !seen {
free_indexes.insert(index as TreeIndex);
free_indexes.insert(TreeIndex(index as u32));
}
}

Expand Down Expand Up @@ -561,7 +587,7 @@ impl MerkleBlob {
return Err(Error::OldLeafUnexpectedlyNotALeaf);
};

node.parent = Some(0);
node.parent = Some(TreeIndex(0));

let nodes = [
(
Expand All @@ -570,7 +596,7 @@ impl MerkleBlob {
Side::Right => left_index,
},
Node {
parent: Some(0),
parent: Some(TreeIndex(0)),
specific: NodeSpecific::Leaf {
key: old_leaf_key,
value: old_leaf_value,
Expand Down Expand Up @@ -872,11 +898,11 @@ impl MerkleBlob {

let Some(grandparent_index) = parent.parent else {
sibling_block.node.parent = None;
self.insert_entry_to_blob(0, &sibling_block)?;
self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?;

if let NodeSpecific::Internal { left, right } = sibling_block.node.specific {
for child_index in [left, right] {
self.update_parent(child_index, Some(0))?;
self.update_parent(child_index, Some(TreeIndex(0)))?;
}
};

Expand Down Expand Up @@ -977,7 +1003,7 @@ impl MerkleBlob {
let total_count = leaf_count + internal_count + self.free_indexes.len();
let extend_index = self.extend_index();
assert_eq!(
total_count, extend_index as usize,
total_count, extend_index.0 as usize,
"expected total node count {extend_index:?} found: {total_count:?}",
);
assert_eq!(child_to_parent.len(), 0);
Expand Down Expand Up @@ -1047,7 +1073,7 @@ impl MerkleBlob {
} else {
Side::Right
};
let mut next_index: TreeIndex = 0;
let mut next_index = TreeIndex(0);
let mut node = self.get_node(next_index)?;

loop {
Expand Down Expand Up @@ -1080,7 +1106,7 @@ impl MerkleBlob {

fn extend_index(&self) -> TreeIndex {
let blob_length = self.blob.len();
let index: TreeIndex = (blob_length / BLOCK_SIZE) as TreeIndex;
let index: TreeIndex = TreeIndex((blob_length / BLOCK_SIZE) as u32);
let remainder = blob_length % BLOCK_SIZE;
assert_eq!(remainder, 0, "blob length {blob_length:?} not a multiple of {BLOCK_SIZE:?}, remainder: {remainder:?}");

Expand Down Expand Up @@ -1378,7 +1404,7 @@ impl MerkleBlob {
{
use pyo3::conversion::IntoPy;
use pyo3::types::PyListMethods;
list.append((index, node.into_py(py)))?;
list.append((index.into_py(py), node.into_py(py)))?;
}

Ok(list.into())
Expand All @@ -1391,7 +1417,7 @@ impl MerkleBlob {
for (index, block) in MerkleBlobParentFirstIterator::new(&self.blob) {
use pyo3::conversion::IntoPy;
use pyo3::types::PyListMethods;
list.append((index, block.node.into_py(py)))?;
list.append((index.into_py(py), block.node.into_py(py)))?;
}

Ok(list.into())
Expand All @@ -1404,7 +1430,7 @@ impl MerkleBlob {

#[pyo3(name = "get_root_hash")]
pub fn py_get_root_hash(&self) -> PyResult<Option<Hash>> {
self.py_get_hash_at_index(0)
self.py_get_hash_at_index(TreeIndex(0))
}

#[pyo3(name = "get_hash_at_index")]
Expand Down Expand Up @@ -1463,7 +1489,7 @@ impl<'a> MerkleBlobLeftChildFirstIterator<'a> {
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(MerkleBlobLeftChildFirstIteratorItem {
visited: false,
index: 0,
index: TreeIndex(0),
});
}

Expand Down Expand Up @@ -1516,7 +1542,7 @@ impl<'a> MerkleBlobParentFirstIterator<'a> {
fn new(blob: &'a [u8]) -> Self {
let mut deque = VecDeque::new();
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(0);
deque.push_back(TreeIndex(0));
}

Self { blob, deque }
Expand Down Expand Up @@ -1552,7 +1578,7 @@ impl<'a> MerkleBlobBreadthFirstIterator<'a> {
fn new(blob: &'a [u8]) -> Self {
let mut deque = VecDeque::new();
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(0);
deque.push_back(TreeIndex(0));
}

Self { blob, deque }
Expand Down Expand Up @@ -1673,7 +1699,7 @@ mod tests {

#[rstest]
fn test_get_lineage(small_blob: MerkleBlob) {
let lineage = small_blob.get_lineage_with_indexes(2).unwrap();
let lineage = small_blob.get_lineage_with_indexes(TreeIndex(2)).unwrap();
for (_, node) in &lineage {
println!("{node:?}");
}
Expand All @@ -1683,8 +1709,8 @@ mod tests {
}

#[rstest]
#[case::right(0, 2, Side::Left)]
#[case::left(0xff, 1, Side::Right)]
#[case::right(0, TreeIndex(2), Side::Left)]
#[case::left(0xff, TreeIndex(1), Side::Right)]
fn test_get_random_insert_location_by_seed(
#[case] seed: u8,
#[case] expected_index: TreeIndex,
Expand Down Expand Up @@ -1868,7 +1894,10 @@ mod tests {
let index = small_blob.key_to_index[&key];
small_blob.delete(key).unwrap();

assert_eq!(small_blob.free_indexes, HashSet::from([index, 2]));
assert_eq!(
small_blob.free_indexes,
HashSet::from([index, TreeIndex(2)])
);
}

#[rstest]
Expand All @@ -1879,7 +1908,7 @@ mod tests {
small_blob.delete(key).unwrap();
open_dot(small_blob.to_dot().set_note("after delete"));

let expected = HashSet::from([1, 2]);
let expected = HashSet::from([TreeIndex(1), TreeIndex(2)]);
assert_eq!(small_blob.free_indexes, expected);
}

Expand All @@ -1905,20 +1934,23 @@ mod tests {
#[should_panic(expected = "unable to get sibling index from a leaf")]
fn test_node_specific_sibling_index_panics_for_leaf() {
let leaf = NodeSpecific::Leaf { key: 0, value: 0 };
leaf.sibling_index(0);
leaf.sibling_index(TreeIndex(0));
}

#[test]
#[should_panic(expected = "index not a child: 2")]
fn test_node_specific_sibling_index_panics_for_unknown_sibling() {
let node = NodeSpecific::Internal { left: 0, right: 1 };
node.sibling_index(2);
let node = NodeSpecific::Internal {
left: TreeIndex(0),
right: TreeIndex(1),
};
node.sibling_index(TreeIndex(2));
}

#[rstest]
fn test_get_free_indexes(small_blob: MerkleBlob) {
let mut blob = small_blob.blob.clone();
let expected_free_index = (blob.len() / BLOCK_SIZE) as TreeIndex;
let expected_free_index = TreeIndex((blob.len() / BLOCK_SIZE) as u32);
blob.extend_from_slice(&[0; BLOCK_SIZE]);
let (free_indexes, _) = get_free_indexes_and_keys_values_indexes(&blob);
assert_eq!(free_indexes, HashSet::from([expected_free_index]));
Expand Down

0 comments on commit 5899478

Please sign in to comment.