diff --git a/Cargo.toml b/Cargo.toml index 15aff42e9..02cb41baf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ name = "files" [dependencies] -egglog = { git = "https://github.com/egraphs-good/egglog", rev = "8cd5fda" } +egglog = { git = "https://github.com/egraphs-good/egglog", rev = "cec39af" } log = "0.4.19" thiserror = "1" lalrpop-util = { version = "0.19.8", features = ["lexer"] } diff --git a/src/lib.rs b/src/lib.rs index 08b0ce9c9..c9c59020d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -260,7 +260,10 @@ impl Optimizer { pub fn make_optimizer_for(&mut self, program: &str) -> String { //let schedule = "(run 3)"; - let schedule = format!("(run {})", self.num_iters); + let schedule = format!( + "(run-schedule (repeat {} cfold (saturate subst)))", + self.num_iters + ); format!( " (datatype Type @@ -328,8 +331,8 @@ impl Optimizer { ;; name, arguments, and body (Func String ArgList StructuredBlock)) - (rewrite (add ty (Int ty a) (Int ty b)) (Int ty (+ a b))) - (rewrite (sub ty (Int ty a) (Int ty b)) (Int ty (- a b))) + (rewrite (add ty (Int ty a) (Int ty b)) (Int ty (+ a b)) :ruleset cfold) + (rewrite (sub ty (Int ty a) (Int ty b)) (Int ty (- a b)) :ruleset cfold) {program} {schedule} diff --git a/src/rvsdg/schema.egg b/src/rvsdg/schema.egg index 07e3e3caa..b7ae88c89 100644 --- a/src/rvsdg/schema.egg +++ b/src/rvsdg/schema.egg @@ -3,10 +3,11 @@ (datatype Operand) (datatype Body) -(sort VecOperand (Vec Operand)) -(datatype VecOperandWrapper - (VO VecOperand)) -(sort VecVecOperand (Vec VecOperandWrapper)) +(sort VecOperandBase (Vec Operand)) +(datatype VecOperand (VO VecOperandBase)) + +(sort VecVecOperandBase (Vec VecOperand)) +(datatype VecVecOperand (VVO VecVecOperandBase)) ;; Type (datatype Type @@ -48,6 +49,98 @@ (function Gamma (Operand VecOperand VecVecOperand) Body) ;; branching (function Theta (Operand VecOperand VecOperand) Body) ;; loop +;; Substitution +(ruleset subst) + +;; e [ x -> v ] +(function SubstExpr (Expr i64 Operand) Expr) +(function SubstOperand (Operand i64 Operand) Operand) +(function SubstBody (Body i64 Operand) Body) +(function SubstVecOperand (VecOperand i64 Operand) VecOperand) +(function SubstVecVecOperand (VecVecOperand i64 Operand) VecVecOperand) + +(rewrite (SubstExpr (Const ty ops lit) x v) (Const ty ops lit) :ruleset subst) +(rewrite (SubstExpr (Call ty f args) x v) (Call ty f (SubstVecOperand args x v)) :ruleset subst) +(rewrite (SubstExpr (add ty a b) x v) (add ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (sub ty a b) x v) (sub ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (mul ty a b) x v) (mul ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (div ty a b) x v) (div ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (eq ty a b) x v) (eq ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (lt ty a b) x v) (lt ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (gt ty a b) x v) (gt ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (le ty a b) x v) (le ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (ge ty a b) x v) (ge ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (not ty a b) x v) (not ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (and ty a b) x v) (and ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) +(rewrite (SubstExpr (or ty a b) x v) (or ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst) + +(rewrite (SubstOperand (Arg x) x v) v :ruleset subst) +(rule ((= f (SubstOperand (Arg y) x v)) (!= y x)) + ((union f (Arg y))) :ruleset subst) +(rewrite (SubstOperand (Node b) x v) (Node (SubstBody b x v)) :ruleset subst) +(rewrite (SubstOperand (Project i b) x v) (Project i (SubstBody b x v)) :ruleset subst) + +(rewrite (SubstBody (PureOp e) x v) (PureOp (SubstExpr e x v)) :ruleset subst) +;; Subst doesn't cross regions - so we subst into the inputs but not outputs +;; Node that a Gamma node's idx is on the outside, so it gets affected, but not +;; a Theta node's predicate +(rewrite (SubstBody (Gamma pred inputs outputs) x v) (Gamma (SubstOperand pred x v) (SubstVecOperand inputs x v) outputs) :ruleset subst) +(rewrite (SubstBody (Theta pred inputs outputs) x v) (Theta pred (SubstVecOperand inputs x v) outputs) :ruleset subst) + +;; params: vec, var, op, next index to subst +;; rtjoa: TODO: implement by mapping internally so they're not O(n^2) time +(function SubstVecOperandHelper (VecOperand i64 Operand i64) VecOperand) +(rewrite (SubstVecOperand vec x v) (SubstVecOperandHelper vec x v 0) :ruleset subst) +(rule + ( + (= f (SubstVecOperandHelper (VO vec) x v i)) + (< i (vec-length vec)) + ) + ( + (union + (SubstVecOperandHelper (VO vec) x v i) + (SubstVecOperandHelper + (VO (vec-set vec i (SubstOperand (vec-get vec i) x v))) + x v (+ i 1) + )) + ) :ruleset subst) + +(rule + ( + (= f (SubstVecOperandHelper (VO vec) x v i)) + (= i (vec-length vec)) + ) + ( + (union (SubstVecOperandHelper (VO vec) x v i) (VO vec)) + ) :ruleset subst) + + +;; params: vec, var, op, next index to subst +;; rtjoa: TODO: implement by mapping internally so they're not O(n^2) time +(function SubstVecVecOperandHelper (VecVecOperand i64 Operand i64) VecVecOperand) +(rewrite (SubstVecVecOperand vec x v) (SubstVecVecOperandHelper vec x v 0) :ruleset subst) +(rule + ( + (= f (SubstVecVecOperandHelper (VVO vec) x v i)) + (< i (vec-length vec)) + ) + ( + (union + (SubstVecVecOperandHelper (VVO vec) x v i) + (SubstVecVecOperandHelper + (VVO (vec-set vec i (SubstVecOperand (vec-get vec i) x v))) + x v (+ i 1) + )) + ) :ruleset subst) + +(rule + ( + (= f (SubstVecVecOperandHelper (VVO vec) x v i)) + (= i (vec-length vec)) + ) + ( + (union (SubstVecVecOperandHelper (VVO vec) x v i) (VVO vec)) + ) :ruleset subst) ;; procedure f(n): ;; i = 0 diff --git a/src/rvsdg/tests.rs b/src/rvsdg/tests.rs index 5b74db936..191fe6729 100644 --- a/src/rvsdg/tests.rs +++ b/src/rvsdg/tests.rs @@ -176,7 +176,7 @@ fn rvsdg_state_gamma() { .C: call @other_func; jmp .End; - .End: + .End: } @other_func() { @@ -417,30 +417,30 @@ fn rvsdg_odd_branch_egg_roundtrip() { (const) (Num 1))))))) (Arg 3)))) - (vec-of (Arg 1) + (VO (vec-of (Arg 1) (Node (PureOp (Const (IntT) (const) (Num 0)))) (Node (PureOp (Const (IntT) (const) (Num 0)))) - (Arg 0)) - (vec-of (Arg 0) + (Arg 0))) + (VO (vec-of (Arg 0) (Node (PureOp (add (IntT) (Arg 1) (Arg 2)))) (Node (PureOp (add (IntT) (Arg 2) (Node (PureOp (Const (IntT) (const) (Num 1))))))) - (Arg 3)))) + (Arg 3))))) (let rescaled (Gamma (Node (PureOp (lt (BoolT) (Project 1 loop) (Node (PureOp (Const (IntT) (const) (Num 5))))))) - (vec-of + (VO (vec-of (Project 0 loop) - (Project 1 loop)) - (vec-of (VO (vec-of (Arg 0) (Arg 1))) + (Project 1 loop))) + (VVO (vec-of (VO (vec-of (Arg 0) (Arg 1))) (VO (vec-of (Arg 0) (Node (PureOp (mul (IntT) (Arg 1) (Node (PureOp (Const (IntT) (const) - (Num 2)))))))))))) + (Num 2))))))))))))) (let expected-result (Project 0 rescaled)) (let expected-state (Project 1 rescaled)) "#; @@ -648,3 +648,102 @@ fn deep_equal(f1: &RvsdgFunction, f2: &RvsdgFunction) -> bool { (None, Some(_)) | (Some(_), None) => false, } } + +#[test] +fn rvsdg_subst() { + const EGGLOG_PROGRAM: &str = r#" + (let unsubsted + (Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2) + (Node (PureOp (Const (IntT) + (const) + (Num 1))))))) + (Arg 3))))) + (let substed (SubstOperand unsubsted 3 (Arg 7))) + (run-schedule (saturate subst)) + (let expected + (Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2) + (Node (PureOp (Const (IntT) + (const) + (Num 1))))))) + (Arg 7))))) + (check (= substed expected)) + "#; + let mut egraph = new_rvsdg_egraph(); + egraph.parse_and_run_program(EGGLOG_PROGRAM).unwrap(); + + const EGGLOG_THETA_PROGRAM: &str = r#" + (let unsubsted + (Theta + (Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2) + (Node (PureOp (Const (IntT) + (const) + (Num 1))))))) + (Arg 3)))) + (VO (vec-of (Arg 0) + (Node (PureOp (Const (IntT) (const) (Num 0)))) + (Node (PureOp (Const (IntT) (const) (Num 0)))) + (Arg 1))) + (VO (vec-of (Arg 0) + (Node (PureOp (add (IntT) (Arg 1) (Arg 2)))) + (Node (PureOp (add (IntT) (Arg 2) + (Node (PureOp (Const (IntT) (const) (Num 1))))))) + (Arg 3))))) + (let substed (SubstBody unsubsted 1 (Arg 7))) + (run-schedule (saturate subst)) + (let expected + (Theta + (Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2) + (Node (PureOp (Const (IntT) + (const) + (Num 1))))))) + (Arg 3)))) + (VO (vec-of (Arg 0) + (Node (PureOp (Const (IntT) (const) (Num 0)))) + (Node (PureOp (Const (IntT) (const) (Num 0)))) + (Arg 7))) + (VO (vec-of (Arg 0) + (Node (PureOp (add (IntT) (Arg 1) (Arg 2)))) + (Node (PureOp (add (IntT) (Arg 2) + (Node (PureOp (Const (IntT) (const) (Num 1))))))) + (Arg 3))))) + (check (= substed expected)) + "#; + let mut egraph = new_rvsdg_egraph(); + egraph.parse_and_run_program(EGGLOG_THETA_PROGRAM).unwrap(); + + const EGGLOG_GAMMA_PROGRAM: &str = r#" + (let unsubsted + (Gamma + (Node + (PureOp + (lt (BoolT) (Arg 0) (Arg 0)))) + (VO (vec-of + (Arg 1) + (Arg 0))) + (VVO (vec-of (VO (vec-of (Arg 0) (Arg 1))) + (VO (vec-of (Arg 0) + (Node (PureOp (mul (IntT) (Arg 1) + (Node (PureOp (Const (IntT) + (const) + (Num 2))))))))))))) + (let substed (SubstBody unsubsted 0 (Arg 7))) + (run-schedule (saturate subst)) + (let expected + (Gamma + (Node + (PureOp + (lt (BoolT) (Arg 7) (Arg 7)))) + (VO (vec-of + (Arg 1) + (Arg 7))) + (VVO (vec-of (VO (vec-of (Arg 0) (Arg 1))) + (VO (vec-of (Arg 0) + (Node (PureOp (mul (IntT) (Arg 1) + (Node (PureOp (Const (IntT) + (const) + (Num 2))))))))))))) + (check (= substed expected)) + "#; + let mut egraph = new_rvsdg_egraph(); + egraph.parse_and_run_program(EGGLOG_GAMMA_PROGRAM).unwrap(); +}