Skip to content

Commit

Permalink
Convert NodePtr to a newtype rather than alias (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity authored Aug 10, 2023
1 parent f7b6331 commit 41b9bd1
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 31 deletions.
31 changes: 16 additions & 15 deletions src/allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use crate::number::{node_from_number, number_from_u8, Number};
use crate::reduction::EvalErr;
use bls12_381::{G1Affine, G1Projective, G2Affine, G2Projective};

pub type NodePtr = i32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct NodePtr(pub i32);

pub enum SExp {
Atom,
Expand Down Expand Up @@ -140,7 +141,7 @@ impl Allocator {
self.u8_vec.extend_from_slice(v);
let end = self.u8_vec.len() as u32;
self.atom_vec.push(AtomBuf { start, end });
Ok(-(self.atom_vec.len() as i32))
Ok(NodePtr(-(self.atom_vec.len() as i32)))
}

pub fn new_number(&mut self, v: Number) -> Result<NodePtr, EvalErr> {
Expand All @@ -163,17 +164,17 @@ impl Allocator {
return err(self.null(), "too many pairs");
}
self.pair_vec.push(IntPair { first, rest });
Ok(r)
Ok(NodePtr(r))
}

pub fn new_substr(&mut self, node: NodePtr, start: u32, end: u32) -> Result<NodePtr, EvalErr> {
if node >= 0 {
if node.0 >= 0 {
return err(node, "(internal error) substr expected atom, got pair");
}
if self.atom_vec.len() == self.atom_limit {
return err(self.null(), "too many atoms");
}
let atom = self.atom_vec[(-node - 1) as usize];
let atom = self.atom_vec[(-node.0 - 1) as usize];
let atom_len = atom.end - atom.start;
if start > atom_len {
return err(node, "substr start out of bounds");
Expand All @@ -188,7 +189,7 @@ impl Allocator {
start: atom.start + start,
end: atom.start + end,
});
Ok(-(self.atom_vec.len() as i32))
Ok(NodePtr(-(self.atom_vec.len() as i32)))
}

pub fn new_concat(&mut self, new_size: usize, nodes: &[NodePtr]) -> Result<NodePtr, EvalErr> {
Expand All @@ -203,12 +204,12 @@ impl Allocator {

let mut counter: usize = 0;
for node in nodes {
if *node >= 0 {
if node.0 >= 0 {
self.u8_vec.truncate(start);
return err(*node, "(internal error) concat expected atom, got pair");
}

let term = self.atom_vec[(-node - 1) as usize];
let term = self.atom_vec[(-node.0 - 1) as usize];
if counter + term.len() > new_size {
self.u8_vec.truncate(start);
return err(*node, "(internal error) concat passed invalid new_size");
Expand All @@ -229,16 +230,16 @@ impl Allocator {
start: (start as u32),
end,
});
Ok(-(self.atom_vec.len() as i32))
Ok(NodePtr(-(self.atom_vec.len() as i32)))
}

pub fn atom_eq(&self, lhs: NodePtr, rhs: NodePtr) -> bool {
self.atom(lhs) == self.atom(rhs)
}

pub fn atom(&self, node: NodePtr) -> &[u8] {
assert!(node < 0, "expected atom, got pair");
let atom = self.atom_vec[(-node - 1) as usize];
assert!(node.0 < 0, "expected atom, got pair");
let atom = self.atom_vec[(-node.0 - 1) as usize];
&self.u8_vec[atom.start as usize..atom.end as usize]
}

Expand Down Expand Up @@ -289,8 +290,8 @@ impl Allocator {
}

pub fn sexp(&self, node: NodePtr) -> SExp {
if node >= 0 {
let pair = self.pair_vec[node as usize];
if node.0 >= 0 {
let pair = self.pair_vec[node.0 as usize];
SExp::Pair(pair.first, pair.rest)
} else {
SExp::Atom
Expand All @@ -310,11 +311,11 @@ impl Allocator {
}

pub fn null(&self) -> NodePtr {
-1
NodePtr(-1)
}

pub fn one(&self) -> NodePtr {
-2
NodePtr(-2)
}

#[cfg(feature = "counters")]
Expand Down
12 changes: 6 additions & 6 deletions src/op_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn get_args<const N: usize>(
) -> Result<[NodePtr; N], EvalErr> {
let mut next = args;
let mut counter = 0;
let mut ret: [NodePtr; N] = [0; N];
let mut ret: [NodePtr; N] = [NodePtr(0); N];

while let Some((first, rest)) = a.next(next) {
next = rest;
Expand Down Expand Up @@ -92,7 +92,7 @@ pub fn get_varargs<const N: usize>(
) -> Result<([NodePtr; N], usize), EvalErr> {
let mut next = args;
let mut counter = 0;
let mut ret: [NodePtr; N] = [0; N];
let mut ret: [NodePtr; N] = [NodePtr(0); N];

while let Some((first, rest)) = a.next(next) {
next = rest;
Expand Down Expand Up @@ -132,19 +132,19 @@ fn test_get_varargs() {
);
assert_eq!(
get_varargs::<4>(&a, args3, "test").unwrap(),
([a1, a2, a3, 0], 3)
([a1, a2, a3, NodePtr(0)], 3)
);
assert_eq!(
get_varargs::<4>(&a, args2, "test").unwrap(),
([a2, a3, 0, 0], 2)
([a2, a3, NodePtr(0), NodePtr(0)], 2)
);
assert_eq!(
get_varargs::<4>(&a, args1, "test").unwrap(),
([a3, 0, 0, 0], 1)
([a3, NodePtr(0), NodePtr(0), NodePtr(0)], 1)
);
assert_eq!(
get_varargs::<4>(&a, args0, "test").unwrap(),
([0, 0, 0, 0], 0)
([NodePtr(0), NodePtr(0), NodePtr(0), NodePtr(0)], 0)
);

let r = get_varargs::<3>(&a, args4, "test").unwrap_err();
Expand Down
18 changes: 9 additions & 9 deletions src/serde/object_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ pub struct ObjectCache<'a, T> {
/// and negative values become odd indices.

fn node_to_index(node: &NodePtr) -> usize {
let node = *node;
if node < 0 {
(-node - node - 1) as usize
let value = node.0;
if value < 0 {
(-value - value - 1) as usize
} else {
(node + node) as usize
(value + value) as usize
}
}

Expand Down Expand Up @@ -264,11 +264,11 @@ fn test_serialized_length() {

#[test]
fn test_node_to_index() {
assert_eq!(node_to_index(&0), 0);
assert_eq!(node_to_index(&1), 2);
assert_eq!(node_to_index(&2), 4);
assert_eq!(node_to_index(&-1), 1);
assert_eq!(node_to_index(&-2), 3);
assert_eq!(node_to_index(&NodePtr(0)), 0);
assert_eq!(node_to_index(&NodePtr(1)), 2);
assert_eq!(node_to_index(&NodePtr(2)), 4);
assert_eq!(node_to_index(&NodePtr(-1)), 1);
assert_eq!(node_to_index(&NodePtr(-2)), 3);
}

// this test takes a very long time (>60s) in debug mode, so it only runs in release mode
Expand Down
2 changes: 1 addition & 1 deletion tools/src/bin/benchmark-clvm-cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ enum OpArgs {

// special argument to indicate it should be substituted for varied in the FreeBytes test to
// measure cost per byte
const VARIABLE: NodePtr = 999;
const VARIABLE: NodePtr = NodePtr(999);

// builds calls in the form:
// (<op> arg arg ...)
Expand Down

0 comments on commit 41b9bd1

Please sign in to comment.