From b84812858d8817f632ec2e6205930ec9f5697ebf Mon Sep 17 00:00:00 2001 From: Folkert Date: Wed, 21 Jun 2023 20:56:36 +0200 Subject: [PATCH] search for multiple TRMC opportunities --- .../cli_testing_examples/benchmarks/CFold.roc | 2 +- .../cli_testing_examples/benchmarks/Deriv.roc | 1 - .../benchmarks/NQueens.roc | 9 +- .../benchmarks/QuicksortApp.roc | 1 - .../benchmarks/RBTreeCk.roc | 10 +- .../benchmarks/RBTreeDel.roc | 1 - crates/compiler/collections/src/vec_set.rs | 16 ++ crates/compiler/mono/src/tail_recursion.rs | 216 +++++++++--------- 8 files changed, 139 insertions(+), 117 deletions(-) diff --git a/crates/cli_testing_examples/benchmarks/CFold.roc b/crates/cli_testing_examples/benchmarks/CFold.roc index 8c496c6f9c4..e4b1e2423dd 100644 --- a/crates/cli_testing_examples/benchmarks/CFold.roc +++ b/crates/cli_testing_examples/benchmarks/CFold.roc @@ -21,7 +21,7 @@ main = |> Task.putLine Err GetIntError -> - Task.putLine "Error: Failed to get Integer from stdin." + Task.putLine "Error: Failed to get Integer from stdin." Expr : [ Add Expr Expr, diff --git a/crates/cli_testing_examples/benchmarks/Deriv.roc b/crates/cli_testing_examples/benchmarks/Deriv.roc index 5ad3f3b9d0c..974c1ada668 100644 --- a/crates/cli_testing_examples/benchmarks/Deriv.roc +++ b/crates/cli_testing_examples/benchmarks/Deriv.roc @@ -23,7 +23,6 @@ main = Err GetIntError -> Task.putLine "Error: Failed to get Integer from stdin." - nest : (I64, Expr -> IO Expr), I64, Expr -> IO Expr nest = \f, n, e -> Task.loop { s: n, f, m: n, x: e } nestHelp diff --git a/crates/cli_testing_examples/benchmarks/NQueens.roc b/crates/cli_testing_examples/benchmarks/NQueens.roc index 62b931fb3e6..15593e37216 100644 --- a/crates/cli_testing_examples/benchmarks/NQueens.roc +++ b/crates/cli_testing_examples/benchmarks/NQueens.roc @@ -10,8 +10,8 @@ main = when inputResult is Ok n -> queens n # original koka 13 - |> Num.toStr - |> Task.putLine + |> Num.toStr + |> Task.putLine Err GetIntError -> Task.putLine "Error: Failed to get Integer from stdin." @@ -21,7 +21,8 @@ ConsList a : [Nil, Cons a (ConsList a)] queens = \n -> length (findSolutions n n) findSolutions = \n, k -> - if k <= 0 then # should we use U64 as input type here instead? + if k <= 0 then + # should we use U64 as input type here instead? Cons Nil Nil else extend n Nil (findSolutions n (k - 1)) @@ -40,7 +41,6 @@ appendSafe = \k, soln, solns -> else appendSafe (k - 1) soln solns - safe : I64, I64, ConsList I64 -> Bool safe = \queen, diagonal, xs -> when xs is @@ -51,7 +51,6 @@ safe = \queen, diagonal, xs -> else Bool.false - length : ConsList a -> I64 length = \xs -> lengthHelp xs 0 diff --git a/crates/cli_testing_examples/benchmarks/QuicksortApp.roc b/crates/cli_testing_examples/benchmarks/QuicksortApp.roc index cc38026216c..67766bc9820 100644 --- a/crates/cli_testing_examples/benchmarks/QuicksortApp.roc +++ b/crates/cli_testing_examples/benchmarks/QuicksortApp.roc @@ -23,7 +23,6 @@ main = Err GetIntError -> Task.putLine "Error: Failed to get Integer from stdin." - sort : List I64 -> List I64 sort = \list -> diff --git a/crates/cli_testing_examples/benchmarks/RBTreeCk.roc b/crates/cli_testing_examples/benchmarks/RBTreeCk.roc index 254be6d0ebf..5c685b195e6 100644 --- a/crates/cli_testing_examples/benchmarks/RBTreeCk.roc +++ b/crates/cli_testing_examples/benchmarks/RBTreeCk.roc @@ -93,9 +93,15 @@ ins = \tree, kx, vx -> Node Black a ky vy b -> if lt kx ky then - (if isRed a then balance1 (Node Black Leaf ky vy b) (ins a kx vx) else Node Black (ins a kx vx) ky vy b) + if isRed a then + balance1 (Node Black Leaf ky vy b) (ins a kx vx) + else + Node Black (ins a kx vx) ky vy b else if lt ky kx then - (if isRed b then balance2 (Node Black a ky vy Leaf) (ins b kx vx) else Node Black a ky vy (ins b kx vx)) + if isRed b then + balance2 (Node Black a ky vy Leaf) (ins b kx vx) + else + Node Black a ky vy (ins b kx vx) else Node Black a kx vx b diff --git a/crates/cli_testing_examples/benchmarks/RBTreeDel.roc b/crates/cli_testing_examples/benchmarks/RBTreeDel.roc index 37f0523ecac..7f5e940b73d 100644 --- a/crates/cli_testing_examples/benchmarks/RBTreeDel.roc +++ b/crates/cli_testing_examples/benchmarks/RBTreeDel.roc @@ -26,7 +26,6 @@ main = Err GetIntError -> Task.putLine "Error: Failed to get Integer from stdin." - boom : Str -> a boom = \_ -> boom "" diff --git a/crates/compiler/collections/src/vec_set.rs b/crates/compiler/collections/src/vec_set.rs index 52be9862497..c45a8169248 100644 --- a/crates/compiler/collections/src/vec_set.rs +++ b/crates/compiler/collections/src/vec_set.rs @@ -34,6 +34,12 @@ impl VecSet { self.elements.is_empty() } + pub fn singleton(value: T) -> Self { + Self { + elements: vec![value], + } + } + pub fn swap_remove(&mut self, index: usize) -> T { self.elements.swap_remove(index) } @@ -96,6 +102,16 @@ impl VecSet { { self.elements.retain(f) } + + pub fn keep_if_in_both(&mut self, other: &Self) { + self.elements.retain(|e| other.contains(e)); + } + + pub fn keep_if_in_either(&mut self, other: Self) { + for e in other.elements { + self.insert(e); + } + } } impl Extend for VecSet { diff --git a/crates/compiler/mono/src/tail_recursion.rs b/crates/compiler/mono/src/tail_recursion.rs index f2afbc9c7ee..0a7ab187d1b 100644 --- a/crates/compiler/mono/src/tail_recursion.rs +++ b/crates/compiler/mono/src/tail_recursion.rs @@ -10,7 +10,7 @@ use crate::layout::{ }; use bumpalo::collections::Vec; use bumpalo::Bump; -use roc_collections::MutMap; +use roc_collections::{MutMap, VecSet}; use roc_module::low_level::LowLevel; use roc_module::symbol::{IdentIds, ModuleId, Symbol}; @@ -53,8 +53,11 @@ pub fn apply_trmc<'a, 'i>( for proc in procs.values_mut() { use self::SelfRecursive::*; if let SelfRecursive(id) = proc.is_self_recursive { - if crate::tail_recursion::is_trmc_candidate(env.interner, proc) { - let new_proc = crate::tail_recursion::TrmcEnv::init(env, proc); + let trmc_candidate_symbols = trmc_candidates(env.interner, proc); + + if !trmc_candidate_symbols.is_empty() { + let new_proc = + crate::tail_recursion::TrmcEnv::init(env, proc, trmc_candidate_symbols); *proc = new_proc; } else { let mut args = Vec::with_capacity_in(proc.args.len(), arena); @@ -402,7 +405,49 @@ fn insert_jumps<'a>( } } -pub(crate) fn is_trmc_candidate<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> bool +#[derive(Debug, Clone, Default)] +struct TrmcCandidateSet { + /// Recursive calls for which we have found a TRMC opportunity + confirmed: VecSet, + /// Recursive calls that are (still) considered for TRMC + active: VecSet, + /// Recursive calls that are used in such a way that makes TRMC impossible + invalid: VecSet, +} + +impl TrmcCandidateSet { + fn insert(&mut self, call: Symbol) { + // there really is no way it could have been inserted already + debug_assert!(!self.invalid.contains(&call)); + + self.active.insert(call); + } + + fn extend(&mut self, other: Self) { + self.confirmed.keep_if_in_either(other.confirmed); + self.invalid.keep_if_in_either(other.invalid); + self.active.keep_if_in_either(other.active); + + self.active.retain(|k| !self.invalid.contains(k)); + self.confirmed.retain(|k| !self.invalid.contains(k)); + } + + fn retain(&mut self, keep: F) + where + F: Fn(&Symbol) -> bool, + { + for c in self.active.iter() { + if !keep(c) { + self.invalid.insert(*c); + } + } + + self.active.retain(|k| !self.invalid.contains(k)); + self.confirmed.retain(|k| !self.invalid.contains(k)); + } +} + +fn trmc_candidates<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> VecSet where I: LayoutInterner<'a>, { @@ -411,87 +456,50 @@ where proc.is_self_recursive, crate::ir::SelfRecursive::SelfRecursive(_) ) { - return false; + return VecSet::default(); } // and return a recursive tag union if !matches!(interner.get_repr(proc.ret_layout), LayoutRepr::Union(union_layout) if union_layout.is_recursive()) { - return false; + return VecSet::default(); } - match has_cons_in_tail_position(&proc.body, proc.name, None) { - SymbolUse::NotUsed | SymbolUse::Used => false, - SymbolUse::TrmcOppotunity => true, - } + trmc_candidates_help(proc.name, &proc.body, TrmcCandidateSet::default()).confirmed } -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] -#[repr(C)] -enum SymbolUse { - #[default] - NotUsed = 0, - TrmcOppotunity = 1, - Used = 2, -} - -impl SymbolUse { - #[must_use] - fn mappend(self, y: Self) -> Self { - debug_assert_eq!(self.mappend_slow(y), Ord::max(self, y)); - - Ord::max(self, y) - } - - fn mappend_slow(self, y: Self) -> Self { - use SymbolUse::*; - - match (self, y) { - (Used, _) | (_, Used) => Used, - (TrmcOppotunity, _) | (_, TrmcOppotunity) => TrmcOppotunity, - (NotUsed, NotUsed) => NotUsed, - } - } -} - -fn has_cons_in_tail_position( - stmt: &Stmt<'_>, +fn trmc_candidates_help<'a>( function_name: LambdaName, - recursive_call: Option, -) -> SymbolUse { - // we are looking for code of the form - // - // let x = Tag a b c - // ret x - + stmt: &'_ Stmt<'a>, + mut candidates: TrmcCandidateSet, +) -> TrmcCandidateSet { // if this stmt is the literal tail tag application and return, then this is a TRMC opportunity if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) { // must use the result of a recursive call directly as an argument - if let Some(recursive_call) = recursive_call { - if cons_info.arguments.contains(&recursive_call) { - return SymbolUse::TrmcOppotunity; + // we pick the (syntactically) first one + for recursive_call in candidates.active.iter() { + if cons_info.arguments.contains(recursive_call) { + return TrmcCandidateSet { + confirmed: VecSet::singleton(*recursive_call), + active: VecSet::default(), + invalid: candidates.invalid, + }; } } } // if the stmt uses the active recursive call, that invalidates the recursive call for this branch - if let Some(recursive_call) = recursive_call { - if stmt_contains_symbol_nonrec(stmt, recursive_call) { - // this means we really only check for the first recursive call (in each branch) - // whether it presents a TRMC opportunity. In theory we can look at all recursive calls - // this is future work. - return SymbolUse::Used; - } - } + candidates.retain(|recursive_call| !stmt_contains_symbol_nonrec(stmt, *recursive_call)); match stmt { Stmt::Let(symbol, expr, _, next) => { // find a new recursive call if we currently have none // that means we generally pick the first recursive call we find - let recursive_call = recursive_call - .or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol)); + if TrmcEnv::is_recursive_expr(expr, function_name).is_some() { + candidates.insert(*symbol); + } - has_cons_in_tail_position(next, function_name, recursive_call) + trmc_candidates_help(function_name, next, candidates) } Stmt::Switch { branches, @@ -503,54 +511,43 @@ fn has_cons_in_tail_position( .map(|(_, _, stmt)| stmt) .chain([default_branch.1]); - let mut accum = SymbolUse::NotUsed; + let mut accum = candidates.clone(); for next in it { - let x = has_cons_in_tail_position(next, function_name, recursive_call); - accum = accum.mappend(x); + let x = trmc_candidates_help(function_name, next, candidates.clone()); - if let SymbolUse::Used = accum { - return SymbolUse::Used; - } + accum.extend(x); } accum } - Stmt::Refcounting(_, next) => { - has_cons_in_tail_position(next, function_name, recursive_call) - } + Stmt::Refcounting(_, next) => trmc_candidates_help(function_name, next, candidates), Stmt::Expect { remainder, .. } | Stmt::ExpectFx { remainder, .. } - | Stmt::Dbg { remainder, .. } => { - has_cons_in_tail_position(remainder, function_name, recursive_call) - } + | Stmt::Dbg { remainder, .. } => trmc_candidates_help(function_name, remainder, candidates), Stmt::Join { body, remainder, .. } => { - let x = has_cons_in_tail_position(body, function_name, recursive_call); + let mut x = trmc_candidates_help(function_name, body, candidates.clone()); + let y = trmc_candidates_help(function_name, remainder, candidates.clone()); - if let SymbolUse::Used = x { - SymbolUse::Used - } else { - let y = has_cons_in_tail_position(remainder, function_name, recursive_call); - x.mappend(y) - } + x.extend(y); + + x } - Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => SymbolUse::NotUsed, + Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => candidates, } } #[derive(Clone)] pub(crate) struct TrmcEnv<'a> { - function_name: LambdaName<'a>, hole_symbol: Symbol, initial_ptr_symbol: Symbol, joinpoint_id: JoinPointId, return_layout: InLayout<'a>, ptr_return_layout: InLayout<'a>, - // the call we are performing TRMC on - recursive_call: Option<(Symbol, Call<'a>)>, + trmc_calls: MutMap>>, } #[derive(Debug)] @@ -634,7 +631,11 @@ impl<'a> TrmcEnv<'a> { ) } - pub fn init<'i>(env: &mut Env<'a, 'i>, proc: &Proc<'a>) -> Proc<'a> { + pub fn init<'i>( + env: &mut Env<'a, 'i>, + proc: &Proc<'a>, + trmc_calls: VecSet, + ) -> Proc<'a> { let arena = env.arena; let return_layout = proc.ret_layout; @@ -682,14 +683,15 @@ impl<'a> TrmcEnv<'a> { let jump_stmt = Stmt::Jump(joinpoint_id, jump_arguments.into_bump_slice()); + let trmc_calls = trmc_calls.iter().map(|s| (*s, None)).collect(); + let mut this = Self { - function_name: proc.name, hole_symbol, initial_ptr_symbol, joinpoint_id, return_layout, ptr_return_layout, - recursive_call: None, + trmc_calls, }; let param = Param { @@ -733,24 +735,30 @@ impl<'a> TrmcEnv<'a> { match stmt { Stmt::Let(symbol, expr, layout, next) => { - if self.recursive_call.is_none() { - if let Some(call) = Self::is_recursive_expr(expr, self.function_name) { - let can_trmc = - has_cons_in_tail_position(next, self.function_name, Some(*symbol)); - - match can_trmc { - SymbolUse::NotUsed => { /* the variable is dead */ } - SymbolUse::TrmcOppotunity => { - self.recursive_call = Some((*symbol, call)); - return self.walk_stmt(env, next); - } - SymbolUse::Used => { /* the variable is used making TRMC invaid */ } - } - } + // if this is a TRMC call, + if let Some(opt_call) = self.trmc_calls.get_mut(symbol) { + debug_assert!(opt_call.is_none()); + + let call = match expr { + Expr::Call(call) => call, + _ => unreachable!(), + }; + + *opt_call = Some(call.clone()); + + return self.walk_stmt(env, next); } if let Some(cons_info) = Self::is_terminal_constructor(stmt) { - match &self.recursive_call { + // figure out which TRMC call to use here. We pick the first one that works + let opt_recursive_call = cons_info.arguments.iter().find_map(|arg| { + self.trmc_calls + .get(arg) + .and_then(|x| x.as_ref()) + .map(|x| (arg, x)) + }); + + match opt_recursive_call { None => { // this control flow path did not encounter a recursive call. Just // write the end result into the hole and we're done. @@ -863,16 +871,12 @@ impl<'a> TrmcEnv<'a> { } => { let mut new_branches = Vec::with_capacity_in(branches.len(), arena); - let opt_recursive_call = self.recursive_call.clone(); - for (id, info, stmt) in branches.iter() { - self.recursive_call = opt_recursive_call.clone(); let new_stmt = self.walk_stmt(env, stmt); new_branches.push((*id, info.clone(), new_stmt)); } - self.recursive_call = opt_recursive_call; let new_default_branch = &*arena.alloc(self.walk_stmt(env, default_branch.1)); Stmt::Switch {