Skip to content

Commit

Permalink
iter: introduce right-to-left post-order iterator, use throughout cod…
Browse files Browse the repository at this point in the history
…ebase

We have several algorithms throughout the codebase which "translate" a
recursive object by running a post-order iterator over it, building a
modified copy node-by-node.

We frequently do this by iterating over an Arc structure, pushing each
created node onto a stack, and then using the `child_indices` member of
the `PostOrderIterItem` struct to index into the stack. We copy elements
out of the stack using Arc::clone and then push a new element.

The result is that for an object with N nodes, we construct a stack with
N elements, call Arc::clone N-1 times, and often we need to bend over
backward to turn &self into an Arc<Self> before starting.

There is a much more efficient way: with a post-order iterator, each
node appears directly after its children. So we can just pop children
off of the stack, construct the new node, and push that onto the stack.
As long as we always pop all of the children, our stack will never grow
beyond the depth of the object in question, and we can avoid some
Arc::clones.

Using a right-to-left iterator means that we can call .pop() in the
"natural" way rather than having to muck about reordering the children.

This commit converts every single use of post_order_iter in the library
to use this new algorithm.

In the case of Miniscript::substitute_raw_pkh, the old algorithm was
actually completely wrong. The next commit adds a unit test.
  • Loading branch information
apoelstra committed Aug 20, 2024
1 parent 5da250a commit e31d52b
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 111 deletions.
55 changes: 55 additions & 0 deletions src/iter/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ pub trait TreeLike: Clone + Sized {
fn post_order_iter(self) -> PostOrderIter<Self> {
PostOrderIter { index: 0, stack: vec![IterStackItem::unprocessed(self, None)] }
}

/// Obtains an iterator of all the nodes rooted at the DAG, in right-to-left post order.
///
/// This ordering is useful for "translation" algorithms which iterate over a
/// structure, pushing translated nodes and popping children.
fn rtl_post_order_iter(self) -> RtlPostOrderIter<Self> {
RtlPostOrderIter { inner: Rtl(self).post_order_iter() }
}
}

/// Element stored internally on the stack of a [`PostOrderIter`].
Expand Down Expand Up @@ -202,6 +210,53 @@ impl<T: TreeLike> Iterator for PostOrderIter<T> {
}
}

/// Adaptor structure to allow iterating in right-to-left order.
#[derive(Clone, Debug)]
struct Rtl<T>(pub T);

impl<T: TreeLike> TreeLike for Rtl<T> {
type NaryChildren = T::NaryChildren;

fn nary_len(tc: &Self::NaryChildren) -> usize { T::nary_len(tc) }
fn nary_index(tc: Self::NaryChildren, idx: usize) -> Self {
let rtl_idx = T::nary_len(&tc) - idx - 1;
Rtl(T::nary_index(tc, rtl_idx))
}

fn as_node(&self) -> Tree<Self, Self::NaryChildren> {
match self.0.as_node() {
Tree::Nullary => Tree::Nullary,
Tree::Unary(a) => Tree::Unary(Rtl(a)),
Tree::Binary(a, b) => Tree::Binary(Rtl(b), Rtl(a)),
Tree::Ternary(a, b, c) => Tree::Ternary(Rtl(c), Rtl(b), Rtl(a)),
Tree::Nary(data) => Tree::Nary(data),
}
}
}

/// Iterates over a DAG in _right-to-left post order_.
///
/// That means nodes are yielded in the order (right child, left child, parent).
#[derive(Clone, Debug)]
pub struct RtlPostOrderIter<T> {
inner: PostOrderIter<Rtl<T>>,
}

impl<T: TreeLike> Iterator for RtlPostOrderIter<T> {
type Item = PostOrderIterItem<T>;

fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|mut item| {
item.child_indices.reverse();
PostOrderIterItem {
child_indices: item.child_indices,
index: item.index,
node: item.node.0,
}
})
}
}

