Skip to content

Commit

Permalink
Make if then show as tail recursive in can
Browse files Browse the repository at this point in the history
  • Loading branch information
faldor20 committed Feb 13, 2024
1 parent e4a7f11 commit 0f67a42
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 7 deletions.
2 changes: 1 addition & 1 deletion crates/compiler/can/src/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2369,7 +2369,7 @@ fn canonicalize_pending_body<'a>(
// reset the tailcallable_symbol
env.tailcallable_symbol = outer_tailcallable;

// The closure is self tail recursive iff it tail calls itself (by defined name).
// The closure is self tail recursive if it tail calls itself (by defined name).
let is_recursive = match can_output.tail_call {
Some(tail_symbol) if tail_symbol == *defined_symbol => Recursive::TailRecursive,
_ => Recursive::NotRecursive,
Expand Down
14 changes: 8 additions & 6 deletions crates/compiler/can/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,7 @@ pub fn canonicalize_expr<'a>(

output.references.union_mut(&cond_output.references);
output.references.union_mut(&then_output.references);
output.union(then_output);
}

let (loc_else, else_output) = canonicalize_expr(
Expand All @@ -1319,6 +1320,7 @@ pub fn canonicalize_expr<'a>(
);

output.references.union_mut(&else_output.references);
output.union(else_output);

(
If {
Expand Down Expand Up @@ -3077,14 +3079,14 @@ pub enum DeclarationTag {

impl DeclarationTag {
fn len(self) -> usize {
use DeclarationTag::*;
use DeclarationTag as dt;

match self {
Function(_) | Recursive(_) | TailRecursive(_) => 1,
Value => 1,
Expectation | ExpectationFx => 1,
Destructure(_) => 1,
MutualRecursion { length, .. } => length as usize + 1,
dt::Function(_) | dt::Recursive(_) | dt::TailRecursive(_) => 1,
dt::Value => 1,
dt::Expectation | dt::ExpectationFx => 1,
dt::Destructure(_) => 1,
dt::MutualRecursion { length, .. } => length as usize + 1,
}
}
}
Expand Down
56 changes: 56 additions & 0 deletions crates/compiler/can/tests/test_can.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,62 @@ mod test_can {
assert_eq!(p_detected, Recursive::TailRecursive);
}

#[test]
fn recognize_tail_calls_if_else() {
let src = indoc!(
r"
g = \x ->
if x == 0 then
0
else
g (x - 1)
# use parens to force the ordering!
(
h = \x ->
if x == 0 then
0
else
g (x - 1)
(
p = \x ->
if x == 0 then
0
else if x == 1 then
g (x - 1)
else
p (x - 1)
# variables must be (indirectly) referenced in the body for analysis to work
{ x: p, y: h }
)
)
"
);
let arena = Bump::new();
let CanExprOut {
loc_expr, problems, ..
} = can_expr_with(&arena, test_home(), src);

assert_eq!(problems, Vec::new());
assert!(problems
.iter()
.all(|problem| matches!(problem, Problem::UnusedDef(_, _))));

let actual = loc_expr.value;

let g_detected = get_closure(&actual, 0);
let h_detected = get_closure(&actual, 1);
let p_detected = get_closure(&actual, 2);

assert_eq!(g_detected, Recursive::TailRecursive);
assert_eq!(h_detected, Recursive::NotRecursive);
assert_eq!(p_detected, Recursive::TailRecursive);
}

// TODO restore this test! It should report two unused defs (h and p), but only reports 1.
// #[test]
// fn reproduce_incorrect_unused_defs() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
procedure Bool.11 (#Attr.2, #Attr.3):
let Bool.23 : Int1 = lowlevel Eq #Attr.2 #Attr.3;
ret Bool.23;

procedure Num.19 (#Attr.2, #Attr.3):
let Num.297 : I64 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Num.297;

procedure Num.20 (#Attr.2, #Attr.3):
let Num.298 : I64 = lowlevel NumSub #Attr.2 #Attr.3;
ret Num.298;

procedure Test.1 (#Derived_gen.0, #Derived_gen.1):
joinpoint Test.7 Test.2 Test.3:
let Test.14 : I64 = 0i64;
let Test.12 : Int1 = CallByName Bool.11 Test.2 Test.14;
if Test.12 then
ret Test.3;
else
let Test.11 : I64 = 1i64;
let Test.9 : I64 = CallByName Num.20 Test.2 Test.11;
let Test.10 : I64 = CallByName Num.19 Test.2 Test.3;
jump Test.7 Test.9 Test.10;
in
jump Test.7 #Derived_gen.0 #Derived_gen.1;

procedure Test.0 ():
let Test.5 : I64 = 1000000i64;
let Test.6 : I64 = 0i64;
let Test.4 : I64 = CallByName Test.1 Test.5 Test.6;
ret Test.4;
14 changes: 14 additions & 0 deletions crates/compiler/test_mono/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,20 @@ fn tail_call_elimination() {
)
}

#[mono_test]
fn tail_call_elimination_if_else() {
indoc!(
r"
sum = \n, accum ->
if n==0 then
accum
else sum (n - 1) (n + accum)
sum 1_000_000 0
"
)
}

#[mono_test]
fn tail_call_with_same_layout_different_lambda_sets() {
indoc!(
Expand Down

0 comments on commit 0f67a42

Please sign in to comment.