Skip to content

Commit

Permalink
Merge pull request #41 from rtjoa/subst
Browse files Browse the repository at this point in the history
Add RVSDG subst functions
  • Loading branch information
oflatt authored Sep 29, 2023
2 parents 6ad9f28 + 24cdd6c commit 9350796
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "files"


[dependencies]
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "8cd5fda" }
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "cec39af" }
log = "0.4.19"
thiserror = "1"
lalrpop-util = { version = "0.19.8", features = ["lexer"] }
Expand Down
9 changes: 6 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ impl Optimizer {

pub fn make_optimizer_for(&mut self, program: &str) -> String {
//let schedule = "(run 3)";
let schedule = format!("(run {})", self.num_iters);
let schedule = format!(
"(run-schedule (repeat {} cfold (saturate subst)))",
self.num_iters
);
format!(
"
(datatype Type
Expand Down Expand Up @@ -328,8 +331,8 @@ impl Optimizer {
;; name, arguments, and body
(Func String ArgList StructuredBlock))
(rewrite (add ty (Int ty a) (Int ty b)) (Int ty (+ a b)))
(rewrite (sub ty (Int ty a) (Int ty b)) (Int ty (- a b)))
(rewrite (add ty (Int ty a) (Int ty b)) (Int ty (+ a b)) :ruleset cfold)
(rewrite (sub ty (Int ty a) (Int ty b)) (Int ty (- a b)) :ruleset cfold)
{program}
{schedule}
Expand Down
101 changes: 97 additions & 4 deletions src/rvsdg/schema.egg
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
(datatype Operand)
(datatype Body)

(sort VecOperand (Vec Operand))
(datatype VecOperandWrapper
(VO VecOperand))
(sort VecVecOperand (Vec VecOperandWrapper))
(sort VecOperandBase (Vec Operand))
(datatype VecOperand (VO VecOperandBase))

(sort VecVecOperandBase (Vec VecOperand))
(datatype VecVecOperand (VVO VecVecOperandBase))

;; Type
(datatype Type
Expand Down Expand Up @@ -48,6 +49,98 @@
(function Gamma (Operand VecOperand VecVecOperand) Body) ;; branching
(function Theta (Operand VecOperand VecOperand) Body) ;; loop

;; Substitution
(ruleset subst)

;; e [ x -> v ]
(function SubstExpr (Expr i64 Operand) Expr)
(function SubstOperand (Operand i64 Operand) Operand)
(function SubstBody (Body i64 Operand) Body)
(function SubstVecOperand (VecOperand i64 Operand) VecOperand)
(function SubstVecVecOperand (VecVecOperand i64 Operand) VecVecOperand)

(rewrite (SubstExpr (Const ty ops lit) x v) (Const ty ops lit) :ruleset subst)
(rewrite (SubstExpr (Call ty f args) x v) (Call ty f (SubstVecOperand args x v)) :ruleset subst)
(rewrite (SubstExpr (add ty a b) x v) (add ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (sub ty a b) x v) (sub ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (mul ty a b) x v) (mul ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (div ty a b) x v) (div ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (eq ty a b) x v) (eq ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (lt ty a b) x v) (lt ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (gt ty a b) x v) (gt ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (le ty a b) x v) (le ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (ge ty a b) x v) (ge ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (not ty a b) x v) (not ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (and ty a b) x v) (and ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (or ty a b) x v) (or ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)

(rewrite (SubstOperand (Arg x) x v) v :ruleset subst)
(rule ((= f (SubstOperand (Arg y) x v)) (!= y x))
((union f (Arg y))) :ruleset subst)
(rewrite (SubstOperand (Node b) x v) (Node (SubstBody b x v)) :ruleset subst)
(rewrite (SubstOperand (Project i b) x v) (Project i (SubstBody b x v)) :ruleset subst)

(rewrite (SubstBody (PureOp e) x v) (PureOp (SubstExpr e x v)) :ruleset subst)
;; Subst doesn't cross regions - so we subst into the inputs but not outputs
;; Node that a Gamma node's idx is on the outside, so it gets affected, but not
;; a Theta node's predicate
(rewrite (SubstBody (Gamma pred inputs outputs) x v) (Gamma (SubstOperand pred x v) (SubstVecOperand inputs x v) outputs) :ruleset subst)
(rewrite (SubstBody (Theta pred inputs outputs) x v) (Theta pred (SubstVecOperand inputs x v) outputs) :ruleset subst)

;; params: vec, var, op, next index to subst
;; rtjoa: TODO: implement by mapping internally so they're not O(n^2) time
(function SubstVecOperandHelper (VecOperand i64 Operand i64) VecOperand)
(rewrite (SubstVecOperand vec x v) (SubstVecOperandHelper vec x v 0) :ruleset subst)
(rule
(
(= f (SubstVecOperandHelper (VO vec) x v i))
(< i (vec-length vec))
)
(
(union
(SubstVecOperandHelper (VO vec) x v i)
(SubstVecOperandHelper
(VO (vec-set vec i (SubstOperand (vec-get vec i) x v)))
x v (+ i 1)
))
) :ruleset subst)

(rule
(
(= f (SubstVecOperandHelper (VO vec) x v i))
(= i (vec-length vec))
)
(
(union (SubstVecOperandHelper (VO vec) x v i) (VO vec))
) :ruleset subst)


;; params: vec, var, op, next index to subst
;; rtjoa: TODO: implement by mapping internally so they're not O(n^2) time
(function SubstVecVecOperandHelper (VecVecOperand i64 Operand i64) VecVecOperand)
(rewrite (SubstVecVecOperand vec x v) (SubstVecVecOperandHelper vec x v 0) :ruleset subst)
(rule
(
(= f (SubstVecVecOperandHelper (VVO vec) x v i))
(< i (vec-length vec))
)
(
(union
(SubstVecVecOperandHelper (VVO vec) x v i)
(SubstVecVecOperandHelper
(VVO (vec-set vec i (SubstVecOperand (vec-get vec i) x v)))
x v (+ i 1)
))
) :ruleset subst)

(rule
(
(= f (SubstVecVecOperandHelper (VVO vec) x v i))
(= i (vec-length vec))
)
(
(union (SubstVecVecOperandHelper (VVO vec) x v i) (VVO vec))
) :ruleset subst)

;; procedure f(n):
;; i = 0
Expand Down
117 changes: 108 additions & 9 deletions src/rvsdg/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ fn rvsdg_state_gamma() {
.C:
call @other_func;
jmp .End;
.End:
.End:
}
@other_func() {
Expand Down Expand Up @@ -417,30 +417,30 @@ fn rvsdg_odd_branch_egg_roundtrip() {
(const)
(Num 1)))))))
(Arg 3))))
(vec-of (Arg 1)
(VO (vec-of (Arg 1)
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Arg 0))
(vec-of (Arg 0)
(Arg 0)))
(VO (vec-of (Arg 0)
(Node (PureOp (add (IntT) (Arg 1) (Arg 2))))
(Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT) (const) (Num 1)))))))
(Arg 3))))
(Arg 3)))))
(let rescaled
(Gamma
(Node
(PureOp
(lt (BoolT) (Project 1 loop)
(Node (PureOp (Const (IntT) (const) (Num 5)))))))
(vec-of
(VO (vec-of
(Project 0 loop)
(Project 1 loop))
(vec-of (VO (vec-of (Arg 0) (Arg 1)))
(Project 1 loop)))
(VVO (vec-of (VO (vec-of (Arg 0) (Arg 1)))
(VO (vec-of (Arg 0)
(Node (PureOp (mul (IntT) (Arg 1)
(Node (PureOp (Const (IntT)
(const)
(Num 2))))))))))))
(Num 2)))))))))))))
(let expected-result (Project 0 rescaled))
(let expected-state (Project 1 rescaled))
"#;
Expand Down Expand Up @@ -648,3 +648,102 @@ fn deep_equal(f1: &RvsdgFunction, f2: &RvsdgFunction) -> bool {
(None, Some(_)) | (Some(_), None) => false,
}
}

