From ed7a15d41f575283b7dc4eb68de157304d264a91 Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Wed, 30 Oct 2024 12:28:27 -0700 Subject: [PATCH] Eliminate common subexpressions in our queries --- dag_in_context/src/optimizations/loop_unroll.egg | 7 +++---- dag_in_context/src/optimizations/memory.egg | 10 ++++++---- dag_in_context/src/optimizations/passthrough.egg | 10 ++++++---- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/dag_in_context/src/optimizations/loop_unroll.egg b/dag_in_context/src/optimizations/loop_unroll.egg index 847bd5262..3120568d9 100644 --- a/dag_in_context/src/optimizations/loop_unroll.egg +++ b/dag_in_context/src/optimizations/loop_unroll.egg @@ -25,7 +25,6 @@ ;; TODO: we could make it work for decrementing loops (rule ((= lhs (DoWhile inputs outputs)) - (= num-inputs (tuple-length inputs)) (= pred (Get outputs 0)) ;; iteration counter starts at start_const (= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i)) @@ -50,18 +49,18 @@ ;; and i is updated after checking pred (rule ((= lhs (DoWhile inputs outputs)) - (= num-inputs (tuple-length inputs)) (= pred (Get outputs 0)) ;; iteration counter starts at start_const (= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i)) + (= body-arg (Get (Arg _ty _ctx) counter_i)) ;; updated counter at counter_i (= next_counter (Get outputs (+ counter_i 1))) ;; increments by a constant each loop - (= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i) + (= next_counter (Bop (Add) body-arg (Const (Int increment) _ty2 _ctx2))) (> increment 0) ;; while this counter less than end_constant - (= pred (Bop (LessThan) (Get (Arg _ty _ctx) counter_i) + (= pred (Bop (LessThan) body-arg (Const (Int end_constant) _ty3 _ctx3))) ;; end constant is at least start constant (>= end_constant start_const) diff --git a/dag_in_context/src/optimizations/memory.egg b/dag_in_context/src/optimizations/memory.egg index 2d0016457..cc1f85e73 100644 --- a/dag_in_context/src/optimizations/memory.egg +++ b/dag_in_context/src/optimizations/memory.egg @@ -249,16 +249,18 @@ :ruleset memory-helpers) ; Compute and propagate PointsToCells -(rewrite (PointsToCells (Concat x y) aps) +(rewrite (PointsToCells concat-x-y aps) (TuplePointsTo (Concat-List (UnwrapTuplePointsTo (PointsToCells x aps)) (UnwrapTuplePointsTo (PointsToCells y aps)))) - :when ((HasType (Concat x y) ty) (PointerishType ty)) + :when ((= concat-x-y (Concat x y)) + (HasType concat-x-y ty) (PointerishType ty)) :ruleset memory-helpers) -(rewrite (PointsToCells (Get x i) aps) +(rewrite (PointsToCells get-x-i aps) (GetPointees (PointsToCells x aps) i) - :when ((HasType (Get x i) ty) (PointerishType ty)) + :when ((= get-x-i (Get x i)) + (HasType get-x-i ty) (PointerishType ty)) :ruleset memory-helpers) (rewrite (PointsToCells (Single x) aps) diff --git a/dag_in_context/src/optimizations/passthrough.egg b/dag_in_context/src/optimizations/passthrough.egg index 0116850cc..c420723d0 100644 --- a/dag_in_context/src/optimizations/passthrough.egg +++ b/dag_in_context/src/optimizations/passthrough.egg @@ -7,7 +7,7 @@ (= (Get pred-outputs (+ i 1)) (Get (Arg _ty _ctx) i)) ;; only pass through pure types, since some loops don't terminate ;; so the state edge must pass through them - (HasType (Get loop i) lhs_ty) + (HasType lhs lhs_ty) (PureType lhs_ty) ) ((union lhs (Get inputs i))) @@ -40,9 +40,11 @@ ;; Pass through if arguments (rule ((= if (If pred inputs then_ else_)) - (= (Get then_ i) (Get (Arg arg_ty _then_ctx) j)) - (= (Get else_ i) (Get (Arg arg_ty _else_ctx) j)) - (HasType (Get then_ i) lhs_ty) + (= then-branch (Get then_ i)) + (= else-branch (Get else_ i)) + (= then-branch (Get (Arg arg_ty _then_ctx) j)) + (= else-branch (Get (Arg arg_ty _else_ctx) j)) + (HasType then-branch lhs_ty) (!= lhs_ty (Base (StateT)))) ((union (Get if i) (Get inputs j))) :ruleset passthrough)