diff --git a/tree_unique_args/src/loop_invariant.egg b/tree_unique_args/src/loop_invariant.egg index b2bf2e52c..f62f7a0a2 100644 --- a/tree_unique_args/src/loop_invariant.egg +++ b/tree_unique_args/src/loop_invariant.egg @@ -1,28 +1,28 @@ ;; Loop Invariant ;; bool: whether the term in the Expr is an invariant. -(function is-inv-Expr (Expr Expr) bool :unextractable :merge (or old new)) -(function is-inv-ListExpr (Expr ListExpr) bool :unextractable :merge (or old new)) +(function is-inv-Expr (IdSort Expr) bool :unextractable :merge (or old new)) +(function is-inv-ListExpr (IdSort ListExpr) bool :unextractable :merge (or old new)) -(relation arg-inv (Expr i64)) +(relation arg-inv (IdSort i64)) ;; in default, when there is a find, set is-inv to false -(rule ((BodyContainsExpr loop term) (= loop (Loop id in pred_out))) ((set (is-inv-Expr loop term) false)) :ruleset always-run) -(rule ((BodyContainsListExpr loop term) (= loop (Loop id in pred_out))) ((set (is-inv-ListExpr loop term) false)) :ruleset always-run) +(rule ((BodyContainsExpr loop_id term) (Loop loop_id in pred_out)) ((set (is-inv-Expr loop_id term) false)) :ruleset always-run) +(rule ((BodyContainsListExpr loop_id term) (Loop loop_id in pred_out)) ((set (is-inv-ListExpr loop_id term) false)) :ruleset always-run) ;; I assume input is tuple here (rule ((= loop (Loop id inputs outputs)) - (= (Get (Arg id) i) (get-loop-outputs-ith loop i))) - ((arg-inv loop i)) :ruleset always-run) + (= (Get (Arg id) i) (get-loop-outputs-ith id i))) + ((arg-inv id i)) :ruleset always-run) -(relation is-inv-ListExpr-helper (Expr ListExpr i64)) -(rule ((BodyContainsListExpr loop list)) ((is-inv-ListExpr-helper loop list 0)) :ruleset always-run) +(relation is-inv-ListExpr-helper (IdSort ListExpr i64)) +(rule ((BodyContainsListExpr loop_id list) (Loop loop_id in out)) ((is-inv-ListExpr-helper loop_id list 0)) :ruleset always-run) -(rule ((is-inv-ListExpr-helper loop list i) - (= true (is-inv-Expr loop expr)) +(rule ((is-inv-ListExpr-helper loop_id list i) + (= true (is-inv-Expr loop_id expr)) (= expr (ListExpr-ith list i))) - ((is-inv-ListExpr-helper loop list (+ i 1))) :ruleset always-run) + ((is-inv-ListExpr-helper loop_id list (+ i 1))) :ruleset always-run) -(rule ((is-inv-ListExpr-helper loop list i) +(rule ((is-inv-ListExpr-helper loop_id list i) (= i (ListExpr-length list))) - ((set (is-inv-ListExpr loop list) true)) :ruleset always-run) \ No newline at end of file + ((set (is-inv-ListExpr loop_id list) true)) :ruleset always-run) \ No newline at end of file diff --git a/tree_unique_args/src/loop_invariant.rs b/tree_unique_args/src/loop_invariant.rs index c4b7d35c1..465b44ff6 100644 --- a/tree_unique_args/src/loop_invariant.rs +++ b/tree_unique_args/src/loop_invariant.rs @@ -8,17 +8,19 @@ fn is_inv_base_case_for_ctor(ctor: Constructor) -> Option { match ctor { Constructor::Get => Some(format!( - "(rule ((BodyContainsExpr loop expr) \ - {br} (= expr (Get (Arg id) i)) \ - {br} (arg-inv loop i)) \ - {br}((set (is-inv-Expr loop expr) true)){ruleset})" + "(rule ((BodyContainsExpr loop_id expr) \ + {br} (Loop loop_id in out) \ + {br} (= expr (Get (Arg loop_id) i)) \ + {br} (arg-inv loop_id i)) \ + {br}((set (is-inv-Expr loop_id expr) true)){ruleset})" )), Constructor::Num | Constructor::Boolean => { let ctor_pattern = ctor.construct(|field| field.var()); Some(format!( - "(rule ((BodyContainsExpr loop expr) \ + "(rule ((BodyContainsExpr loop_id expr) \ + {br} (Loop loop_id in out) \ {br} (= expr {ctor_pattern})) \ - {br}((set (is-inv-Expr loop expr) true)){ruleset})" + {br}((set (is-inv-Expr loop_id expr) true)){ruleset})" )) } _ => None, @@ -53,7 +55,7 @@ fn is_invariant_rule_for_ctor(ctor: Constructor) -> Option { Purpose::SubExpr | Purpose::SubListExpr => { let var = field.var(); let sort = field.sort().name(); - Some(format!("(= true (is-inv-{sort} loop {var}))")) + Some(format!("(= true (is-inv-{sort} loop_id {var}))")) } }) .join(" "); @@ -64,10 +66,11 @@ fn is_invariant_rule_for_ctor(ctor: Constructor) -> Option { _ => String::new(), }; Some(format!( - "(rule ((BodyContainsExpr loop expr) \ + "(rule ((BodyContainsExpr loop_id expr) \ + {br} (Loop loop_id in out) \ {br} (= expr {ctor_pattern}) \ {br} {is_inv_ctor} {is_pure}) \ - {br}((set (is-inv-Expr loop expr) true)){ruleset})" + {br}((set (is-inv-Expr loop_id expr) true)){ruleset})" )) } } @@ -99,12 +102,12 @@ fn loop_invariant_detection1() -> Result<(), egglog::Error> { "; let check = " - (check (arg-inv loop 0)) - (fail (check (arg-inv loop 1))) - (check (= true (is-inv-Expr loop (Get (Arg id1) 0)))) - (check (= false (is-inv-Expr loop (Get (Arg id1) 1)))) - (check (= true (is-inv-Expr loop (Add (Num id1 1) (Get (Arg id1) 0))))) - (check (= false (is-inv-Expr loop (Sub (Get (Arg id1) 1) (Add (Num id1 1) (Get (Arg id1) 0))) ))) + (check (arg-inv id1 0)) + (fail (check (arg-inv id1 1))) + (check (= true (is-inv-Expr id1 (Get (Arg id1) 0)))) + (check (= false (is-inv-Expr id1 (Get (Arg id1) 1)))) + (check (= true (is-inv-Expr id1 (Add (Num id1 1) (Get (Arg id1) 0))))) + (check (= false (is-inv-Expr id1 (Sub (Get (Arg id1) 1) (Add (Num id1 1) (Get (Arg id1) 0))) ))) "; crate::run_test(build, check) @@ -149,19 +152,19 @@ fn loop_invariant_detection2() -> Result<(), egglog::Error> { "; let check = " - (check (arg-inv loop 1)) - (check (arg-inv loop 2)) - (check (arg-inv loop 3)) - (check (arg-inv loop 4)) - (fail (check (arg-inv loop 0))) + (check (arg-inv id1 1)) + (check (arg-inv id1 2)) + (check (arg-inv id1 3)) + (check (arg-inv id1 4)) + (fail (check (arg-inv id1 0))) (let l4 (list4 (Num id1 1) (Num id1 2) (Num id1 3) (Num id1 4))) - (check (is-inv-ListExpr-helper loop l4 4)) - (check (= true (is-inv-ListExpr loop l4))) - (check (= true (is-inv-Expr loop (Switch (Num id1 1) l4)))) - (check (= true (is-inv-Expr loop inv))) - (check (= false (is-inv-Expr loop (Add (Get (Arg id1) 0) inv)))) + (check (is-inv-ListExpr-helper id1 l4 4)) + (check (= true (is-inv-ListExpr id1 l4))) + (check (= true (is-inv-Expr id1 (Switch (Num id1 1) l4)))) + (check (= true (is-inv-Expr id1 inv))) + (check (= false (is-inv-Expr id1 (Add (Get (Arg id1) 0) inv)))) ;; a non exist expr should fail - (fail (check (is-inv-Expr loop (Switch (Num id1 2) l4)))) + (fail (check (is-inv-Expr id1 (Switch (Num id1 2) l4)))) "; crate::run_test(build, check) diff --git a/tree_unique_args/src/util.egg b/tree_unique_args/src/util.egg index 36b5b627a..292c8bce1 100644 --- a/tree_unique_args/src/util.egg +++ b/tree_unique_args/src/util.egg @@ -21,12 +21,22 @@ :ruleset always-run) ;; get the ith output of a loop -(function get-loop-outputs-ith (Expr i64) Expr :unextractable) -(rule ((= loop (Loop id inputs pred-outputs)) - (= pred-outputs (All id1 ord1 pred-out-list)) - (= (All id2 ord2 outputs-list) (ListExpr-ith pred-out-list 1)) - (= ith-outputs (ListExpr-ith outputs-list i))) - ((union (get-loop-outputs-ith loop i) ith-outputs)) :ruleset always-run) + +(function get-loop-output (IdSort) Expr :unextractable) +(function get-loop-pred (IdSort) Expr :unextractable) +(rule ((= loop (Loop loop_id inputs pred_outputs)) + (= pred_outputs (All id_1 ord1 pred_out_list)) + (= pred (ListExpr-ith pred_out_list 0)) + (= out (ListExpr-ith pred_out_list 1))) + ((union (get-loop-output loop_id) out) + (union (get-loop-pred loop_id) pred)) :ruleset always-run) + +;; get the ith output of a loop +(function get-loop-outputs-ith (IdSort i64) Expr :unextractable) +(rule ((= (All id1 ord2 outputs-list) (get-loop-output loop_id)) + (= ith_outputs (ListExpr-ith outputs-list i))) + ((union (get-loop-outputs-ith loop_id i) ith_outputs)) :ruleset always-run) + (function Expr-size (Expr) i64 :merge (min old new)) (function ListExpr-size (ListExpr) i64 :merge (min old new)) diff --git a/tree_unique_args/src/util.rs b/tree_unique_args/src/util.rs index e4c6bd809..fe6651f80 100644 --- a/tree_unique_args/src/util.rs +++ b/tree_unique_args/src/util.rs @@ -105,12 +105,12 @@ fn get_loop_output_ith_test() -> Result<(), egglog::Error> { let check = " (check ( = - (get-loop-outputs-ith loop 0) + (get-loop-outputs-ith id1 0) out0 )) (check ( = - (get-loop-outputs-ith loop 1) + (get-loop-outputs-ith id1 1) out1 ))"; crate::run_test(build, check)