/// Iterates over a [`TreeLike`] in _pre order_.
///
/// Unlike the post-order iterator, this one does not keep track of indices
Expand Down
156 changes: 107 additions & 49 deletions src/miniscript/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ mod private {
/// and they can call `Miniscript::clone`.
fn clone(&self) -> Self {
let mut stack = vec![];
for item in self.post_order_iter() {
let child_n = |n| Arc::clone(&stack[item.child_indices[n]]);

for item in self.rtl_post_order_iter() {
let new_term = match item.node.node {
Terminal::PkK(ref p) => Terminal::PkK(p.clone()),
Terminal::PkH(ref p) => Terminal::PkH(p.clone()),
Expand All @@ -101,23 +99,31 @@ mod private {
Terminal::Hash160(ref x) => Terminal::Hash160(x.clone()),
Terminal::True => Terminal::True,
Terminal::False => Terminal::False,
Terminal::Alt(..) => Terminal::Alt(child_n(0)),
Terminal::Swap(..) => Terminal::Swap(child_n(0)),
Terminal::Check(..) => Terminal::Check(child_n(0)),
Terminal::DupIf(..) => Terminal::DupIf(child_n(0)),
Terminal::Verify(..) => Terminal::Verify(child_n(0)),
Terminal::NonZero(..) => Terminal::NonZero(child_n(0)),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(child_n(0)),
Terminal::AndV(..) => Terminal::AndV(child_n(0), child_n(1)),
Terminal::AndB(..) => Terminal::AndB(child_n(0), child_n(1)),
Terminal::AndOr(..) => Terminal::AndOr(child_n(0), child_n(1), child_n(2)),
Terminal::OrB(..) => Terminal::OrB(child_n(0), child_n(1)),
Terminal::OrD(..) => Terminal::OrD(child_n(0), child_n(1)),
Terminal::OrC(..) => Terminal::OrC(child_n(0), child_n(1)),
Terminal::OrI(..) => Terminal::OrI(child_n(0), child_n(1)),
Terminal::Thresh(ref thresh) => Terminal::Thresh(
thresh.map_from_post_order_iter(&item.child_indices, &stack),
Terminal::Alt(..) => Terminal::Alt(stack.pop().unwrap()),
Terminal::Swap(..) => Terminal::Swap(stack.pop().unwrap()),
Terminal::Check(..) => Terminal::Check(stack.pop().unwrap()),
Terminal::DupIf(..) => Terminal::DupIf(stack.pop().unwrap()),
Terminal::Verify(..) => Terminal::Verify(stack.pop().unwrap()),
Terminal::NonZero(..) => Terminal::NonZero(stack.pop().unwrap()),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(stack.pop().unwrap()),
Terminal::AndV(..) => {
Terminal::AndV(stack.pop().unwrap(), stack.pop().unwrap())
}
Terminal::AndB(..) => {
Terminal::AndB(stack.pop().unwrap(), stack.pop().unwrap())
}
Terminal::AndOr(..) => Terminal::AndOr(
stack.pop().unwrap(),
stack.pop().unwrap(),
stack.pop().unwrap(),
),
Terminal::OrB(..) => Terminal::OrB(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrD(..) => Terminal::OrD(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrC(..) => Terminal::OrC(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrI(..) => Terminal::OrI(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::Thresh(ref thresh) => {
Terminal::Thresh(thresh.map_ref(|_| stack.pop().unwrap()))
}
Terminal::Multi(ref thresh) => Terminal::Multi(thresh.clone()),
Terminal::MultiA(ref thresh) => Terminal::MultiA(thresh.clone()),
};
Expand All @@ -130,6 +136,7 @@ mod private {
}));
}

assert_eq!(stack.len(), 1);
Arc::try_unwrap(stack.pop().unwrap()).unwrap()
}
}
Expand Down Expand Up @@ -536,9 +543,7 @@ impl<Pk: MiniscriptKey, Ctx: ScriptContext> Miniscript<Pk, Ctx> {
T: Translator<Pk, Q, FuncError>,
{
let mut translated = vec![];
for data in Arc::new(self.clone()).post_order_iter() {
let child_n = |n| Arc::clone(&translated[data.child_indices[n]]);

for data in self.rtl_post_order_iter() {
let new_term = match data.node.node {
Terminal::PkK(ref p) => Terminal::PkK(t.pk(p)?),
Terminal::PkH(ref p) => Terminal::PkH(t.pk(p)?),
Expand All @@ -551,23 +556,39 @@ impl<Pk: MiniscriptKey, Ctx: ScriptContext> Miniscript<Pk, Ctx> {
Terminal::Hash160(ref x) => Terminal::Hash160(t.hash160(x)?),
Terminal::True => Terminal::True,
Terminal::False => Terminal::False,
Terminal::Alt(..) => Terminal::Alt(child_n(0)),
Terminal::Swap(..) => Terminal::Swap(child_n(0)),
Terminal::Check(..) => Terminal::Check(child_n(0)),
Terminal::DupIf(..) => Terminal::DupIf(child_n(0)),
Terminal::Verify(..) => Terminal::Verify(child_n(0)),
Terminal::NonZero(..) => Terminal::NonZero(child_n(0)),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(child_n(0)),
Terminal::AndV(..) => Terminal::AndV(child_n(0), child_n(1)),
Terminal::AndB(..) => Terminal::AndB(child_n(0), child_n(1)),
Terminal::AndOr(..) => Terminal::AndOr(child_n(0), child_n(1), child_n(2)),
Terminal::OrB(..) => Terminal::OrB(child_n(0), child_n(1)),
Terminal::OrD(..) => Terminal::OrD(child_n(0), child_n(1)),
Terminal::OrC(..) => Terminal::OrC(child_n(0), child_n(1)),
Terminal::OrI(..) => Terminal::OrI(child_n(0), child_n(1)),
Terminal::Thresh(ref thresh) => Terminal::Thresh(
thresh.map_from_post_order_iter(&data.child_indices, &translated),
Terminal::Alt(..) => Terminal::Alt(translated.pop().unwrap()),
Terminal::Swap(..) => Terminal::Swap(translated.pop().unwrap()),
Terminal::Check(..) => Terminal::Check(translated.pop().unwrap()),
Terminal::DupIf(..) => Terminal::DupIf(translated.pop().unwrap()),
Terminal::Verify(..) => Terminal::Verify(translated.pop().unwrap()),
Terminal::NonZero(..) => Terminal::NonZero(translated.pop().unwrap()),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(translated.pop().unwrap()),
Terminal::AndV(..) => {
Terminal::AndV(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::AndB(..) => {
Terminal::AndB(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::AndOr(..) => Terminal::AndOr(
translated.pop().unwrap(),
translated.pop().unwrap(),
translated.pop().unwrap(),
),
Terminal::OrB(..) => {
Terminal::OrB(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::OrD(..) => {
Terminal::OrD(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::OrC(..) => {
Terminal::OrC(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::OrI(..) => {
Terminal::OrI(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::Thresh(ref thresh) => {
Terminal::Thresh(thresh.map_ref(|_| translated.pop().unwrap()))
}
Terminal::Multi(ref thresh) => Terminal::Multi(thresh.translate_ref(|k| t.pk(k))?),
Terminal::MultiA(ref thresh) => {
Terminal::MultiA(thresh.translate_ref(|k| t.pk(k))?)
Expand All @@ -582,22 +603,58 @@ impl<Pk: MiniscriptKey, Ctx: ScriptContext> Miniscript<Pk, Ctx> {

/// Substitutes raw public keys hashes with the public keys as provided by map.
pub fn substitute_raw_pkh(&self, pk_map: &BTreeMap<hash160::Hash, Pk>) -> Miniscript<Pk, Ctx> {
let mut translated = vec![];
for data in Arc::new(self.clone()).post_order_iter() {
let new_term = if let Terminal::RawPkH(ref p) = data.node.node {
match pk_map.get(p) {
Some(pk) => Terminal::PkH(pk.clone()),
None => Terminal::RawPkH(*p),
let mut stack = vec![];
for item in self.rtl_post_order_iter() {
let new_term = match item.node.node {
Terminal::PkK(ref p) => Terminal::PkK(p.clone()),
Terminal::PkH(ref p) => Terminal::PkH(p.clone()),
// This algorithm is identical to Clone::clone except for this line.
Terminal::RawPkH(ref hash) => match pk_map.get(hash) {
Some(p) => Terminal::PkH(p.clone()),
None => Terminal::RawPkH(*hash),
},
Terminal::After(ref n) => Terminal::After(*n),
Terminal::Older(ref n) => Terminal::Older(*n),
Terminal::Sha256(ref x) => Terminal::Sha256(x.clone()),
Terminal::Hash256(ref x) => Terminal::Hash256(x.clone()),
Terminal::Ripemd160(ref x) => Terminal::Ripemd160(x.clone()),
Terminal::Hash160(ref x) => Terminal::Hash160(x.clone()),
Terminal::True => Terminal::True,
Terminal::False => Terminal::False,
Terminal::Alt(..) => Terminal::Alt(stack.pop().unwrap()),
Terminal::Swap(..) => Terminal::Swap(stack.pop().unwrap()),
Terminal::Check(..) => Terminal::Check(stack.pop().unwrap()),
Terminal::DupIf(..) => Terminal::DupIf(stack.pop().unwrap()),
Terminal::Verify(..) => Terminal::Verify(stack.pop().unwrap()),
Terminal::NonZero(..) => Terminal::NonZero(stack.pop().unwrap()),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(stack.pop().unwrap()),
Terminal::AndV(..) => Terminal::AndV(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::AndB(..) => Terminal::AndB(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::AndOr(..) => Terminal::AndOr(
stack.pop().unwrap(),
stack.pop().unwrap(),
stack.pop().unwrap(),
),
Terminal::OrB(..) => Terminal::OrB(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrD(..) => Terminal::OrD(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrC(..) => Terminal::OrC(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrI(..) => Terminal::OrI(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::Thresh(ref thresh) => {
Terminal::Thresh(thresh.map_ref(|_| stack.pop().unwrap()))
}
} else {
data.node.node.clone()
Terminal::Multi(ref thresh) => Terminal::Multi(thresh.clone()),
Terminal::MultiA(ref thresh) => Terminal::MultiA(thresh.clone()),
};

let new_ms = Miniscript::from_ast(new_term).expect("typeck");
translated.push(Arc::new(new_ms));
stack.push(Arc::new(Miniscript::from_components_unchecked(
new_term,
item.node.ty,
item.node.ext,
)));
}

Arc::try_unwrap(translated.pop().unwrap()).unwrap()
assert_eq!(stack.len(), 1);
Arc::try_unwrap(stack.pop().unwrap()).unwrap()
}
}

Expand Down Expand Up @@ -822,6 +879,7 @@ mod tests {
}
let roundtrip = Miniscript::from_str(&display).expect("parse string serialization");
assert_eq!(roundtrip, script);
assert_eq!(roundtrip.clone(), script);
}

fn string_display_debug_test<Ctx: ScriptContext>(
Expand Down
Loading

0 comments on commit e31d52b

Please sign in to comment.