#[test]
fn rvsdg_subst() {
const EGGLOG_PROGRAM: &str = r#"
(let unsubsted
(Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT)
(const)
(Num 1)))))))
(Arg 3)))))
(let substed (SubstOperand unsubsted 3 (Arg 7)))
(run-schedule (saturate subst))
(let expected
(Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT)
(const)
(Num 1)))))))
(Arg 7)))))
(check (= substed expected))
"#;
let mut egraph = new_rvsdg_egraph();
egraph.parse_and_run_program(EGGLOG_PROGRAM).unwrap();

const EGGLOG_THETA_PROGRAM: &str = r#"
(let unsubsted
(Theta
(Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT)
(const)
(Num 1)))))))
(Arg 3))))
(VO (vec-of (Arg 0)
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Arg 1)))
(VO (vec-of (Arg 0)
(Node (PureOp (add (IntT) (Arg 1) (Arg 2))))
(Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT) (const) (Num 1)))))))
(Arg 3)))))
(let substed (SubstBody unsubsted 1 (Arg 7)))
(run-schedule (saturate subst))
(let expected
(Theta
(Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT)
(const)
(Num 1)))))))
(Arg 3))))
(VO (vec-of (Arg 0)
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Arg 7)))
(VO (vec-of (Arg 0)
(Node (PureOp (add (IntT) (Arg 1) (Arg 2))))
(Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT) (const) (Num 1)))))))
(Arg 3)))))
(check (= substed expected))
"#;
let mut egraph = new_rvsdg_egraph();
egraph.parse_and_run_program(EGGLOG_THETA_PROGRAM).unwrap();

const EGGLOG_GAMMA_PROGRAM: &str = r#"
(let unsubsted
(Gamma
(Node
(PureOp
(lt (BoolT) (Arg 0) (Arg 0))))
(VO (vec-of
(Arg 1)
(Arg 0)))
(VVO (vec-of (VO (vec-of (Arg 0) (Arg 1)))
(VO (vec-of (Arg 0)
(Node (PureOp (mul (IntT) (Arg 1)
(Node (PureOp (Const (IntT)
(const)
(Num 2)))))))))))))
(let substed (SubstBody unsubsted 0 (Arg 7)))
(run-schedule (saturate subst))
(let expected
(Gamma
(Node
(PureOp
(lt (BoolT) (Arg 7) (Arg 7))))
(VO (vec-of
(Arg 1)
(Arg 7)))
(VVO (vec-of (VO (vec-of (Arg 0) (Arg 1)))
(VO (vec-of (Arg 0)
(Node (PureOp (mul (IntT) (Arg 1)
(Node (PureOp (Const (IntT)
(const)
(Num 2)))))))))))))
(check (= substed expected))
"#;
let mut egraph = new_rvsdg_egraph();
egraph.parse_and_run_program(EGGLOG_GAMMA_PROGRAM).unwrap();
}

0 comments on commit 9350796

Please sign in to comment.