diff --git a/crates/chia-datalayer/src/merkle.rs b/crates/chia-datalayer/src/merkle.rs index 4e0cd878..9911118a 100644 --- a/crates/chia-datalayer/src/merkle.rs +++ b/crates/chia-datalayer/src/merkle.rs @@ -1,5 +1,5 @@ #[cfg(feature = "py-bindings")] -use pyo3::{buffer::PyBuffer, pyclass, pymethods, PyResult}; +use pyo3::{buffer::PyBuffer, exceptions::PyValueError, pyclass, pymethods, PyResult}; use clvmr::sha2::Sha256; use num_traits::ToBytes; @@ -17,6 +17,11 @@ type KvId = i64; const fn range_by_length(start: usize, length: usize) -> Range { start..start + length } +const fn max(left: usize, right: usize) -> usize { + [left, right][(left < right) as usize] +} +// TODO: once not experimental... something closer to this +// const fn max(left: T, right: T) -> T { if left < right {right} else {left} } // TODO: consider in more detail other serialization tools such as serde and streamable // define the serialized block format @@ -36,9 +41,7 @@ const RIGHT_RANGE: Range = range_by_length(LEFT_RANGE.end, size_of:: = range_by_length(PARENT_RANGE.end, size_of::()); const VALUE_RANGE: Range = range_by_length(KEY_RANGE.end, size_of::()); -// TODO: clearly shouldn't be hard coded -// TODO: max of RIGHT_RANGE.end and VALUE_RANGE.end -const DATA_SIZE: usize = VALUE_RANGE.end; +const DATA_SIZE: usize = max(RIGHT_RANGE.end, VALUE_RANGE.end); const BLOCK_SIZE: usize = METADATA_SIZE + DATA_SIZE; type BlockBytes = [u8; BLOCK_SIZE]; type MetadataBytes = [u8; METADATA_SIZE]; @@ -54,7 +57,6 @@ pub enum NodeType { impl NodeType { pub fn from_u8(value: u8) -> Result { - // TODO: identify some useful structured serialization tooling we use match value { // ha! feel free to laugh at this x if (NodeType::Internal as u8 == x) => Ok(NodeType::Internal), @@ -119,7 +121,6 @@ pub struct NodeMetadata { impl NodeMetadata { pub fn from_bytes(blob: MetadataBytes) -> Result { // TODO: could save 1-2% of tree space by packing (and maybe don't do that) - // TODO: identify some useful structured serialization tooling we use Ok(Self { node_type: Self::node_type_from_bytes(blob)?, dirty: Self::dirty_from_bytes(blob)?, @@ -176,8 +177,6 @@ impl NodeSpecific { } impl Node { - // TODO: talk through whether this is good practice. being prepared for an error even though - // presently it won't happen #[allow(clippy::unnecessary_wraps)] pub fn from_bytes(metadata: &NodeMetadata, blob: DataBytes) -> Result { Ok(Self { @@ -417,8 +416,9 @@ impl MerkleBlob { side: &Side, ) -> Result<(), String> { self.clear(); - // TODO: just handling the nodes below being out of order. this all still smells a bit - self.blob.resize(BLOCK_SIZE * 3, 0); + let root_index = self.get_new_index(); + let left_index = self.get_new_index(); + let right_index = self.get_new_index(); let new_internal_block = Block { metadata: NodeMetadata { @@ -427,12 +427,15 @@ impl MerkleBlob { }, node: Node { parent: None, - specific: NodeSpecific::Internal { left: 1, right: 2 }, + specific: NodeSpecific::Internal { + left: left_index, + right: right_index, + }, hash: *internal_node_hash, }, }; - self.insert_entry_to_blob(0, &new_internal_block)?; + self.insert_entry_to_blob(root_index, &new_internal_block)?; let NodeSpecific::Leaf { key: old_leaf_key, @@ -447,8 +450,8 @@ impl MerkleBlob { let nodes = [ ( match side { - Side::Left => 2, - Side::Right => 1, + Side::Left => right_index, + Side::Right => left_index, }, Node { parent: Some(0), @@ -461,8 +464,8 @@ impl MerkleBlob { ), ( match side { - Side::Left => 1, - Side::Right => 2, + Side::Left => left_index, + Side::Right => right_index, }, node, ), @@ -739,9 +742,11 @@ impl MerkleBlob { fn get_new_index(&mut self) -> TreeIndex { match self.free_indexes.iter().next().copied() { None => { - // TODO: should this update free indexes...? let index = self.extend_index(); self.blob.extend_from_slice(&[0; BLOCK_SIZE]); + // NOTE: explicitly not marking index as free since that would hazard two + // sequential calls to this function through this path to both return + // the same index index } Some(new_index) => { @@ -864,17 +869,7 @@ impl MerkleBlob { } pub fn get_node(&self, index: TreeIndex) -> Result { - // TODO: use Block::from_bytes() - // TODO: handle invalid indexes? - // TODO: handle overflows? - let block = self.get_block_bytes(index)?; - let metadata_blob: MetadataBytes = block[METADATA_RANGE].try_into().unwrap(); - let data_blob: DataBytes = block[DATA_RANGE].try_into().unwrap(); - let metadata = NodeMetadata::from_bytes(metadata_blob) - .map_err(|message| format!("failed loading metadata: {message})"))?; - - Node::from_bytes(&metadata, data_blob) - .map_err(|message| format!("failed loading node: {message}")) + Ok(self.get_block(index)?.node) } pub fn get_parent_index(&self, index: TreeIndex) -> Result { @@ -886,7 +881,6 @@ impl MerkleBlob { } pub fn get_lineage(&self, index: TreeIndex) -> Result, String> { - // TODO: what about an index that happens to be the null index? a question for everywhere i guess let mut next_index = Some(index); let mut lineage = vec![]; @@ -900,8 +894,6 @@ impl MerkleBlob { } pub fn get_lineage_indexes(&self, index: TreeIndex) -> Result, String> { - // TODO: yep, this 'optimization' might be overkill, and should be speed compared regardless - // TODO: what about an index that happens to be the null index? a question for everywhere i guess let mut next_index = Some(index); let mut lineage: Vec = vec![]; @@ -932,7 +924,6 @@ impl MerkleBlob { // an iteration that's already doing that let left_hash = self.get_hash(left)?; let right_hash = self.get_hash(right)?; - // TODO: wrap this up in Block maybe? just to have 'control' of dirty being 'accurate' block.update_hash(&left_hash, &right_hash); self.insert_entry_to_blob(index, &block)?; } @@ -1013,15 +1004,20 @@ impl MerkleBlob { impl PartialEq for MerkleBlob { fn eq(&self, other: &Self) -> bool { - // TODO: should we check the indexes? + // NOTE: this is checking tree structure equality, not serialized bytes equality for ((_, self_block), (_, other_block)) in zip(self, other) { if (self_block.metadata.dirty || other_block.metadata.dirty) || self_block.node.hash != other_block.node.hash - // TODO: isn't only a leaf supposed to check this? - || self_block.node.specific != other_block.node.specific { return false; } + match self_block.node.specific { + // NOTE: this is effectively checked by the controlled overall traversal + NodeSpecific::Internal { .. } => {} + NodeSpecific::Leaf { .. } => { + return self_block.node.specific == other_block.node.specific + } + } } true @@ -1033,8 +1029,7 @@ impl<'a> IntoIterator for &'a MerkleBlob { type IntoIter = MerkleBlobLeftChildFirstIterator<'a>; fn into_iter(self) -> Self::IntoIter { - // TODO: review types around this to avoid copying - MerkleBlobLeftChildFirstIterator::new(&self.blob[..]) + MerkleBlobLeftChildFirstIterator::new(&self.blob) } } @@ -1052,25 +1047,28 @@ impl MerkleBlob { let slice = unsafe { std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes()) }; - Ok(Self::new(Vec::from(slice)).unwrap()) + match Self::new(Vec::from(slice)) { + Ok(blob) => Ok(blob), + Err(message) => Err(PyValueError::new_err(message)), + } } #[pyo3(name = "insert")] pub fn py_insert(&mut self, key: KvId, value: KvId, hash: Hash) -> PyResult<()> { - // TODO: consider the error - // TODO: expose insert location - self.insert(key, value, &hash, InsertLocation::Auto) - .unwrap(); - - Ok(()) + if let Err(message) = self.insert(key, value, &hash, InsertLocation::Auto) { + Err(PyValueError::new_err(message)) + } else { + Ok(()) + } } #[pyo3(name = "delete")] pub fn py_delete(&mut self, key: KvId) -> PyResult<()> { - // TODO: consider the error - self.delete(key).unwrap(); - - Ok(()) + if let Err(message) = self.delete(key) { + Err(PyValueError::new_err(message)) + } else { + Ok(()) + } } #[pyo3(name = "__len__")] @@ -1238,7 +1236,6 @@ mod tests { #[test] fn test_node_type_serialized_values() { - // TODO: can i make sure we cover all variants? assert_eq!(NodeType::Internal as u8, 0); assert_eq!(NodeType::Leaf as u8, 1); @@ -1271,7 +1268,6 @@ mod tests { #[rstest] fn test_node_metadata_from_to( #[values(false, true)] dirty: bool, - // TODO: can we make sure we cover all variants #[values(NodeType::Internal, NodeType::Leaf)] node_type: NodeType, ) { let bytes: [u8; 2] = [node_type.to_u8(), dirty as u8]; @@ -1403,7 +1399,7 @@ mod tests { for i in 0..100_000 { let start = Instant::now(); merkle_blob - // TODO: yeah this hash is garbage + // NOTE: yeah this hash is garbage .insert(i, i, &sha256_num(i), InsertLocation::Auto) .unwrap(); let end = Instant::now();