diff --git a/src/rvsdg/tree_unique/mod.rs b/src/rvsdg/tree_unique/mod.rs index eaa944fa5..525e29854 100644 --- a/src/rvsdg/tree_unique/mod.rs +++ b/src/rvsdg/tree_unique/mod.rs @@ -1,9 +1 @@ -//! Convert RVSDG programs to the tree -//! encoding of programs. -//! RVSDGs are close to this encoding, -//! but use a DAG-based semantics. -//! This means that nodes that are shared -//! are only computed once. -//! These shared nodes need to be let-bound so that they are only -//! computed once in the tree encoded -//! program. +pub mod to_tree; diff --git a/src/rvsdg/tree_unique/to_tree.rs b/src/rvsdg/tree_unique/to_tree.rs new file mode 100644 index 000000000..eaa944fa5 --- /dev/null +++ b/src/rvsdg/tree_unique/to_tree.rs @@ -0,0 +1,9 @@ +//! Convert RVSDG programs to the tree +//! encoding of programs. +//! RVSDGs are close to this encoding, +//! but use a DAG-based semantics. +//! This means that nodes that are shared +//! are only computed once. +//! These shared nodes need to be let-bound so that they are only +//! computed once in the tree encoded +//! program. diff --git a/tree_unique_args/src/ast.rs b/tree_unique_args/src/ast.rs new file mode 100644 index 000000000..def00792d --- /dev/null +++ b/tree_unique_args/src/ast.rs @@ -0,0 +1,204 @@ +use crate::{Expr, Expr::*, Id, Order}; + +pub fn give_fresh_ids(expr: &mut Expr) { + let mut id = 1; + give_fresh_ids_helper(expr, 0, &mut id); +} + +fn give_fresh_ids_helper(expr: &mut Expr, current_id: i64, fresh_id: &mut i64) { + match expr { + Loop(id, input, body) => { + let new_id = *fresh_id; + *fresh_id += 1; + *id = Id(new_id); + give_fresh_ids_helper(input, current_id, fresh_id); + give_fresh_ids_helper(body, new_id, fresh_id); + } + Let(id, arg, body) => { + let new_id = *fresh_id; + *fresh_id += 1; + *id = Id(new_id); + give_fresh_ids_helper(arg, current_id, fresh_id); + give_fresh_ids_helper(body, new_id, fresh_id); + } + Arg(id) => { + *id = Id(current_id); + } + Function(id, body) => { + let new_id = *fresh_id; + *fresh_id += 1; + *id = Id(new_id); + give_fresh_ids_helper(body, new_id, fresh_id); + } + Call(id, arg) => { + *id = Id(current_id); + give_fresh_ids_helper(arg, current_id, fresh_id); + } + _ => expr.for_each_child(move |child| give_fresh_ids_helper(child, current_id, fresh_id)), + } +} + +pub fn program(args: Vec) -> Expr { + let mut prog = Program(args); + give_fresh_ids(&mut prog); + prog +} + +pub fn num(n: i64) -> Expr { + Num(n) +} + +pub fn ttrue() -> Expr { + Boolean(true) +} +pub fn tfalse() -> Expr { + Boolean(false) +} + +pub fn unit() -> Expr { + Unit +} + +pub fn add(a: Expr, b: Expr) -> Expr { + Add(Box::new(a), Box::new(b)) +} + +pub fn sub(a: Expr, b: Expr) -> Expr { + Sub(Box::new(a), Box::new(b)) +} + +pub fn mul(a: Expr, b: Expr) -> Expr { + Mul(Box::new(a), Box::new(b)) +} + +pub fn lessthan(a: Expr, b: Expr) -> Expr { + LessThan(Box::new(a), Box::new(b)) +} + +pub fn and(a: Expr, b: Expr) -> Expr { + And(Box::new(a), Box::new(b)) +} + +pub fn or(a: Expr, b: Expr) -> Expr { + Or(Box::new(a), Box::new(b)) +} + +pub fn not(a: Expr) -> Expr { + Not(Box::new(a)) +} + +pub fn get(a: Expr, i: usize) -> Expr { + Get(Box::new(a), i) +} + +pub fn concat(a: Expr, b: Expr) -> Expr { + Concat(Box::new(a), Box::new(b)) +} + +pub fn print(a: Expr) -> Expr { + Print(Box::new(a)) +} + +pub fn sequence(args: Vec) -> Expr { + All(Order::Sequential, args) +} + +pub fn parallel(args: Vec) -> Expr { + All(Order::Parallel, args) +} + +pub fn switch(arg: Expr, cases: Vec) -> Expr { + Switch(Box::new(arg), cases) +} + +pub fn tloop(input: Expr, body: Expr) -> Expr { + Loop(Id(0), Box::new(input), Box::new(body)) +} + +pub fn tlet(arg: Expr, body: Expr) -> Expr { + Let(Id(0), Box::new(arg), Box::new(body)) +} + +pub fn arg() -> Expr { + Arg(Id(0)) +} + +pub fn function(arg: Expr) -> Expr { + Function(Id(0), Box::new(arg)) +} + +pub fn call(arg: Expr) -> Expr { + Call(Id(0), Box::new(arg)) +} + +#[test] +fn test_gives_nested_ids() { + let mut prog = tlet(num(0), tlet(num(1), num(2))); + give_fresh_ids(&mut prog); + assert_eq!( + prog, + Let( + Id(1), + Box::new(Num(0)), + Box::new(Let(Id(2), Box::new(Num(1)), Box::new(Num(2)))) + ) + ); +} + +#[test] +fn test_gives_loop_ids() { + let mut prog = tlet(num(0), tloop(num(1), num(2))); + give_fresh_ids(&mut prog); + assert_eq!( + prog, + Let( + Id(1), + Box::new(Num(0)), + Box::new(Loop(Id(2), Box::new(Num(1)), Box::new(Num(2)))) + ) + ); +} + +#[test] +fn test_complex_program_ids() { + // test a program that includes + // a let, a loop, a switch, and a call + let prog = program(vec![function(tlet( + num(0), + tloop( + num(1), + switch( + arg(), + vec![ + num(2), + call(num(3)), + tlet(num(4), num(5)), + tloop(num(6), num(7)), + ], + ), + ), + ))]); + assert_eq!( + prog, + Program(vec![Function( + Id(1), + Box::new(Let( + Id(2), + Box::new(Num(0)), + Box::new(Loop( + Id(3), + Box::new(Num(1)), + Box::new(Switch( + Box::new(Arg(Id(3))), + vec![ + Num(2), + Call(Id(3), Box::new(Num(3))), + Let(Id(4), Box::new(Num(4)), Box::new(Num(5))), + Loop(Id(5), Box::new(Num(6)), Box::new(Num(7))), + ] + )) + )) + )) + )]) + ); +} diff --git a/tree_unique_args/src/interpreter.rs b/tree_unique_args/src/interpreter.rs index 464394f94..4158a827b 100644 --- a/tree_unique_args/src/interpreter.rs +++ b/tree_unique_args/src/interpreter.rs @@ -18,6 +18,7 @@ pub fn typecheck(e: &Expr, arg_ty: &Option) -> Result { } }; match e { + Expr::Program(_) => panic!("Found non top level program."), Expr::Num(_) => Ok(Type::Num), Expr::Boolean(_) => Ok(Type::Boolean), Expr::Unit => Ok(Type::Unit), @@ -50,6 +51,23 @@ pub fn typecheck(e: &Expr, arg_ty: &Option) -> Result { )), } } + Expr::Concat(tuple_1, tuple_2) => { + let ty_tuple_1 = typecheck(tuple_1, arg_ty)?; + let ty_tuple_2 = typecheck(tuple_2, arg_ty)?; + match (ty_tuple_1.clone(), ty_tuple_2.clone()) { + (Type::Tuple(tys_1), Type::Tuple(tys_2)) => { + Ok(Type::Tuple(tys_1.into_iter().chain(tys_2).collect())) + } + (Type::Tuple(_tys_1), _) => Err(TypeError::ExpectedTupleType( + *tuple_2.clone(), + ty_tuple_2.clone(), + )), + _ => Err(TypeError::ExpectedTupleType( + *tuple_1.clone(), + ty_tuple_1.clone(), + )), + } + } Expr::Print(e) => { // right now, only print nums expect_type(e, Type::Num)?; @@ -117,6 +135,7 @@ pub struct VirtualMachine { // assumes e typechecks and that memory is written before read pub fn interpret(e: &Expr, arg: &Option, vm: &mut VirtualMachine) -> Value { match e { + Expr::Program(_) => todo!("interpret programs"), Expr::Num(x) => Value::Num(*x), Expr::Boolean(x) => Value::Boolean(*x), Expr::Unit => Value::Unit, @@ -186,6 +205,15 @@ pub fn interpret(e: &Expr, arg: &Option, vm: &mut VirtualMachine) -> Valu }; vals[*i].clone() } + Expr::Concat(tuple_1, tuple_2) => { + let Value::Tuple(t1) = interpret(tuple_1, arg, vm) else { + panic!("concat") + }; + let Value::Tuple(t2) = interpret(tuple_2, arg, vm) else { + panic!("concat") + }; + Value::Tuple(t1.into_iter().chain(t2).collect()) + } Expr::Print(e) => { let Value::Num(n) = interpret(e, arg, vm) else { panic!("print") diff --git a/tree_unique_args/src/lib.rs b/tree_unique_args/src/lib.rs index 6fd2847e5..1caeddce1 100644 --- a/tree_unique_args/src/lib.rs +++ b/tree_unique_args/src/lib.rs @@ -1,3 +1,4 @@ +pub mod ast; pub(crate) mod body_contains; pub(crate) mod conditional_invariant_code_motion; pub(crate) mod deep_copy; @@ -38,6 +39,11 @@ pub enum Expr { Or(Box, Box), Not(Box), Get(Box, usize), + /// Concat is a convenient built-in way + /// to put two tuples together. + /// It's not strictly necessary, but + /// doing it by constructing a new big tuple is tedius and slow. + Concat(Box, Box), Print(Box), Read(Box), Write(Box, Box), @@ -47,9 +53,58 @@ pub enum Expr { Let(Id, Box, Box), Arg(Id), Function(Id, Box), + /// A list of functions, with the first + /// being the main function. + Program(Vec), Call(Id, Box), } +impl Expr { + /// Runs `func` on every child of this expression. + pub fn for_each_child(&mut self, mut func: impl FnMut(&mut Expr)) { + match self { + Expr::Num(_) | Expr::Boolean(_) | Expr::Unit | Expr::Arg(_) => {} + Expr::Add(a, b) + | Expr::Sub(a, b) + | Expr::Mul(a, b) + | Expr::LessThan(a, b) + | Expr::And(a, b) + | Expr::Or(a, b) + | Expr::Concat(a, b) + | Expr::Write(a, b) => { + func(a); + func(b); + } + Expr::Not(a) | Expr::Print(a) | Expr::Read(a) => { + func(a); + } + Expr::Get(a, _) | Expr::Function(_, a) | Expr::Call(_, a) => { + func(a); + } + Expr::All(_, children) => { + for child in children { + func(child); + } + } + Expr::Switch(input, children) => { + func(input); + for child in children { + func(child); + } + } + Expr::Loop(_, pred, output) | Expr::Let(_, pred, output) => { + func(pred); + func(output); + } + Expr::Program(functions) => { + for function in functions { + func(function); + } + } + } + } +} + #[derive(Clone, Debug, PartialEq)] pub enum Value { Num(i64), @@ -85,7 +140,7 @@ pub fn run_test(build: &str, check: &str) -> Result { &subst::subst_rules().join("\n"), &deep_copy::deep_copy_rules().join("\n"), include_str!("sugar.egg"), - include_str!("util.egg"), + &util::rules().join("\n"), &id_analysis::id_analysis_rules().join("\n"), // optimizations include_str!("simple.egg"), diff --git a/tree_unique_args/src/schema.egg b/tree_unique_args/src/schema.egg index bda9b308c..c0f553385 100644 --- a/tree_unique_args/src/schema.egg +++ b/tree_unique_args/src/schema.egg @@ -68,7 +68,7 @@ ; ========================== ; f output -(relation Function (IdSort Expr)) +(function Function (IdSort Expr) Expr) ; f arg (function Call (IdSort Expr) Expr) diff --git a/tree_unique_args/src/sugar.egg b/tree_unique_args/src/sugar.egg index bf53ed250..eda9206bf 100644 --- a/tree_unique_args/src/sugar.egg +++ b/tree_unique_args/src/sugar.egg @@ -4,6 +4,18 @@ (Cons a (Cons b (Nil))) :ruleset always-run) +(function list3 (Expr Expr Expr) ListExpr) +(rewrite (list3 a b c) + (Cons a (Cons b (Cons c (Nil)))) :ruleset always-run) + +(function list4 (Expr Expr Expr Expr) ListExpr) +(rewrite (list4 a b c d) + (Cons a (Cons b (Cons c (Cons d (Nil))))) :ruleset always-run) + +(function list5 (Expr Expr Expr Expr Expr) ListExpr) +(rewrite (list5 a b c d e) + (Cons a (Cons b (Cons c (Cons d (Cons e (Nil)))))) :ruleset always-run) + (function IgnoreFirst (Expr Expr) Expr) (rewrite (IgnoreFirst a b) (Get diff --git a/tree_unique_args/src/type_analysis.egg b/tree_unique_args/src/type_analysis.egg index 3412183d3..eaab2f245 100644 --- a/tree_unique_args/src/type_analysis.egg +++ b/tree_unique_args/src/type_analysis.egg @@ -26,71 +26,24 @@ (rule ((= (TypeList-suffix list n) (TNil))) ((set (TypeList-length list) n)) :ruleset type-analysis) -(relation HasTypeDemand (Expr)) - (relation HasType (Expr Type)) ; Primitives -(rule ((HasTypeDemand (Num id n))) +(rule ((Num id n)) ((HasType (Num id n) (IntT))) :ruleset type-analysis) -(rule ((HasTypeDemand (Boolean id b))) +(rule ((Boolean id b)) ((HasType (Boolean id b) (BoolT))) :ruleset type-analysis) -(rule ((HasTypeDemand (UnitExpr id))) +(rule ((UnitExpr id)) ((HasType (UnitExpr id) (UnitT))) :ruleset type-analysis) - -; Pure Op Demand -(rule ((HasTypeDemand (Add x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (Sub x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (Mul x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (LessThan x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (And x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (Or x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (Not x))) - ((HasTypeDemand x)) - :ruleset type-analysis) -(rule ((HasTypeDemand (Get e idx))) - ((HasTypeDemand e)) - :ruleset type-analysis) - ; Pure Op Compute (rule ( - (HasTypeDemand (Add x y)) + (Add x y) (HasType x (IntT)) (HasType y (IntT)) ) @@ -99,7 +52,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Sub x y)) + (Sub x y) (HasType x (IntT)) (HasType y (IntT)) ) @@ -108,7 +61,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Mul x y)) + (Mul x y) (HasType x (IntT)) (HasType y (IntT)) ) @@ -117,7 +70,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (LessThan x y)) + (LessThan x y) (HasType x (IntT)) (HasType y (IntT)) ) @@ -126,7 +79,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (And x y)) + (And x y) (HasType x (BoolT)) (HasType y (BoolT)) ) @@ -135,7 +88,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Or x y)) + (Or x y) (HasType x (BoolT)) (HasType y (BoolT)) ) @@ -144,7 +97,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Not x)) + (Not x) (HasType x (BoolT)) ) ( @@ -152,30 +105,22 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Get e n)) + (Get e n) (HasType e (TupleT tylist)) ) ((HasType (Get e n) (TypeList-ith tylist n))) :ruleset type-analysis) ; Effectful Ops -(rule ((HasTypeDemand (Print e))) +(rule ((Print e)) ((HasType (Print e) (UnitT))) :ruleset type-analysis) ; TODO: Read and Write (requires type annotations) ; Switch ; if the condition is a boolean, it must have exactly two branches -(rule ((HasTypeDemand (Switch cond (Cons A (Cons B (Nil)))))) - ( - (HasTypeDemand cond) - (HasTypeDemand A) - (HasTypeDemand B) - ) - :ruleset type-analysis) (rule ( (= switch (Switch cond (Cons A (Cons B (Nil))))) - (HasTypeDemand switch) (HasType cond (BoolT)) (HasType A ty) (HasType B ty) @@ -185,21 +130,13 @@ ; Otherwise, the condition must be an integer, and we can have any number of branches. -; peel off a branch and demand type -(rule ((HasTypeDemand (Switch cond (Cons branch rest)))) - ( - (HasTypeDemand branch) - (HasTypeDemand (Switch cond rest)) - ) - :ruleset type-analysis) -; base case- demand the type of the condition -(rule ((HasTypeDemand (Switch cond Nil))) - ((HasTypeDemand cond)) +(rule ((Switch cond (Cons branch rest))) + ((Switch cond rest)) ; peel off a branch for type checking :ruleset type-analysis) ; base case- single branch switch has type of branch (rule ( - (HasTypeDemand (Switch cond (Cons branch (Nil)))) + (Switch cond (Cons branch (Nil))) ; boolean condition handled above, now we must have an integer condition (HasType cond (IntT)) (HasType branch ty) @@ -208,7 +145,7 @@ :ruleset type-analysis) ; recursive case (rule ( - (HasTypeDemand (Switch cond (Cons branch rest))) + (Switch cond (Cons branch rest)) (HasType (Switch cond rest) ty) ; make sure the condition is an integer ; (prevents us from typing boolean switches with >2 branches) @@ -219,21 +156,16 @@ :ruleset type-analysis) ; Sequencing -(rule ((HasTypeDemand (All ord (Cons hd tl)))) - ( - (HasTypeDemand hd) - (HasTypeDemand (All ord tl)) - ) +(rule ((All ord (Cons hd tl))) + ((All ord tl)) ; peel off a layer for type checking :ruleset type-analysis) ; base case: Nil -(rule ( - (HasTypeDemand (All ord (Nil))) - ) +(rule ((All ord (Nil))) ((HasType (All ord (Nil)) (TupleT (TNil)))) :ruleset type-analysis) ; rec case (rule ( - (HasTypeDemand (All ord (Cons hd tl))) + (All ord (Cons hd tl)) (HasType hd ty) (HasType (All ord tl) (TupleT tylist)) ) @@ -247,4 +179,61 @@ (!= t1 t2) ) ((panic "Type Mismatch!")) + :ruleset type-analysis) + + +; Lets + +(rule ( + (Let id in out) + (HasType in ty) + ) + ( + (HasType (Arg id) ty) ; assert the let's argument has type ty in the let's context + ) + :ruleset type-analysis) + +(rule ( + (Let id in out) + (HasType out ty) + ) + ((HasType (Let id in out) ty)) + :ruleset type-analysis) + +; Loops + +(rule ( + (Loop id in pred-out) + (HasType in ty) + ) + ( + (HasType (Arg id) ty) ; assert the argument has type ty in the loop's context + ) + :ruleset type-analysis) + +(rule ( + (Loop id in pred-out) + (HasType in ty) ; input type + ; pred-out must be a two-element tuple. + ; pred must be boolean, output type must match input type + (HasType pred-out (TupleT (TCons (BoolT) (TCons ty (TNil))))) + ) + ((HasType (Loop id in pred-out) ty)) ; whole loop has type of output + :ruleset type-analysis) + +(rule ( + (Loop id in pred-out) + (HasType pred-out (TupleT (TCons pred-ty rest))) + (!= pred-ty (BoolT)) + ) + ((panic "Loop predicate was not a boolean")) + :ruleset type-analysis) + +(rule ( + (Loop id in pred-out) + (HasType in in-ty) + (HasType pred-out (TupleT lst)) + (!= (TypeList-length lst) 2) + ) + ((panic "Loop did not get two arguments (predicate and output)")) :ruleset type-analysis) \ No newline at end of file diff --git a/tree_unique_args/src/type_analysis.rs b/tree_unique_args/src/type_analysis.rs index c601dda23..982425038 100644 --- a/tree_unique_args/src/type_analysis.rs +++ b/tree_unique_args/src/type_analysis.rs @@ -9,8 +9,6 @@ fn simple_types() -> Result<(), egglog::Error> { (let x (LessThan m n)) (let y (Not x)) (let z (And x (Or y y))) - (HasTypeDemand s) - (HasTypeDemand z) "; let check = " (run-schedule (saturate type-analysis)) @@ -36,8 +34,6 @@ fn switch_boolean() -> Result<(), egglog::Error> { (Cons (Add n1 n1) (Cons (Sub n1 n2) (Nil))))) (let wrong_switch (Switch b1 (Cons n1 (Cons n2 (Cons n1 (Nil)))))) - (HasTypeDemand switch) - (HasTypeDemand wrong_switch) "; let check = " (run-schedule (saturate type-analysis)) @@ -62,9 +58,6 @@ fn switch_int() -> Result<(), egglog::Error> { (Switch (Mul n1 n2) (Cons (LessThan n3 n4) (Nil)))) (let s3 (Switch (Sub n2 n2) (Cons (Print n1) (Cons (Print n4) (Cons (Print n3) (Nil)))))) - (HasTypeDemand s1) - (HasTypeDemand s2) - (HasTypeDemand s3) "; let check = " (run-schedule (saturate type-analysis)) @@ -90,17 +83,10 @@ fn tuple() -> Result<(), egglog::Error> { (let tup2 (All (Sequential) (Cons z (Nil)))) (let tup3 (All (Parallel) (Cons x (Cons m (Nil))))) (let tup4 (All (Parallel) (Cons tup2 (Cons tup3 (Nil))))) - (HasTypeDemand tup1) - (HasTypeDemand tup2) - (HasTypeDemand tup3) - (HasTypeDemand tup4) (let get1 (Get tup3 0)) (let get2 (Get tup3 1)) (let get3 (Get (Get tup4 1) 1)) - (HasTypeDemand get1) - (HasTypeDemand get2) - (HasTypeDemand get3) "; let check = " (run-schedule (saturate type-analysis)) @@ -119,3 +105,93 @@ fn tuple() -> Result<(), egglog::Error> { "; crate::run_test(build, check) } + +#[test] +fn lets() -> Result<(), egglog::Error> { + let build = " + (let let-id (Id (i64-fresh!))) + (let outer-ctx (Id (i64-fresh!))) + (let l (Let let-id (Num outer-ctx 5) (Add (Arg let-id) (Arg let-id)))) + (let outer (Id (i64-fresh!))) + (let inner (Id (i64-fresh!))) + (let ctx (Id (i64-fresh!))) + (let nested + (Let outer (Num ctx 3) + (Let inner (All (Parallel) (Cons (Arg outer) (Cons (Num outer 2) (Nil)))) + (Add (Get (Arg inner) 0) (Get (Arg inner) 1))))) + "; + let check = " + (run-schedule (saturate type-analysis)) + (check (HasType l (IntT))) + (check (HasType nested (IntT))) + "; + crate::run_test(build, check) +} + +#[test] +fn loops() -> Result<(), egglog::Error> { + let build = " + (let ctx (Id 0)) + (let loop-id (Id 1)) + (let l (Loop loop-id (Num ctx 1) + (All (Sequential) + (Cons (LessThan (Num loop-id 2) (Num loop-id 3)) + (Cons (Switch (Boolean loop-id true) + (Cons (Num loop-id 4) (Cons (Num loop-id 5) (Nil)))) + (Nil)))))) + "; + let check = " + (run-schedule (saturate type-analysis)) + (check (HasType l (IntT))) + "; + crate::run_test(build, check) +} + +#[test] +#[should_panic] +fn loop_pred_boolean() { + let build = " + (let ctx (Id 0)) + (let loop-id (Id 1)) + (let l (Loop loop-id (Num ctx 1) + (All (Sequential) + (Cons (Add (Num loop-id 2) (Num loop-id 3)) + (Cons (Switch (Boolean loop-id true) + (Cons (Num loop-id 4) (Cons (Num loop-id 5) (Nil)))) + (Nil)))))) + (run-schedule (saturate type-analysis))"; + let check = ""; + + let _ = crate::run_test(build, check); +} + +#[test] +#[should_panic] +fn loop_args1() { + let build = " + (let ctx (Id 0)) + (let loop-id (Id 1)) + (let l (Loop loop-id (Num ctx 1) (All (Sequential) (Nil)))) + (run-schedule (saturate type-analysis))"; + let check = ""; + + let _ = crate::run_test(build, check); +} + +#[test] +#[should_panic] +fn loop_args3() { + let build = " + (let ctx (Id 0)) + (let loop-id (Id 1)) + (let l (Loop loop-id (Num ctx 1) + (All (Sequential) + (Cons (LessThan (Num loop-id 2) (Num loop-id 3)) + (Cons (Switch (Boolean loop-id true) + (Cons (Num loop-id 4) (Cons (Num loop-id 5) (Nil)))) + (Cons (Num loop-id 1) (Nil))))))) + (run-schedule (saturate type-analysis))"; + let check = ""; + + let _ = crate::run_test(build, check); +} diff --git a/tree_unique_args/src/util.egg b/tree_unique_args/src/util.egg index 3a1ca4b49..fdf6c45fb 100644 --- a/tree_unique_args/src/util.egg +++ b/tree_unique_args/src/util.egg @@ -33,3 +33,6 @@ (Remove b (- i 1)) :when ((> i 0)) :ruleset always-run) + +(function Expr-size (Expr) i64 :merge (min old new)) +(function ListExpr-size (ListExpr) i64 :merge (min old new)) diff --git a/tree_unique_args/src/util.rs b/tree_unique_args/src/util.rs index 146204abe..a8df7d90e 100644 --- a/tree_unique_args/src/util.rs +++ b/tree_unique_args/src/util.rs @@ -1,3 +1,50 @@ +use crate::ir::{Constructor, Purpose}; +use std::iter; +use strum::IntoEnumIterator; + +fn ast_size_for_ctor(ctor: Constructor) -> String { + let ctor_pattern = ctor.construct(|field| field.var()); + let ruleset = " :ruleset always-run"; + match ctor { + // List itself don't count size + Constructor::Nil => format!("(rule ({ctor_pattern}) ((set (ListExpr-size {ctor_pattern}) 0)) {ruleset})"), + Constructor::Cons => format!("(rule ((= list (Cons expr xs)) (= a (Expr-size expr)) (= b (ListExpr-size xs))) ((set (ListExpr-size list) (+ a b))){ruleset})"), + // let Get and All's size = children's size (I prefer not +1 here) + Constructor::Get => format!("(rule ((= expr (Get tup i)) (= n (Expr-size tup))) ((set (Expr-size expr) n)) {ruleset})"), + Constructor::All => format!("(rule ((= expr (All ord list)) (= n (ListExpr-size list))) ((set (Expr-size expr) n)) {ruleset})"), + _ => { + let field_pattern = ctor.filter_map_fields(|field| { + let sort = field.sort().name(); + let var = field.var(); + match field.purpose { + Purpose::CapturedExpr + | Purpose::SubExpr + | Purpose::SubListExpr => + Some(format!("({sort}-size {var})")), + _ => None + } + }); + + let len = field_pattern.len(); + let result_str = field_pattern.join(" "); + + match len { + // Num, Bool Arg, UnitExpr for 0 + 0 => format!("(rule ((= expr {ctor_pattern})) ((set (Expr-size expr) 1)) {ruleset})"), + 1 => format!("(rule ((= expr {ctor_pattern}) (= n {result_str})) ((set (Expr-size expr) (+ 1 n))){ruleset})"), + 2 => format!("(rule ((= expr {ctor_pattern}) (= sum (+ {result_str}))) ((set (Expr-size expr) (+ 1 sum))){ruleset})"), + _ => panic!("Unimplemented") // we don't have ast take three Expr + } + }, + } +} + +pub(crate) fn rules() -> Vec { + iter::once(include_str!("util.egg").to_string()) + .chain(Constructor::iter().map(ast_size_for_ctor)) + .collect::>() +} + #[test] fn test_list_util() -> Result<(), egglog::Error> { let build = &*" @@ -35,3 +82,49 @@ fn append_test() -> Result<(), egglog::Error> { crate::run_test(build, check) } + +#[test] +fn ast_size_test() -> Result<(), egglog::Error> { + let build = " + (let id1 (Id (i64-fresh!))) + (let id-outer (Id (i64-fresh!))) + (let inv + (Sub (Get (Arg id1) 4) + (Mul (Get (Arg id1) 2) + (Switch (Num id1 1) (list4 (Num id1 1) + (Num id1 2) + (Num id1 3) + (Num id1 4)) + ) + ) + )) + + (let loop + (Loop id1 + (All (Parallel) (list5 (Num id-outer 0) + (Num id-outer 1) + (Num id-outer 2) + (Num id-outer 3) + (Num id-outer 4))) + (All (Sequential) (Pair + ; pred + (LessThan (Get (Arg id1) 0) (Get (Arg id1) 4)) + ; output + (All (Parallel) + (list5 + (Add (Get (Arg id1) 0) + inv + ) + (Get (Arg id1) 1) + (Get (Arg id1) 2) + (Get (Arg id1) 3) + (Get (Arg id1) 4) )))))) + "; + + let check = " + (check (= 10 (Expr-size inv))) + (check (= 25 (Expr-size loop))) + "; + + crate::run_test(build, check) +}