Skip to content

Commit

Permalink
fix type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Nov 8, 2024
1 parent e887882 commit d80e687
Showing 1 changed file with 59 additions and 23 deletions.
82 changes: 59 additions & 23 deletions dag_in_context/src/optimizations/rec_to_loop.egg
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,26 @@
(rule
((Function name in out body)
(= body (If pred always-runs (Call name rec_case) base-case))
(HasType always-runs inputs-ty)
(HasType always-runs start-ty)
(HasType body func-ty))
((let loop
(DoWhile (Arg inputs-ty (InIf true pred always-runs))
(Concat
(Single (Subst (InIf true pred always-runs) rec_case pred))
(Subst (InIf true pred always-runs) rec_case always-runs))))
((panic "bad")
(let loop-inputs (Arg start-ty (InIf true pred always-runs)))
(let loop-outputs
(Concat
(Single (Subst (TmpCtx) rec_case pred))
(Subst (TmpCtx) rec_case always-runs)))
(union (TmpCtx) (InLoop loop-inputs loop-outputs))
(delete (TmpCtx))

(let loop
(DoWhile loop-inputs loop-outputs))


;; initial start value
(let outer-if
(If pred always-runs
loop
(Arg inputs-ty (InIf false pred always-runs))))
(Arg start-ty (InIf false pred always-runs))))
(union body (Subst (InFunc name) outer-if base-case)))
:ruleset rec-to-loop)

Expand All @@ -64,7 +71,7 @@
;; if (start[0]) {
;; do {
;; start = always_runs(rec_case(start));
;; acc = acc + f(start);
;; acc = acc + extra(start);
;; } while (start[0]);
;; }
;; ret base_case(start) + acc;
Expand All @@ -74,25 +81,54 @@
(= body (If pred always-runs then-case base-case))
(= call (Call name rec-case))
(= then-case
(Concat (BinOp (Add) (Get call 0) op-arg)
(Concat (Bop (Add) (Get call 0) extra)
(Get call 1)))
(HasType inputs inputs-ty)
(= inputs-ty (TupleT inputs-ty-list))
(HasType always-runs start-ty)
(= always-runs-len (tuple-length always-runs))
(= start-ty (TupleT start-ty-list))
(HasType body func-ty))
((let loop-ty
(TupleT (TLConcat inputs-ty-list (TCons (IntT) (TNil)))))
((panic "good")
(let loop-ty
(TupleT (TLConcat start-ty-list (TCons (IntT) (TNil)))))
;; recursive case in the loop
(let new-rec-case
(Subst (TmpCtx)
(SubTuple (Arg loop-ty (TmpCtx)) 0 always-runs-len) rec-case))
;; extra computation in the loop
(let new-extra
(Subst (TmpCtx)
(SubTuple (Arg loop-ty (TmpCtx)) 0 always-runs-len) extra))
;; acc starts at 0
(let loop-inputs
(Concat (Arg start-ty (InIf true pred always-runs)) (Single (Const (Int 0) start-ty (InIf true pred always-runs)))))
(let loop-outputs
(Concat
(Single (Subst (TmpCtx) new-rec-case pred))
(Concat
(Subst (TmpCtx) new-rec-case always-runs)
;; add extra to acc
(Single (Bop (Add) (Get (Arg loop-ty (TmpCtx)) always-runs-len) new-extra)))))
;; loop starts at zero, adds extra each iteration
(let loop
(DoWhile (Concat (Arg inputs-ty (InIf true pred inputs)) (Single (Const 0 (IntT))))
(Concat
(Single (Subst (InIf true pred inputs) rec_case pred))
(Subst (InIf true pred inputs) rec_case inputs))
))

;; initial start value
(DoWhile loop-inputs loop-outputs))
;; union tmpctx
(union (TmpCtx) (InLoop loop-inputs loop-outputs))
(delete (TmpCtx))

(let outer-if
(If pred inputs
(If pred always-runs
loop
(Arg inputs-ty (InIf false pred inputs))))
(union body (Subst (InFunc name) outer-if base-case)))
(Concat
(Arg start-ty (InIf false pred always-runs))
;; otherwise acc is 0
(Const (Int 0) func-ty (InIf false pred always-runs)))))
;; base case over latest start value
(let new-base-case
(Subst (InFunc name) (SubTuple outer-if 0 always-runs-len) base-case))
;; add base case to acc
(let res
(Concat
(Bop (Add) (Get new-base-case 0) (Get outer-if always-runs-len))
(Get new-base-case 1)))
(union body res))
:ruleset rec-to-loop)

0 comments on commit d80e687

Please sign in to comment.