Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RVSDG subst functions #41

Merged
merged 5 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "files"


[dependencies]
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "8cd5fda" }
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "cec39af" }
log = "0.4.19"
thiserror = "1"
lalrpop-util = { version = "0.19.8", features = ["lexer"] }
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ impl Optimizer {

pub fn make_optimizer_for(&mut self, program: &str) -> String {
//let schedule = "(run 3)";
let schedule = format!("(run {})", self.num_iters);
let schedule = format!("(run-schedule (repeat {} cfold (saturate subst)))", self.num_iters);
format!(
"
(datatype Type
Expand Down Expand Up @@ -328,8 +328,8 @@ impl Optimizer {
;; name, arguments, and body
(Func String ArgList StructuredBlock))

(rewrite (add ty (Int ty a) (Int ty b)) (Int ty (+ a b)))
(rewrite (sub ty (Int ty a) (Int ty b)) (Int ty (- a b)))
(rewrite (add ty (Int ty a) (Int ty b)) (Int ty (+ a b)) :ruleset cfold)
(rewrite (sub ty (Int ty a) (Int ty b)) (Int ty (- a b)) :ruleset cfold)

{program}
{schedule}
Expand Down
101 changes: 97 additions & 4 deletions src/rvsdg/schema.egg
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
(datatype Operand)
(datatype Body)

(sort VecOperand (Vec Operand))
(datatype VecOperandWrapper
(VO VecOperand))
(sort VecVecOperand (Vec VecOperandWrapper))
(sort VecOperandBase (Vec Operand))
(datatype VecOperand (VO VecOperandBase))

(sort VecVecOperandBase (Vec VecOperand))
(datatype VecVecOperand (VVO VecVecOperandBase))

;; Type
(datatype Type
Expand Down Expand Up @@ -48,6 +49,98 @@
(function Gamma (Operand VecOperand VecVecOperand) Body) ;; branching
(function Theta (Operand VecOperand VecOperand) Body) ;; loop

;; Substitution
(ruleset subst)

;; e [ x -> v ]
(function SubstExpr (Expr i64 Operand) Expr)
oflatt marked this conversation as resolved.
Show resolved Hide resolved
(function SubstOperand (Operand i64 Operand) Operand)
(function SubstBody (Body i64 Operand) Body)
(function SubstVecOperand (VecOperand i64 Operand) VecOperand)
(function SubstVecVecOperand (VecVecOperand i64 Operand) VecVecOperand)

(rewrite (SubstExpr (Const ty ops lit) x v) (Const ty ops lit) :ruleset subst)
(rewrite (SubstExpr (Call ty f args) x v) (Call ty f (SubstVecOperand args x v)) :ruleset subst)
(rewrite (SubstExpr (add ty a b) x v) (add ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (sub ty a b) x v) (sub ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (mul ty a b) x v) (mul ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (div ty a b) x v) (div ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (eq ty a b) x v) (eq ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (lt ty a b) x v) (lt ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (gt ty a b) x v) (gt ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (le ty a b) x v) (le ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (ge ty a b) x v) (ge ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (not ty a b) x v) (not ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (and ty a b) x v) (and ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)
(rewrite (SubstExpr (or ty a b) x v) (or ty (SubstOperand a x v) (SubstOperand b x v)) :ruleset subst)

(rewrite (SubstOperand (Arg x) x v) v :ruleset subst)
(rule ((= f (SubstOperand (Arg y) x v)) (!= y x))
((union f (Arg y))) :ruleset subst)
(rewrite (SubstOperand (Node b) x v) (Node (SubstBody b x v)) :ruleset subst)
(rewrite (SubstOperand (Project i b) x v) (Project i (SubstBody b x v)) :ruleset subst)

(rewrite (SubstBody (PureOp e) x v) (PureOp (SubstExpr e x v)) :ruleset subst)
;; Subst doesn't cross regions - so we subst into the inputs but not outputs
;; Node that a Gamma node's idx is on the outside, so it gets affected, but not
;; a Theta node's predicate
(rewrite (SubstBody (Gamma pred inputs outputs) x v) (Gamma (SubstOperand pred x v) (SubstVecOperand inputs x v) outputs) :ruleset subst)
(rewrite (SubstBody (Theta pred inputs outputs) x v) (Theta pred (SubstVecOperand inputs x v) outputs) :ruleset subst)

;; params: vec, var, op, next index to subst
;; rtjoa: TODO: implement by mapping internally so they're not O(n^2) time
oflatt marked this conversation as resolved.
Show resolved Hide resolved
(function SubstVecOperandHelper (VecOperand i64 Operand i64) VecOperand)
(rewrite (SubstVecOperand vec x v) (SubstVecOperandHelper vec x v 0) :ruleset subst)
(rule
(
(= f (SubstVecOperandHelper (VO vec) x v i))
(< i (vec-length vec))
)
(
(union
(SubstVecOperandHelper (VO vec) x v i)
(SubstVecOperandHelper
(VO (vec-set vec i (SubstOperand (vec-get vec i) x v)))
x v (+ i 1)
))
) :ruleset subst)

(rule
(
(= f (SubstVecOperandHelper (VO vec) x v i))
(= i (vec-length vec))
)
(
(union (SubstVecOperandHelper (VO vec) x v i) (VO vec))
) :ruleset subst)


;; params: vec, var, op, next index to subst
;; rtjoa: TODO: implement by mapping internally so they're not O(n^2) time
(function SubstVecVecOperandHelper (VecVecOperand i64 Operand i64) VecVecOperand)
(rewrite (SubstVecVecOperand vec x v) (SubstVecVecOperandHelper vec x v 0) :ruleset subst)
(rule
(
(= f (SubstVecVecOperandHelper (VVO vec) x v i))
(< i (vec-length vec))
)
(
(union
(SubstVecVecOperandHelper (VVO vec) x v i)
(SubstVecVecOperandHelper
(VVO (vec-set vec i (SubstVecOperand (vec-get vec i) x v)))
x v (+ i 1)
))
) :ruleset subst)

(rule
(
(= f (SubstVecVecOperandHelper (VVO vec) x v i))
(= i (vec-length vec))
)
(
(union (SubstVecVecOperandHelper (VVO vec) x v i) (VVO vec))
) :ruleset subst)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a ruleset for the subst rules to make it more usable?
I'd also like the existing egglog encoding to run this ruleset to saturation in-between runs of the normal ruleset.

Copy link
Collaborator Author

@rtjoa rtjoa Sep 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added all rules to the subst ruleset, added the existing rules to constant-fold +/* to the cfold ruleset, and made the schedule the following:

let schedule = format!("(run-schedule (repeat {} cfold (saturate subst)))", self.num_iters);```

;; procedure f(n):
;; i = 0
Expand Down
117 changes: 108 additions & 9 deletions src/rvsdg/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ fn rvsdg_state_gamma() {
.C:
call @other_func;
jmp .End;
.End:
.End:
}

@other_func() {
Expand Down Expand Up @@ -417,30 +417,30 @@ fn rvsdg_odd_branch_egg_roundtrip() {
(const)
(Num 1)))))))
(Arg 3))))
(vec-of (Arg 1)
(VO (vec-of (Arg 1)
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Arg 0))
(vec-of (Arg 0)
(Arg 0)))
(VO (vec-of (Arg 0)
(Node (PureOp (add (IntT) (Arg 1) (Arg 2))))
(Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT) (const) (Num 1)))))))
(Arg 3))))
(Arg 3)))))
(let rescaled
(Gamma
(Node
(PureOp
(lt (BoolT) (Project 1 loop)
(Node (PureOp (Const (IntT) (const) (Num 5)))))))
(vec-of
(VO (vec-of
(Project 0 loop)
(Project 1 loop))
(vec-of (VO (vec-of (Arg 0) (Arg 1)))
(Project 1 loop)))
(VVO (vec-of (VO (vec-of (Arg 0) (Arg 1)))
(VO (vec-of (Arg 0)
(Node (PureOp (mul (IntT) (Arg 1)
(Node (PureOp (Const (IntT)
(const)
(Num 2))))))))))))
(Num 2)))))))))))))
(let expected-result (Project 0 rescaled))
(let expected-state (Project 1 rescaled))
"#;
Expand Down Expand Up @@ -648,3 +648,102 @@ fn deep_equal(f1: &RvsdgFunction, f2: &RvsdgFunction) -> bool {
(None, Some(_)) | (Some(_), None) => false,
}
}

#[test]
fn rvsdg_subst() {
const EGGLOG_PROGRAM: &str = r#"
(let unsubsted
(Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT)
(const)
(Num 1)))))))
(Arg 3)))))
(let substed (SubstOperand unsubsted 3 (Arg 7)))
(run-schedule (saturate subst))
(let expected
(Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT)
(const)
(Num 1)))))))
(Arg 7)))))
(check (= substed expected))
"#;
let mut egraph = new_rvsdg_egraph();
egraph.parse_and_run_program(EGGLOG_PROGRAM).unwrap();

const EGGLOG_THETA_PROGRAM: &str = r#"
(let unsubsted
(Theta
(Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT)
(const)
(Num 1)))))))
(Arg 3))))
(VO (vec-of (Arg 0)
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Arg 1)))
(VO (vec-of (Arg 0)
(Node (PureOp (add (IntT) (Arg 1) (Arg 2))))
(Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT) (const) (Num 1)))))))
(Arg 3)))))
(let substed (SubstBody unsubsted 1 (Arg 7)))
(run-schedule (saturate subst))
(let expected
(Theta
(Node (PureOp (lt (BoolT) (Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT)
(const)
(Num 1)))))))
(Arg 3))))
(VO (vec-of (Arg 0)
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Node (PureOp (Const (IntT) (const) (Num 0))))
(Arg 7)))
(VO (vec-of (Arg 0)
(Node (PureOp (add (IntT) (Arg 1) (Arg 2))))
(Node (PureOp (add (IntT) (Arg 2)
(Node (PureOp (Const (IntT) (const) (Num 1)))))))
(Arg 3)))))
(check (= substed expected))
"#;
let mut egraph = new_rvsdg_egraph();
egraph.parse_and_run_program(EGGLOG_THETA_PROGRAM).unwrap();

const EGGLOG_GAMMA_PROGRAM: &str = r#"
(let unsubsted
(Gamma
(Node
(PureOp
(lt (BoolT) (Arg 0) (Arg 0))))
(VO (vec-of
(Arg 1)
(Arg 0)))
(VVO (vec-of (VO (vec-of (Arg 0) (Arg 1)))
(VO (vec-of (Arg 0)
(Node (PureOp (mul (IntT) (Arg 1)
(Node (PureOp (Const (IntT)
(const)
(Num 2)))))))))))))
(let substed (SubstBody unsubsted 0 (Arg 7)))
(run-schedule (saturate subst))
(let expected
(Gamma
(Node
(PureOp
(lt (BoolT) (Arg 7) (Arg 7))))
(VO (vec-of
(Arg 1)
(Arg 7)))
(VVO (vec-of (VO (vec-of (Arg 0) (Arg 1)))
(VO (vec-of (Arg 0)
(Node (PureOp (mul (IntT) (Arg 1)
(Node (PureOp (Const (IntT)
(const)
(Num 2)))))))))))))
(check (= substed expected))
"#;
let mut egraph = new_rvsdg_egraph();
egraph.parse_and_run_program(EGGLOG_GAMMA_PROGRAM).unwrap();
}
Loading