Skip to content

Commit

Permalink
Merge pull request #606 from egraphs-good/peel-more
Browse files Browse the repository at this point in the history
Add loop iteration estimates for basic loops and peel more
  • Loading branch information
kirstenmg authored May 27, 2024
2 parents 81cdeef + 73bd671 commit f7b1b09
Show file tree
Hide file tree
Showing 13 changed files with 243 additions and 130 deletions.
70 changes: 67 additions & 3 deletions dag_in_context/src/optimizations/loop_unroll.egg
Original file line number Diff line number Diff line change
@@ -1,24 +1,88 @@
;; Some simple simplifications of loops
(ruleset loop-unroll)
(ruleset loop-peel)
(ruleset loop-iters-analysis)

;; inputs, outputs -> number of iterations
;; The minimum possible guess is 1 because of do-while loops
;; TODO: dead loop deletion can turn loops with a false condition to a body
(function LoopNumItersGuess (Expr Expr) i64 :merge (max 1 (min old new)))

;; by default, guess that all loops run 1000 times
(rule ((DoWhile inputs outputs))
((set (LoopNumItersGuess inputs outputs) 1000))
:ruleset always-run)
:ruleset loop-iters-analysis)

;; For a loop that is false, its num iters is 1
(rule
((= loop (DoWhile inputs outputs))
(= (Const (Bool false) ty ctx) (Get outputs 0)))
((set (LoopNumItersGuess inputs outputs) 1))
:ruleset loop-iters-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated before checking pred
;; TODO: we could make it work for decrementing loops
(rule
((= lhs (DoWhile inputs outputs))
(= num-inputs (tuple-length inputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by some constant each loop
(= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i)
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while next_counter less than end_constant
(= pred (Bop (LessThan) next_counter
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (/ (- end_constant start_const) increment))
)
:ruleset loop-iters-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated after checking pred
(rule
((= lhs (DoWhile inputs outputs))
(= num-inputs (tuple-length inputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by a constant each loop
(= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i)
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while this counter less than end_constant
(= pred (Bop (LessThan) (Get (Arg _ty _ctx) counter_i)
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (+ (/ (- end_constant start_const) increment) 1))
)
:ruleset loop-iters-analysis)

;; loop peeling rule
;; Only peel loops that we know iterate < 3 times
(rule
((= lhs (DoWhile inputs outputs))
(ContextOf lhs ctx)
(HasType inputs inputs-ty)
(= outputs-len (tuple-length outputs))
(= old_cost (LoopNumItersGuess inputs outputs)))
((let executed-once
(= old_cost (LoopNumItersGuess inputs outputs))
(< old_cost 3)
)
(
(let executed-once
(Subst ctx inputs outputs))
(let executed-once-body
(SubTuple executed-once 1 (- outputs-len 1)))
Expand Down
5 changes: 3 additions & 2 deletions dag_in_context/src/optimizations/switch_rewrites.egg
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
(ruleset switch_rewrite)
(ruleset always-switch-rewrite)

; if (a and b) X Y ~~> if a (if b X Y) Y
(rule ((= lhs (If (Bop (And) a b) ins X Y))
Expand Down Expand Up @@ -44,11 +45,11 @@

(rewrite (If (Const (Bool true) ty ctx) ins thn els)
(Subst ctx ins thn)
:ruleset switch_rewrite)
:ruleset always-switch-rewrite)

(rewrite (If (Const (Bool false) ty ctx) ins thn els)
(Subst ctx ins els)
:ruleset switch_rewrite)
:ruleset always-switch-rewrite)

(rule ((= lhs (If pred ins thn els))
(= (Get thn i) (Const (Bool true) ty ctx1))
Expand Down
4 changes: 3 additions & 1 deletion dag_in_context/src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ pub fn mk_schedule() -> String {
context
interval-analysis
memory-helpers
always-switch-rewrite
loop-iters-analysis
)
Expand All @@ -52,11 +54,11 @@ pub fn mk_schedule() -> String {
switch_rewrite
;loop-inv-motion
loop-strength-reduction
loop-peel
)
(run-schedule
{helpers}
loop-peel
(repeat 2
{helpers}
expensive-optimizations)
Expand Down
12 changes: 12 additions & 0 deletions tests/passing/small/jumping_loop.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
@main {
jump: int = const 4;
i: int = const 0;
n: int = const 18;

.loop:
i: int = add jump i;
pred: bool = lt i n;
br pred .loop .end;

.end:
}
13 changes: 13 additions & 0 deletions tests/passing/small/peel_twice.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@main {
i: int = const 0;
one: int = const 1;
two: int = const 2;

.loop:
i: int = add one i;
cond: bool = lt i two;
br cond .loop .loop_end;

.loop_end:
print i;
}
12 changes: 12 additions & 0 deletions tests/passing/small/peel_twice_precalc_pred.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
@main {
i: int = const 0;
one: int = const 1;

.loop:
cond: bool = lt i one;
i: int = add one i;
br cond .loop .loop_end;

.loop_end:
print i;
}
71 changes: 29 additions & 42 deletions tests/snapshots/files__implicit-return-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -37,48 +37,35 @@ expression: visualization.result
}
@main {
.v0_:
v1_: bool = const true;
v2_: int = const 16;
v3_: int = const 1;
v4_: int = const 4;
v5_: int = const 15;
v6_: int = id v2_;
v1_: int = const 4;
v2_: int = const 0;
v3_: int = const 15;
v4_: int = id v1_;
v5_: int = id v2_;
v6_: int = id v1_;
v7_: int = id v3_;
v8_: int = id v4_;
v9_: int = id v5_;
br v1_ .v10_ .v11_;
.v10_:
v12_: int = id v2_;
v13_: int = id v3_;
v14_: int = id v4_;
v15_: int = id v5_;
.v16_:
v17_: int = const 14;
v18_: bool = lt v13_ v17_;
v19_: int = id v12_;
v20_: int = id v13_;
v21_: int = id v14_;
v22_: int = id v15_;
br v18_ .v23_ .v24_;
.v23_:
v25_: int = mul v12_ v14_;
v26_: int = const 1;
v27_: int = add v13_ v26_;
v19_: int = id v25_;
v20_: int = id v27_;
v21_: int = id v14_;
v22_: int = id v15_;
.v24_:
.v8_:
v9_: int = const 14;
v10_: bool = lt v5_ v9_;
v11_: int = id v4_;
v12_: int = id v5_;
v13_: int = id v6_;
v14_: int = id v7_;
br v10_ .v15_ .v16_;
.v15_:
v17_: int = mul v4_ v6_;
v18_: int = const 1;
v19_: int = add v18_ v5_;
v11_: int = id v17_;
v12_: int = id v19_;
v13_: int = id v20_;
v14_: int = id v21_;
v15_: int = id v22_;
br v18_ .v16_ .v28_;
.v28_:
v6_: int = id v12_;
v7_: int = id v13_;
v8_: int = id v14_;
v9_: int = id v15_;
.v11_:
print v6_;
v13_: int = id v6_;
v14_: int = id v7_;
.v16_:
v4_: int = id v11_;
v5_: int = id v12_;
v6_: int = id v13_;
v7_: int = id v14_;
br v10_ .v8_ .v20_;
.v20_:
print v4_;
}
21 changes: 21 additions & 0 deletions tests/snapshots/files__jumping_loop-optimize.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
source: tests/files.rs
expression: visualization.result
---
@main {
.v0_:
v1_: int = const 0;
v2_: int = const 18;
v3_: int = const 4;
v4_: int = id v1_;
v5_: int = id v2_;
v6_: int = id v3_;
.v7_:
v8_: int = add v4_ v6_;
v9_: bool = lt v8_ v5_;
v4_: int = id v8_;
v5_: int = id v5_;
v6_: int = id v6_;
br v9_ .v7_ .v10_;
.v10_:
}
9 changes: 9 additions & 0 deletions tests/snapshots/files__peel_twice-optimize.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: tests/files.rs
expression: visualization.result
---
@main {
.v0_:
v1_: int = const 2;
print v1_;
}
9 changes: 9 additions & 0 deletions tests/snapshots/files__peel_twice_precalc_pred-optimize.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: tests/files.rs
expression: visualization.result
---
@main {
.v0_:
v1_: int = const 2;
print v1_;
}
63 changes: 27 additions & 36 deletions tests/snapshots/files__range_check-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,33 @@ expression: visualization.result
---
@main {
.v0_:
v1_: bool = const true;
v2_: int = const 1;
print v2_;
v3_: int = id v2_;
br v1_ .v4_ .v5_;
.v4_:
v6_: int = id v2_;
.v7_:
v8_: int = const 6;
v9_: bool = lt v6_ v8_;
v10_: int = const 5;
v11_: bool = lt v6_ v10_;
br v11_ .v12_ .v13_;
v1_: int = const 0;
v2_: int = id v1_;
.v3_:
v4_: int = const 6;
v5_: bool = lt v2_ v4_;
v6_: int = const 5;
v7_: bool = lt v2_ v6_;
br v7_ .v8_ .v9_;
.v8_:
v10_: int = const 1;
print v10_;
v11_: int = id v2_;
.v12_:
v14_: int = const 1;
print v14_;
v15_: int = id v6_;
v13_: int = const 1;
v14_: int = add v13_ v2_;
v15_: int = id v14_;
br v5_ .v16_ .v17_;
.v16_:
v17_: int = const 1;
v18_: int = add v17_ v6_;
v19_: int = id v18_;
br v9_ .v20_ .v21_;
.v20_:
v19_: int = id v18_;
.v21_:
v6_: int = id v19_;
br v9_ .v7_ .v22_;
.v22_:
v3_: int = id v6_;
print v3_;
ret;
.v13_:
v23_: int = const 2;
print v23_;
v15_: int = id v6_;
jmp .v16_;
.v5_:
print v3_;
v15_: int = id v14_;
.v17_:
v2_: int = id v15_;
br v5_ .v3_ .v18_;
.v9_:
v19_: int = const 2;
print v19_;
v11_: int = id v2_;
jmp .v12_;
.v18_:
print v2_;
}
Loading

0 comments on commit f7b1b09

Please sign in to comment.