Skip to content

Commit

Permalink
only 14 to go
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky committed Oct 17, 2024
1 parent 6e01fe0 commit ae77486
Showing 1 changed file with 47 additions and 51 deletions.
98 changes: 47 additions & 51 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(feature = "py-bindings")]
use pyo3::{buffer::PyBuffer, pyclass, pymethods, PyResult};
use pyo3::{buffer::PyBuffer, exceptions::PyValueError, pyclass, pymethods, PyResult};

use clvmr::sha2::Sha256;
use num_traits::ToBytes;
Expand All @@ -17,6 +17,11 @@ type KvId = i64;
const fn range_by_length(start: usize, length: usize) -> Range<usize> {
start..start + length
}
const fn max(left: usize, right: usize) -> usize {
[left, right][(left < right) as usize]
}
// TODO: once not experimental... something closer to this
// const fn max<T: ~const std::cmp::PartialOrd>(left: T, right: T) -> T { if left < right {right} else {left} }

// TODO: consider in more detail other serialization tools such as serde and streamable
// define the serialized block format
Expand All @@ -36,9 +41,7 @@ const RIGHT_RANGE: Range<usize> = range_by_length(LEFT_RANGE.end, size_of::<Tree
const KEY_RANGE: Range<usize> = range_by_length(PARENT_RANGE.end, size_of::<KvId>());
const VALUE_RANGE: Range<usize> = range_by_length(KEY_RANGE.end, size_of::<KvId>());

// TODO: clearly shouldn't be hard coded
// TODO: max of RIGHT_RANGE.end and VALUE_RANGE.end
const DATA_SIZE: usize = VALUE_RANGE.end;
const DATA_SIZE: usize = max(RIGHT_RANGE.end, VALUE_RANGE.end);
const BLOCK_SIZE: usize = METADATA_SIZE + DATA_SIZE;
type BlockBytes = [u8; BLOCK_SIZE];
type MetadataBytes = [u8; METADATA_SIZE];
Expand All @@ -54,7 +57,6 @@ pub enum NodeType {

impl NodeType {
pub fn from_u8(value: u8) -> Result<Self, String> {
// TODO: identify some useful structured serialization tooling we use
match value {
// ha! feel free to laugh at this
x if (NodeType::Internal as u8 == x) => Ok(NodeType::Internal),
Expand Down Expand Up @@ -119,7 +121,6 @@ pub struct NodeMetadata {
impl NodeMetadata {
pub fn from_bytes(blob: MetadataBytes) -> Result<Self, String> {
// TODO: could save 1-2% of tree space by packing (and maybe don't do that)
// TODO: identify some useful structured serialization tooling we use
Ok(Self {
node_type: Self::node_type_from_bytes(blob)?,
dirty: Self::dirty_from_bytes(blob)?,
Expand Down Expand Up @@ -176,8 +177,6 @@ impl NodeSpecific {
}

impl Node {
// TODO: talk through whether this is good practice. being prepared for an error even though
// presently it won't happen
#[allow(clippy::unnecessary_wraps)]
pub fn from_bytes(metadata: &NodeMetadata, blob: DataBytes) -> Result<Self, String> {
Ok(Self {
Expand Down Expand Up @@ -417,8 +416,9 @@ impl MerkleBlob {
side: &Side,
) -> Result<(), String> {
self.clear();
// TODO: just handling the nodes below being out of order. this all still smells a bit
self.blob.resize(BLOCK_SIZE * 3, 0);
let root_index = self.get_new_index();
let left_index = self.get_new_index();
let right_index = self.get_new_index();

let new_internal_block = Block {
metadata: NodeMetadata {
Expand All @@ -427,12 +427,15 @@ impl MerkleBlob {
},
node: Node {
parent: None,
specific: NodeSpecific::Internal { left: 1, right: 2 },
specific: NodeSpecific::Internal {
left: left_index,
right: right_index,
},
hash: *internal_node_hash,
},
};

self.insert_entry_to_blob(0, &new_internal_block)?;
self.insert_entry_to_blob(root_index, &new_internal_block)?;

let NodeSpecific::Leaf {
key: old_leaf_key,
Expand All @@ -447,8 +450,8 @@ impl MerkleBlob {
let nodes = [
(
match side {
Side::Left => 2,
Side::Right => 1,
Side::Left => right_index,
Side::Right => left_index,
},
Node {
parent: Some(0),
Expand All @@ -461,8 +464,8 @@ impl MerkleBlob {
),
(
match side {
Side::Left => 1,
Side::Right => 2,
Side::Left => left_index,
Side::Right => right_index,
},
node,
),
Expand Down Expand Up @@ -739,9 +742,11 @@ impl MerkleBlob {
fn get_new_index(&mut self) -> TreeIndex {
match self.free_indexes.iter().next().copied() {
None => {
// TODO: should this update free indexes...?
let index = self.extend_index();
self.blob.extend_from_slice(&[0; BLOCK_SIZE]);
// NOTE: explicitly not marking index as free since that would hazard two
// sequential calls to this function through this path to both return
// the same index
index
}
Some(new_index) => {
Expand Down Expand Up @@ -864,17 +869,7 @@ impl MerkleBlob {
}

pub fn get_node(&self, index: TreeIndex) -> Result<Node, String> {
// TODO: use Block::from_bytes()
// TODO: handle invalid indexes?
// TODO: handle overflows?
let block = self.get_block_bytes(index)?;
let metadata_blob: MetadataBytes = block[METADATA_RANGE].try_into().unwrap();
let data_blob: DataBytes = block[DATA_RANGE].try_into().unwrap();
let metadata = NodeMetadata::from_bytes(metadata_blob)
.map_err(|message| format!("failed loading metadata: {message})"))?;

Node::from_bytes(&metadata, data_blob)
.map_err(|message| format!("failed loading node: {message}"))
Ok(self.get_block(index)?.node)
}

pub fn get_parent_index(&self, index: TreeIndex) -> Result<Parent, String> {
Expand All @@ -886,7 +881,6 @@ impl MerkleBlob {
}

pub fn get_lineage(&self, index: TreeIndex) -> Result<Vec<Node>, String> {
// TODO: what about an index that happens to be the null index? a question for everywhere i guess
let mut next_index = Some(index);
let mut lineage = vec![];

Expand All @@ -900,8 +894,6 @@ impl MerkleBlob {
}

pub fn get_lineage_indexes(&self, index: TreeIndex) -> Result<Vec<TreeIndex>, String> {
// TODO: yep, this 'optimization' might be overkill, and should be speed compared regardless
// TODO: what about an index that happens to be the null index? a question for everywhere i guess
let mut next_index = Some(index);
let mut lineage: Vec<TreeIndex> = vec![];

Expand Down Expand Up @@ -932,7 +924,6 @@ impl MerkleBlob {
// an iteration that's already doing that
let left_hash = self.get_hash(left)?;
let right_hash = self.get_hash(right)?;
// TODO: wrap this up in Block maybe? just to have 'control' of dirty being 'accurate'
block.update_hash(&left_hash, &right_hash);
self.insert_entry_to_blob(index, &block)?;
}
Expand Down Expand Up @@ -1013,15 +1004,20 @@ impl MerkleBlob {

impl PartialEq for MerkleBlob {
fn eq(&self, other: &Self) -> bool {
// TODO: should we check the indexes?
// NOTE: this is checking tree structure equality, not serialized bytes equality
for ((_, self_block), (_, other_block)) in zip(self, other) {
if (self_block.metadata.dirty || other_block.metadata.dirty)
|| self_block.node.hash != other_block.node.hash
// TODO: isn't only a leaf supposed to check this?
|| self_block.node.specific != other_block.node.specific
{
return false;
}
match self_block.node.specific {
// NOTE: this is effectively checked by the controlled overall traversal
NodeSpecific::Internal { .. } => {}
NodeSpecific::Leaf { .. } => {
return self_block.node.specific == other_block.node.specific
}
}
}

true
Expand All @@ -1033,8 +1029,7 @@ impl<'a> IntoIterator for &'a MerkleBlob {
type IntoIter = MerkleBlobLeftChildFirstIterator<'a>;

fn into_iter(self) -> Self::IntoIter {
// TODO: review types around this to avoid copying
MerkleBlobLeftChildFirstIterator::new(&self.blob[..])
MerkleBlobLeftChildFirstIterator::new(&self.blob)
}
}

Expand All @@ -1052,25 +1047,28 @@ impl MerkleBlob {
let slice =
unsafe { std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes()) };

Ok(Self::new(Vec::from(slice)).unwrap())
match Self::new(Vec::from(slice)) {
Ok(blob) => Ok(blob),
Err(message) => Err(PyValueError::new_err(message)),
}
}

#[pyo3(name = "insert")]
pub fn py_insert(&mut self, key: KvId, value: KvId, hash: Hash) -> PyResult<()> {
// TODO: consider the error
// TODO: expose insert location
self.insert(key, value, &hash, InsertLocation::Auto)
.unwrap();

Ok(())
if let Err(message) = self.insert(key, value, &hash, InsertLocation::Auto) {
Err(PyValueError::new_err(message))
} else {
Ok(())
}
}

#[pyo3(name = "delete")]
pub fn py_delete(&mut self, key: KvId) -> PyResult<()> {
// TODO: consider the error
self.delete(key).unwrap();

Ok(())
if let Err(message) = self.delete(key) {
Err(PyValueError::new_err(message))
} else {
Ok(())
}
}

#[pyo3(name = "__len__")]
Expand Down Expand Up @@ -1238,7 +1236,6 @@ mod tests {

#[test]
fn test_node_type_serialized_values() {
// TODO: can i make sure we cover all variants?
assert_eq!(NodeType::Internal as u8, 0);
assert_eq!(NodeType::Leaf as u8, 1);

Expand Down Expand Up @@ -1271,7 +1268,6 @@ mod tests {
#[rstest]
fn test_node_metadata_from_to(
#[values(false, true)] dirty: bool,
// TODO: can we make sure we cover all variants
#[values(NodeType::Internal, NodeType::Leaf)] node_type: NodeType,
) {
let bytes: [u8; 2] = [node_type.to_u8(), dirty as u8];
Expand Down Expand Up @@ -1403,7 +1399,7 @@ mod tests {
for i in 0..100_000 {
let start = Instant::now();
merkle_blob
// TODO: yeah this hash is garbage
// NOTE: yeah this hash is garbage
.insert(i, i, &sha256_num(i), InsertLocation::Auto)
.unwrap();
let end = Instant::now();
Expand Down

0 comments on commit ae77486

Please sign in to comment.