Skip to content

Commit

Permalink
add a version of traverse_path that's optimized for small int
Browse files Browse the repository at this point in the history
  • Loading branch information
arvidn committed Feb 2, 2024
1 parent bc714c7 commit 96ea62b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/run_program.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::traverse_path::traverse_path;
use crate::allocator::{Allocator, Checkpoint, NodePtr, SExp};
use super::traverse_path::{traverse_path, traverse_path_fast};
use crate::allocator::{Allocator, Checkpoint, NodePtr, NodeVisitor, SExp};
use crate::cost::Cost;
use crate::dialect::{Dialect, OperatorSet};
use crate::err_utils::err;
Expand Down Expand Up @@ -279,7 +279,15 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> {
// put a bunch of ops on op_stack
let SExp::Pair(op_node, op_list) = self.allocator.sexp(program) else {
// the program is just a bitfield path through the env tree
let r: Reduction = traverse_path(self.allocator, self.allocator.atom(program), env)?;
let r: Reduction = self.allocator.visit_node(program, |node| -> Response {
match node {
NodeVisitor::Buffer(buf) => traverse_path(self.allocator, buf, env),
NodeVisitor::U32(val) => traverse_path_fast(self.allocator, *val, env),
NodeVisitor::Pair(_, _) => {
panic!("expected atom, got pair");
}
}
})?;
self.push(r.1)?;
return Ok(r.0);
};
Expand Down
94 changes: 94 additions & 0 deletions src/traverse_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,42 @@ pub fn traverse_path(allocator: &Allocator, node_index: &[u8], args: NodePtr) ->
Ok(Reduction(cost, arg_list))
}

// The cost calculation for this version of traverse_path assumes the node_index has the canonical
// integer representation (which is true for SmallAtom in the allocator). If there are any
// redundant leading zeros, the slow path must be used
pub fn traverse_path_fast(allocator: &Allocator, mut node_index: u32, args: NodePtr) -> Response {
if node_index == 0 {
return Ok(Reduction(
TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT,
allocator.nil(),
));
}

let mut arg_list: NodePtr = args;

let mut cost: Cost = TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT;
let mut num_bits = 0;
while node_index != 1 {
let SExp::Pair(left, right) = allocator.sexp(arg_list) else {
return Err(EvalErr(arg_list, "path into atom".into()));
};

let is_bit_set: bool = (node_index & 0x01) != 0;
arg_list = if is_bit_set { right } else { left };
node_index >>= 1;
num_bits += 1
}

cost += num_bits * TRAVERSE_COST_PER_BIT;
// since positive numbers sometimes need a leading zero, e.g. 0x80, 0x8000 etc. We also
// need to add the cost of that leading zero byte
if num_bits == 7 || num_bits == 15 || num_bits == 23 || num_bits == 31 {
cost += TRAVERSE_COST_PER_ZERO_BYTE;
}

Ok(Reduction(cost, arg_list))
}

#[test]
fn test_msb_mask() {
assert_eq!(msb_mask(0x0), 0x0);
Expand Down Expand Up @@ -160,3 +196,61 @@ fn test_traverse_path() {
EvalErr(n2, "path into atom".to_string())
);
}

#[test]
fn test_traverse_path_fast_fast() {
use crate::allocator::Allocator;

let mut a = Allocator::new();
let nul = a.nil();
let n1 = a.new_atom(&[0, 1, 2]).unwrap();
let n2 = a.new_atom(&[4, 5, 6]).unwrap();

assert_eq!(traverse_path_fast(&a, 0, n1).unwrap(), Reduction(44, nul));
assert_eq!(traverse_path_fast(&a, 0b1, n1).unwrap(), Reduction(44, n1));
assert_eq!(traverse_path_fast(&a, 0b1, n2).unwrap(), Reduction(44, n2));

let n3 = a.new_pair(n1, n2).unwrap();
assert_eq!(traverse_path_fast(&a, 0b1, n3).unwrap(), Reduction(44, n3));
assert_eq!(traverse_path_fast(&a, 0b10, n3).unwrap(), Reduction(48, n1));
assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2));
assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2));

let list = a.new_pair(n1, nul).unwrap();
let list = a.new_pair(n2, list).unwrap();

assert_eq!(
traverse_path_fast(&a, 0b10, list).unwrap(),
Reduction(48, n2)
);
assert_eq!(
traverse_path_fast(&a, 0b101, list).unwrap(),
Reduction(52, n1)
);
assert_eq!(
traverse_path_fast(&a, 0b111, list).unwrap(),
Reduction(52, nul)
);

// errors
assert_eq!(
traverse_path_fast(&a, 0b1011, list).unwrap_err(),
EvalErr(nul, "path into atom".to_string())
);
assert_eq!(
traverse_path_fast(&a, 0b1101, list).unwrap_err(),
EvalErr(n1, "path into atom".to_string())
);
assert_eq!(
traverse_path_fast(&a, 0b1001, list).unwrap_err(),
EvalErr(n1, "path into atom".to_string())
);
assert_eq!(
traverse_path_fast(&a, 0b1010, list).unwrap_err(),
EvalErr(n2, "path into atom".to_string())
);
assert_eq!(
traverse_path_fast(&a, 0b1110, list).unwrap_err(),
EvalErr(n2, "path into atom".to_string())
);
}

0 comments on commit 96ea62b

Please sign in to comment.