Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better select rule #653

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 13 additions & 65 deletions dag_in_context/src/optimizations/switch_rewrites.egg
Original file line number Diff line number Diff line change
Expand Up @@ -31,78 +31,26 @@
((union (Get if_e k) (Bop (Smax) a b)))
:ruleset switch_rewrite)

; if pred then a else b ~~> (select pred a b)
; where a and b are inputs to the region
(rule (
(= if_e (If pred inputs thn els))
(= a (Get inputs i))
(= b (Get inputs j))

; if pred then a else b
(= (Get thn k) (Get (Arg ty (InIf true pred inputs)) i))
(= (Get els k) (Get (Arg ty (InIf false pred inputs)) j))

; If i = j, then the arg is just passed through the if, and we
; don't need a select. This will get handled by the passthrough rules.
(!= i j)
)
(rule
(
(union (Get if_e k) (Top (Select) pred a b))
)
:ruleset switch_rewrite)

(rule (
(= if_e (If pred inputs thn els))
(ContextOf if_e ctx)
(HasArgType if_e ty)
(= (Get thn i) (Const x _ty (InIf true pred inputs)))
(= (Get els i) (Const y _ty (InIf false pred inputs)))
)
((union (Get if_e i) (Top (Select) pred (Const x ty ctx) (Const y ty ctx))))
:ruleset switch_rewrite)

; if pred then A else Const -> select pred A Const
; where A is an input to the region
(rule (
(= if_e (If pred inputs thn els))
(ContextOf if_e ctx)
(HasArgType if_e ty)

; input to the if
(= a (Get inputs i))
(= (Get thn k) (Get (Arg _ty (InIf true pred inputs)) i))

(= els_out (Get els k))
(= (IntB y) (lo-bound els_out))
(= (IntB y) (hi-bound els_out))
(= if_e (If pred inputs thn els))
(ContextOf if_e ctx)

(= thn_out (Get thn i))
(= els_out (Get els i))
(ExprIsPure thn_out)
(ExprIsPure els_out)

(> 10 (Expr-size thn_out)) ; TODO: Tune these size limits
(> 10 (Expr-size els_out))
)
(
(union (Get if_e k) (Top (Select) pred a (Const (Int y) ty ctx)))
(union (Get if_e i)
(Top (Select) pred (Subst ctx inputs thn_out) (Subst ctx inputs els_out)))
)
:ruleset switch_rewrite
)

; if pred then Const else B -> select pred Const B
; where B is an input to the region
(rule (
(= if_e (If pred inputs thn els))
(ContextOf if_e ctx)
(HasArgType if_e ty)

(= thn_out (Get thn k))
(= (IntB y) (lo-bound thn_out))
(= (IntB y) (hi-bound thn_out))

; input to the if
(= b (Get inputs i))
(= (Get els k) (Get (Arg _ty (InIf false pred inputs)) i))
)
(
(union (Get if_e k) (Top (Select) pred (Const (Int y) ty ctx) b))
)
:ruleset switch_rewrite
)

; if (a and b) X Y ~~> if a (if b X Y) Y
(rule ((= lhs (If (Bop (And) a b) ins X Y))
(HasType ins (TupleT ins_ty))
Expand Down
23 changes: 23 additions & 0 deletions tests/passing/small/select_simple.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
// ARGS: 20 30
fn main(x: i64, y: i64) {
let res: i64 = 0;
// if P then A else B where A and B are inputs to the region
if (x * y < 20) {
res = x;
} else {
res = y;
}
println!("{}", res);

// if P then C1 else C2
if (x * y > 10) {
res = 4;
} else {
res = 5;
}
println!("{}", res);

// if P then C1 (and implicitly, the else is a passthrough)
if (x * y > 20) {
res = 10;
}
println!("{}", res);

// if P then X else Y where X and Y are small, pure expressions
if (x * y == 40) {
res = x * 2;
} else {
res = x + 5;
}
println!("{}", res);
}
32 changes: 6 additions & 26 deletions tests/snapshots/files__block-diamond-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,11 @@ expression: visualization.result
c2_: int = const 1;
c3_: int = const 2;
v4_: bool = lt v0 c3_;
c5_: int = const 0;
c6_: int = const 5;
v7_: int = id c2_;
v8_: int = id c2_;
v9_: int = id c3_;
br v4_ .b10_ .b11_;
.b10_:
c12_: int = const 4;
v7_: int = id c12_;
v8_: int = id c2_;
v9_: int = id c3_;
v13_: int = id v7_;
v14_: int = id v8_;
v15_: int = add c2_ v13_;
print v15_;
c5_: int = const 4;
v6_: int = select v4_ c5_ c2_;
v7_: int = add c3_ v6_;
v8_: int = select v4_ v6_ v7_;
v9_: int = add c2_ v8_;
print v9_;
ret;
jmp .b16_;
.b11_:
v13_: int = id v7_;
v14_: int = id v8_;
v17_: int = add v7_ v9_;
v13_: int = id v17_;
v14_: int = id v8_;
v15_: int = add c2_ v13_;
print v15_;
ret;
.b16_:
}
29 changes: 9 additions & 20 deletions tests/snapshots/files__block-diamond-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,14 @@ expression: visualization.result
# ARGS: 1
@main(v0: int) {
.b1_:
c2_: int = const 1;
c3_: int = const 2;
v4_: bool = lt v0 c3_;
c5_: int = const 4;
v6_: int = select v4_ c5_ c2_;
v7_: int = id v6_;
v8_: int = id c2_;
br v4_ .b9_ .b10_;
.b9_:
v11_: int = add c2_ v7_;
print v11_;
c2_: int = const 2;
v3_: bool = lt v0 c2_;
c4_: int = const 4;
c5_: int = const 1;
v6_: int = select v3_ c4_ c5_;
v7_: int = add c2_ v6_;
v8_: int = select v3_ v6_ v7_;
v9_: int = add c5_ v8_;
print v9_;
ret;
jmp .b12_;
.b10_:
v13_: int = add c3_ v6_;
v7_: int = id v13_;
v8_: int = id c2_;
v11_: int = add c2_ v7_;
print v11_;
ret;
.b12_:
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,9 @@ expression: visualization.result
.b1_:
c2_: int = const 2;
v3_: bool = lt v0 c2_;
br v3_ .b4_ .b5_;
.b4_:
v6_: int = add v0 v0;
v7_: int = id v6_;
print v7_;
v4_: int = add v0 v0;
v5_: int = mul c2_ v4_;
v6_: int = select v3_ v4_ v5_;
print v6_;
ret;
jmp .b8_;
.b5_:
v9_: int = add v0 v0;
v10_: int = mul c2_ v9_;
v7_: int = id v10_;
print v7_;
ret;
.b8_:
}
17 changes: 4 additions & 13 deletions tests/snapshots/files__branch_duplicate_work-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,9 @@ expression: visualization.result
.b1_:
c2_: int = const 2;
v3_: bool = lt v0 c2_;
br v3_ .b4_ .b5_;
.b4_:
v6_: int = add v0 v0;
v7_: int = id v6_;
print v7_;
v4_: int = add v0 v0;
v5_: int = mul c2_ v4_;
v6_: int = select v3_ v4_ v5_;
print v6_;
ret;
jmp .b8_;
.b5_:
v9_: int = add v0 v0;
v10_: int = mul c2_ v9_;
v7_: int = id v10_;
print v7_;
ret;
.b8_:
}
29 changes: 8 additions & 21 deletions tests/snapshots/files__branch_hoisting-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,14 @@ expression: visualization.result
.b16_:
v18_: bool = eq v6_ v7_;
c19_: int = const 2;
br v18_ .b20_ .b21_;
.b20_:
v22_: int = mul c19_ v5_;
v23_: int = id v22_;
v24_: int = id v5_;
v25_: int = id v6_;
v26_: int = id v6_;
v27_: int = id v8_;
.b28_:
c29_: int = const 1;
v30_: int = add c29_ v5_;
v20_: int = mul c19_ v5_;
c21_: int = const 3;
v22_: int = mul c21_ v5_;
v23_: int = select v18_ v20_ v22_;
c24_: int = const 1;
v25_: int = add c24_ v5_;
v11_: int = id v23_;
v12_: int = id v30_;
v12_: int = id v25_;
v13_: int = id v6_;
v14_: int = id v7_;
v15_: int = id v8_;
Expand All @@ -45,20 +40,12 @@ expression: visualization.result
v7_: int = id v14_;
v8_: int = id v15_;
jmp .b9_;
.b21_:
c31_: int = const 3;
v32_: int = mul c31_ v5_;
v23_: int = id v32_;
v24_: int = id v5_;
v25_: int = id v6_;
v26_: int = id v7_;
v27_: int = id v8_;
jmp .b28_;
.b17_:
v4_: int = id v11_;
v5_: int = id v12_;
v6_: int = id v13_;
v7_: int = id v14_;
v8_: int = id v15_;
print v4_;
ret;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,48 +32,23 @@ expression: visualization.result
v25_: bool = eq v11_ v24_;
v26_: int = mul v7_ v9_;
v27_: int = add v26_ v8_;
c28_: bool = const true;
v29_: int = id v6_;
v30_: bool = id c28_;
v31_: int = id v27_;
v32_: int = id v8_;
v33_: int = id v9_;
v34_: int = id v10_;
v35_: int = id v11_;
br v25_ .b36_ .b37_;
.b36_:
c38_: bool = const true;
v29_: int = id v6_;
v30_: bool = id c38_;
v31_: int = id v22_;
v32_: int = id v8_;
v33_: int = id v9_;
v34_: int = id v10_;
v35_: int = id v11_;
v28_: int = select v25_ v22_ v27_;
v14_: int = id v6_;
v15_: int = id v31_;
v15_: int = id v28_;
v16_: int = id v8_;
v17_: int = id v9_;
v18_: int = id v10_;
v19_: int = id v11_;
.b20_:
v39_: bool = not v13_;
v29_: bool = not v13_;
v6_: int = id v14_;
v7_: int = id v15_;
v8_: int = id v16_;
v9_: int = id v17_;
v10_: int = id v18_;
v11_: int = id v19_;
br v39_ .b12_ .b40_;
.b37_:
v14_: int = id v6_;
v15_: int = id v31_;
v16_: int = id v8_;
v17_: int = id v9_;
v18_: int = id v10_;
v19_: int = id v11_;
jmp .b20_;
.b40_:
br v29_ .b12_ .b30_;
.b30_:
print v0;
ret;
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ expression: visualization.result
.b1_:
c2_: int = const 20;
print c2_;
ret;
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ expression: visualization.result
.b1_:
c2_: int = const 20;
print c2_;
ret;
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ expression: visualization.result
v8_: int = id v14_;
v9_: int = id v15_;
print c1_;
ret;
}
Loading
Loading