diff --git a/dag_in_context/src/from_egglog.rs b/dag_in_context/src/from_egglog.rs index a0e29634..7b07b84a 100644 --- a/dag_in_context/src/from_egglog.rs +++ b/dag_in_context/src/from_egglog.rs @@ -143,7 +143,10 @@ impl<'a> FromEgglog<'a> { }; Assumption::InIf(boolean, self.expr_from_egglog(self.termdag.get(*pred_expr)), self.expr_from_egglog(self.termdag.get(*input_expr))) } - _ => panic!("Invalid assumption: {:?}", assumption), + (name, _) => { + eprintln!("Invalid assumption: {:?}", assumption); + Assumption::WildCard(name.into()) + } }) } diff --git a/dag_in_context/src/optimizations/loop_invariant.rs b/dag_in_context/src/optimizations/loop_invariant.rs index bf55804e..84886bc4 100644 --- a/dag_in_context/src/optimizations/loop_invariant.rs +++ b/dag_in_context/src/optimizations/loop_invariant.rs @@ -209,3 +209,69 @@ fn test_invariant_detect() -> crate::Result { vec![], ) } + +#[test] +fn test_invariant_hoist() -> crate::Result { + use crate::add_context::ContextCache; + use crate::ast::*; + use crate::schema::Assumption; + + let mut cache = ContextCache::new_dummy_ctx(); + let output_ty = tuplet!(intt(), intt(), intt(), statet()); + let inner_inv = sub(getat(2), getat(1)); + let inv = add(inner_inv.clone(), int(1)); + let print = tprint(inv.clone(), getat(3)); + + let my_loop = dowhile( + parallel!(getat(0), getat(1), getat(2), getat(3)), + parallel!( + less_than(getat(0), getat(1)), + int(3), + getat(1), + getat(2), + print, + ), + ) + .with_arg_types(output_ty.clone(), output_ty.clone()) + .add_ctx_with_cache(Assumption::dummy(), &mut cache); + + let new_out_ty = tuplet!(intt(), intt(), intt(), statet(), intt()); + let mut cache = ContextCache::new_symbolic_ctx(); + + let hoisted_loop = dowhile( + parallel!( + getat(0), + getat(1), + getat(2), + getat(3), + add(int(1), sub(getat(2), getat(1))) + ), + parallel!( + less_than(getat(0), getat(1)), + int(3), + getat(1), + getat(2), + tprint(getat(4), getat(3)), + getat(4) + ), + ) + .with_arg_types(output_ty.clone(), new_out_ty) + .add_ctx_with_cache(Assumption::dummy(), &mut cache); + + let build = format!("(let loop {}) \n", my_loop); + let check = format!( + "(check {}) + (check (= loop (SubTuple {} 0 4)))", + hoisted_loop.clone(), + hoisted_loop + ); + + egglog_test( + &build, + &check, + vec![], + Value::Tuple(vec![]), + Value::Tuple(vec![]), + vec![], + ) +} diff --git a/dag_in_context/src/schedule.rs b/dag_in_context/src/schedule.rs index 77a38f32..381a6c4f 100644 --- a/dag_in_context/src/schedule.rs +++ b/dag_in_context/src/schedule.rs @@ -60,7 +60,7 @@ pub fn mk_schedule() -> String { cheap-optimizations switch_rewrite - ;loop-inv-motion + loop-inv-motion loop-strength-reduction loop-peel ) diff --git a/tests/passing/small/loop_hoist.bril b/tests/passing/small/loop_hoist.bril index ecfeffeb..2774af7f 100644 --- a/tests/passing/small/loop_hoist.bril +++ b/tests/passing/small/loop_hoist.bril @@ -1,8 +1,5 @@ -@main() { - arg1: int = const 1; - arg2: int = const 2; - arg3: int = const 3; - arg4: int = const 4; +# ARGS: 1 2 3 4 +@main(arg1 : int, arg2 : int, arg3 : int, arg4 : int) { .entry: zero: int = const 0; sub: int = sub arg3 arg2; diff --git a/tests/snapshots/files__if_in_loop-optimize.snap b/tests/snapshots/files__if_in_loop-optimize.snap index 15e1998c..187187d1 100644 --- a/tests/snapshots/files__if_in_loop-optimize.snap +++ b/tests/snapshots/files__if_in_loop-optimize.snap @@ -7,39 +7,41 @@ expression: visualization.result c2_: int = const 0; c3_: int = const 1; c4_: int = const 10; - v5_: int = id c2_; - v6_: int = id c3_; - v7_: int = id v0; - v8_: int = id c2_; - v9_: int = id c4_; -.b10_: - v11_: bool = le v7_ v8_; - v12_: bool = lt v5_ v9_; - v13_: bool = id v12_; - v14_: int = id v5_; + v5_: bool = lt v0 c3_; + v6_: int = id c2_; + v7_: int = id c3_; + v8_: int = id v0; + v9_: int = id c2_; + v10_: int = id c4_; + v11_: bool = id v5_; +.b12_: + v13_: bool = lt v6_ v10_; + v14_: bool = id v13_; v15_: int = id v6_; - v16_: int = id v8_; - v17_: int = id v7_; + v16_: int = id v7_; + v17_: int = id v9_; v18_: int = id v8_; v19_: int = id v9_; - br v11_ .b20_ .b21_; -.b20_: - v13_: bool = id v12_; - v14_: int = id v5_; + v20_: int = id v10_; + br v11_ .b21_ .b22_; +.b21_: + v14_: bool = id v13_; v15_: int = id v6_; - v16_: int = id v6_; + v16_: int = id v7_; v17_: int = id v7_; v18_: int = id v8_; v19_: int = id v9_; -.b21_: - print v16_; + v20_: int = id v10_; +.b22_: + print v17_; print v11_; - v22_: int = add v5_ v6_; - v5_: int = id v22_; - v6_: int = id v6_; + v23_: int = add v6_ v7_; + v6_: int = id v23_; v7_: int = id v7_; v8_: int = id v8_; v9_: int = id v9_; - br v12_ .b10_ .b23_; -.b23_: + v10_: int = id v10_; + v11_: bool = id v11_; + br v13_ .b12_ .b24_; +.b24_: } diff --git a/tests/snapshots/files__implicit-return-optimize.snap b/tests/snapshots/files__implicit-return-optimize.snap index da3565f3..091e3490 100644 --- a/tests/snapshots/files__implicit-return-optimize.snap +++ b/tests/snapshots/files__implicit-return-optimize.snap @@ -5,38 +5,41 @@ expression: visualization.result @pow(v0: int, v1: int) { .b2_: c3_: int = const 0; - v4_: int = id v0; - v5_: int = id c3_; + c4_: int = const 1; + v5_: int = sub v1 c4_; v6_: int = id v0; - v7_: int = id v1; -.b8_: - c9_: int = const 1; - v10_: int = sub v7_ c9_; - v11_: bool = lt v5_ v10_; - v12_: int = id v4_; - v13_: int = id v5_; - v14_: int = id v6_; - v15_: int = id v7_; - br v11_ .b16_ .b17_; -.b16_: - v18_: int = mul v4_ v6_; - c19_: int = const 1; - v20_: int = add c19_ v5_; - v12_: int = id v18_; - v13_: int = id v20_; - v14_: int = id v6_; - v15_: int = id v7_; - v4_: int = id v12_; - v5_: int = id v13_; - v6_: int = id v14_; - v7_: int = id v15_; - jmp .b8_; + v7_: int = id c3_; + v8_: int = id v0; + v9_: int = id v1; + v10_: int = id v5_; +.b11_: + v12_: bool = lt v7_ v10_; + v13_: int = id v6_; + v14_: int = id v7_; + v15_: int = id v8_; + v16_: int = id v9_; + br v12_ .b17_ .b18_; .b17_: - v4_: int = id v12_; - v5_: int = id v13_; - v6_: int = id v14_; - v7_: int = id v15_; - print v4_; + v19_: int = mul v6_ v8_; + c20_: int = const 1; + v21_: int = add c20_ v7_; + v13_: int = id v19_; + v14_: int = id v21_; + v15_: int = id v8_; + v16_: int = id v9_; + v6_: int = id v13_; + v7_: int = id v14_; + v8_: int = id v15_; + v9_: int = id v16_; + v10_: int = id v10_; + jmp .b11_; +.b18_: + v6_: int = id v13_; + v7_: int = id v14_; + v8_: int = id v15_; + v9_: int = id v16_; + v10_: int = id v10_; + print v6_; } @main { .b0_: diff --git a/tests/snapshots/files__loop_hoist-optimize.snap b/tests/snapshots/files__loop_hoist-optimize.snap index 9096261c..ff0e4905 100644 --- a/tests/snapshots/files__loop_hoist-optimize.snap +++ b/tests/snapshots/files__loop_hoist-optimize.snap @@ -2,27 +2,25 @@ source: tests/files.rs expression: visualization.result --- -@main { -.b0_: - c1_: int = const 1; - c2_: int = const 4; - c3_: int = const 3; - c4_: int = const 2; - v5_: int = id c1_; - v6_: int = id c2_; - v7_: int = id c3_; - v8_: int = id c4_; -.b9_: - c10_: int = const 1; - print c10_; - v11_: int = add c10_ v5_; - v12_: bool = lt v11_ v6_; - v13_: bool = not v12_; - v5_: int = id v11_; - v6_: int = id v6_; +@main(v0: int, v1: int, v2: int, v3: int) { +.b4_: + v5_: int = sub v2 v1; + v6_: int = id v0; + v7_: int = id v3; + v8_: int = id v2; + v9_: int = id v1; + v10_: int = id v5_; +.b11_: + print v10_; + v12_: int = add v10_ v6_; + v13_: bool = lt v12_ v7_; + v14_: bool = not v13_; + v6_: int = id v12_; v7_: int = id v7_; v8_: int = id v8_; - br v13_ .b9_ .b14_; -.b14_: - print v5_; + v9_: int = id v9_; + v10_: int = id v10_; + br v14_ .b11_ .b15_; +.b15_: + print v6_; }