diff --git a/tree_unique_args/src/ast.rs b/tree_unique_args/src/ast.rs index def00792d..662e04ada 100644 --- a/tree_unique_args/src/ast.rs +++ b/tree_unique_args/src/ast.rs @@ -55,10 +55,6 @@ pub fn tfalse() -> Expr { Boolean(false) } -pub fn unit() -> Expr { - Unit -} - pub fn add(a: Expr, b: Expr) -> Expr { Add(Box::new(a), Box::new(b)) } @@ -99,12 +95,12 @@ pub fn print(a: Expr) -> Expr { Print(Box::new(a)) } -pub fn sequence(args: Vec) -> Expr { - All(Order::Sequential, args) +pub fn sequence(id: Id, args: Vec) -> Expr { + All(id, Order::Sequential, args) } -pub fn parallel(args: Vec) -> Expr { - All(Order::Parallel, args) +pub fn parallel(id: Id, args: Vec) -> Expr { + All(id, Order::Parallel, args) } pub fn switch(arg: Expr, cases: Vec) -> Expr { diff --git a/tree_unique_args/src/body_contains.rs b/tree_unique_args/src/body_contains.rs index a59e6ac3a..de5a3dc8a 100644 --- a/tree_unique_args/src/body_contains.rs +++ b/tree_unique_args/src/body_contains.rs @@ -66,7 +66,7 @@ fn test_body_contains() -> Result<(), egglog::Error> { (let loop (Loop id1 (Num id-outer 1) - (All (Sequential) (Pair + (All id1 (Sequential) (Pair ; pred (LessThan (Num id1 2) (Num id1 3)) ; output diff --git a/tree_unique_args/src/deep_copy.rs b/tree_unique_args/src/deep_copy.rs index 41296cca7..0258cfa7f 100644 --- a/tree_unique_args/src/deep_copy.rs +++ b/tree_unique_args/src/deep_copy.rs @@ -60,13 +60,13 @@ fn test_deep_copy() -> Result<(), egglog::Error> { (let id-outer (Id (i64-fresh!))) (let loop (Loop id1 - (All (Parallel) (Pair (Arg id-outer) (Num id-outer 0))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Arg id-outer) (Num id-outer 0))) + (All id1 (Sequential) (Pair ; pred (LessThan (Get (Arg id1) 0) (Get (Arg id1) 1)) ; output (Let id2 - (All (Parallel) (Pair + (All id1 (Parallel) (Pair (Add (Get (Arg id1) 0) (Num id1 1)) (Sub (Get (Arg id1) 1) (Num id1 1)))) (Arg id2)))))) @@ -75,13 +75,13 @@ fn test_deep_copy() -> Result<(), egglog::Error> { let check = " (let loop-copied-expected (Loop (Id 4) - (All (Parallel) (Pair (Arg (Id 3)) (Num (Id 3) 0))) - (All (Sequential) (Pair + (All (Id 3) (Parallel) (Pair (Arg (Id 3)) (Num (Id 3) 0))) + (All (Id 4) (Sequential) (Pair ; pred (LessThan (Get (Arg (Id 4)) 0) (Get (Arg (Id 4)) 1)) ; output (Let (Id 5) - (All (Parallel) (Pair + (All (Id 4) (Parallel) (Pair (Add (Get (Arg (Id 4)) 0) (Num (Id 4) 1)) (Sub (Get (Arg (Id 4)) 1) (Num (Id 4) 1)))) (Arg (Id 5))))))) diff --git a/tree_unique_args/src/id_analysis.rs b/tree_unique_args/src/id_analysis.rs index 9c7275b00..3d840778f 100644 --- a/tree_unique_args/src/id_analysis.rs +++ b/tree_unique_args/src/id_analysis.rs @@ -54,21 +54,21 @@ fn test_id_analysis() -> Result<(), egglog::Error> { (ExprIsValid (Let let0-id (Num outer-id 0) - (All + (All let0-id (Parallel) (Pair (Let let1-id (Num let0-id 3) (Boolean let1-id true)) - (UnitExpr let0-id) + (All let0-id (Parallel) (Nil)) )))) "; let check = " (check (ExprHasRefId (Num outer-id 0) outer-id)) (check (ExprHasRefId (Boolean let1-id true) let1-id)) - (check (ExprHasRefId (UnitExpr let0-id) let0-id)) + (check (ExprHasRefId (All let0-id (Parallel) (Nil)) let0-id)) (check (ExprHasRefId (Let let1-id @@ -81,7 +81,7 @@ fn test_id_analysis() -> Result<(), egglog::Error> { let1-id (Num let0-id 3) (Boolean let1-id true)) - (UnitExpr let0-id) + (All let0-id (Parallel) (Nil)) ) let0-id )) @@ -90,13 +90,14 @@ fn test_id_analysis() -> Result<(), egglog::Error> { let0-id (Num outer-id 0) (All + let0-id (Parallel) (Pair (Let let1-id (Num let0-id 3) (Boolean let1-id true)) - (UnitExpr let0-id) + (All let0-id (Parallel) (Nil)) ))) outer-id )) @@ -135,7 +136,7 @@ fn test_id_analysis_listexpr_id_conflict_panics() { let build = " (let id1 (Id (i64-fresh!))) (let id2 (Id (i64-fresh!))) - (let conflict-expr (Cons (Num id1 3) (Cons (UnitExpr id2) (Nil)))) + (let conflict-expr (Cons (Num id1 3) (Cons (All id2 (Parallel) (Nil)) (Nil)))) (ListExprIsValid conflict-expr)"; let check = ""; diff --git a/tree_unique_args/src/interpreter.rs b/tree_unique_args/src/interpreter.rs index 4158a827b..b85b0329d 100644 --- a/tree_unique_args/src/interpreter.rs +++ b/tree_unique_args/src/interpreter.rs @@ -21,7 +21,6 @@ pub fn typecheck(e: &Expr, arg_ty: &Option) -> Result { Expr::Program(_) => panic!("Found non top level program."), Expr::Num(_) => Ok(Type::Num), Expr::Boolean(_) => Ok(Type::Boolean), - Expr::Unit => Ok(Type::Unit), Expr::Add(e1, e2) | Expr::Sub(e1, e2) | Expr::Mul(e1, e2) => { expect_type(e1, Type::Num)?; expect_type(e2, Type::Num)?; @@ -71,7 +70,7 @@ pub fn typecheck(e: &Expr, arg_ty: &Option) -> Result { Expr::Print(e) => { // right now, only print nums expect_type(e, Type::Num)?; - Ok(Type::Unit) + Ok(Type::Tuple(vec![])) } Expr::Read(addr) => { // right now, all memory holds nums. @@ -83,9 +82,9 @@ pub fn typecheck(e: &Expr, arg_ty: &Option) -> Result { Expr::Write(addr, data) => { expect_type(addr, Type::Num)?; expect_type(data, Type::Num)?; - Ok(Type::Unit) + Ok(Type::Tuple(vec![])) } - Expr::All(_, exprs) => { + Expr::All(_, _, exprs) => { let tys = exprs .iter() .map(|expr| typecheck(expr, arg_ty)) @@ -138,7 +137,6 @@ pub fn interpret(e: &Expr, arg: &Option, vm: &mut VirtualMachine) -> Valu Expr::Program(_) => todo!("interpret programs"), Expr::Num(x) => Value::Num(*x), Expr::Boolean(x) => Value::Boolean(*x), - Expr::Unit => Value::Unit, Expr::Add(e1, e2) => { let Value::Num(n1) = interpret(e1, arg, vm) else { panic!("add") @@ -219,7 +217,7 @@ pub fn interpret(e: &Expr, arg: &Option, vm: &mut VirtualMachine) -> Valu panic!("print") }; vm.log.push(n); - Value::Unit + Value::Tuple(vec![]) } Expr::Read(e_addr) => { let Value::Num(addr) = interpret(e_addr, arg, vm) else { @@ -233,9 +231,9 @@ pub fn interpret(e: &Expr, arg: &Option, vm: &mut VirtualMachine) -> Valu }; let data = interpret(e_data, arg, vm); vm.mem.insert(addr as usize, data); - Value::Unit + Value::Tuple(vec![]) } - Expr::All(_, exprs) => { + Expr::All(_, _, exprs) => { // this always executes sequentially (which is a valid way to // execute parallel tuples) let vals = exprs @@ -285,6 +283,7 @@ fn test_interpreter() { Id(0), Box::new(Expr::Num(1)), Box::new(Expr::All( + Id(0), Order::Parallel, vec![ // pred: i < 10 @@ -292,6 +291,7 @@ fn test_interpreter() { // output Expr::Get( Box::new(Expr::All( + Id(0), Order::Parallel, vec![ // i = i + 1 @@ -319,6 +319,7 @@ fn test_interpreter_fib_using_memory() { let nth = 10; let fib_nth = 55; let e = Expr::All( + Id(-1), Order::Sequential, vec![ Expr::Write(Box::new(Expr::Num(0)), Box::new(Expr::Num(0))), @@ -327,6 +328,7 @@ fn test_interpreter_fib_using_memory() { Id(0), Box::new(Expr::Num(2)), Box::new(Expr::All( + Id(0), Order::Parallel, vec![ // pred: i < nth @@ -334,6 +336,7 @@ fn test_interpreter_fib_using_memory() { // output Expr::Get( Box::new(Expr::All( + Id(0), Order::Parallel, vec![ // i = i + 1 @@ -370,8 +373,8 @@ fn test_interpreter_fib_using_memory() { assert_eq!( res, Value::Tuple(vec![ - Value::Unit, - Value::Unit, + Value::Tuple(vec![]), + Value::Tuple(vec![]), Value::Num(nth + 1), Value::Num(fib_nth) ]) @@ -443,7 +446,6 @@ impl std::str::FromStr for Expr { ("Boolean", [_id, egglog::ast::Expr::Lit(egglog::ast::Literal::Bool(b))]) => { Ok(Expr::Boolean(*b)) } - ("UnitExpr", [_id]) => Ok(Expr::Unit), ("Add", [x, y]) => Ok(Expr::Add( Box::new(egglog_expr_to_expr(x)?), Box::new(egglog_expr_to_expr(y)?), @@ -482,7 +484,7 @@ impl std::str::FromStr for Expr { Box::new(egglog_expr_to_expr(x)?), Box::new(egglog_expr_to_expr(y)?), )), - ("All", [egglog::ast::Expr::Call(order, empty), xs]) => { + ("All", [id, egglog::ast::Expr::Call(order, empty), xs]) => { if !empty.is_empty() { return Err(ExprParseError::InvalidOrderArguments); } @@ -491,7 +493,11 @@ impl std::str::FromStr for Expr { "Sequential" => Ok(Order::Sequential), s => Err(ExprParseError::InvalidOrder(s.to_owned())), }?; - Ok(Expr::All(order, list_expr_to_vec(xs)?)) + Ok(Expr::All( + egglog_expr_to_id(id)?, + order, + list_expr_to_vec(xs)?, + )) } ("Switch", [pred, branches]) => Ok(Expr::Switch( Box::new(egglog_expr_to_expr(pred)?), @@ -533,7 +539,7 @@ fn test_expr_parser() { let s = "(Loop (Id 1) (Num (Id 0) 1) -(All (Sequential) +(All (Id 1) (Sequential) (Cons (LessThan (Num (Id 1) 2) (Num (Id 1) 3)) (Cons (Switch (Boolean (Id 1) true) (Cons (Num (Id 1) 4) (Cons (Num (Id 1) 5) (Nil)))) (Nil))))) @@ -543,6 +549,7 @@ fn test_expr_parser() { Id(1), Box::new(Expr::Num(1)), Box::new(Expr::All( + Id(1), Order::Sequential, vec![ Expr::LessThan(Box::new(Expr::Num(2)), Box::new(Expr::Num(3))), diff --git a/tree_unique_args/src/ir.rs b/tree_unique_args/src/ir.rs index 5e5c9e148..234c679dc 100644 --- a/tree_unique_args/src/ir.rs +++ b/tree_unique_args/src/ir.rs @@ -48,7 +48,6 @@ impl ESort { pub(crate) enum Constructor { Num, Boolean, - UnitExpr, Add, Sub, Mul, @@ -123,7 +122,6 @@ impl Constructor { match self { Constructor::Num => "Num", Constructor::Boolean => "Boolean", - Constructor::UnitExpr => "UnitExpr", Constructor::Add => "Add", Constructor::Sub => "Sub", Constructor::Mul => "Mul", @@ -152,7 +150,6 @@ impl Constructor { match self { Constructor::Num => vec![f(ReferencingId, "id"), f(Static(Sort::I64), "n")], Constructor::Boolean => vec![f(ReferencingId, "id"), f(Static(Sort::Bool), "b")], - Constructor::UnitExpr => vec![f(ReferencingId, "id")], Constructor::Add => vec![f(SubExpr, "x"), f(SubExpr, "y")], Constructor::Sub => vec![f(SubExpr, "x"), f(SubExpr, "y")], Constructor::Mul => vec![f(SubExpr, "x"), f(SubExpr, "y")], @@ -166,7 +163,11 @@ impl Constructor { Constructor::Print => vec![f(SubExpr, "printee")], Constructor::Read => vec![f(SubExpr, "addr")], Constructor::Write => vec![f(SubExpr, "addr"), f(SubExpr, "data")], - Constructor::All => vec![f(Static(Sort::Order), "order"), f(SubListExpr, "exprs")], + Constructor::All => vec![ + f(ReferencingId, "id"), + f(Static(Sort::Order), "order"), + f(SubListExpr, "exprs"), + ], Constructor::Switch => vec![f(SubExpr, "pred"), f(SubListExpr, "branches")], Constructor::Loop => vec![ f(CapturingId, "id"), @@ -209,7 +210,6 @@ impl Constructor { match self { Constructor::Num => ESort::Expr, Constructor::Boolean => ESort::Expr, - Constructor::UnitExpr => ESort::Expr, Constructor::Add => ESort::Expr, Constructor::Sub => ESort::Expr, Constructor::Mul => ESort::Expr, diff --git a/tree_unique_args/src/is_valid.rs b/tree_unique_args/src/is_valid.rs index 66e0e4d16..832ceeda9 100644 --- a/tree_unique_args/src/is_valid.rs +++ b/tree_unique_args/src/is_valid.rs @@ -36,12 +36,12 @@ fn test_is_valid() -> Result<(), egglog::Error> { (let id-outer (Id (i64-fresh!))) (let loop (Loop id1 - (All (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) + (All id1 (Sequential) (Pair ; pred (LessThan (Get (Arg id1) 0) (Get (Arg id1) 1)) ; output - (All (Parallel) (Pair + (All id1 (Parallel) (Pair (Add (Get (Arg id1) 0) (Num id1 1)) (Sub (Get (Arg id1) 1) (Num id1 1)))))))) (ExprIsValid loop) diff --git a/tree_unique_args/src/ivt.egg b/tree_unique_args/src/ivt.egg index b8fe2d171..bbd077393 100644 --- a/tree_unique_args/src/ivt.egg +++ b/tree_unique_args/src/ivt.egg @@ -29,7 +29,7 @@ (rule ((= loop (Loop id in out)) (ExprHasRefId loop outer-id) (ExprIsValid loop) - (= out (All ord (Cons pred (Cons switch (Nil))))) + (= out (All id ord (Cons pred (Cons switch (Nil))))) (= switch (Switch pred branches)) (ExprIsPure pred)) ((LiftSwitch switch in outer-id)) :ruleset ivt) @@ -37,13 +37,13 @@ ; Apply the rule (rule ((= loop (Loop id in out)) (ExprIsValid loop) - (= out (All ord (Cons pred (Cons switch (Nil))))) + (= out (All id ord (Cons pred (Cons switch (Nil))))) (= switch (Switch pred (Cons thn* (Cons els* (Nil))))) ; NB: we don't constrain 'outer-id' here in part because there can only ; be _one_ outer-id that is referenced by it. (See the panic in the rules ; for *HasRefId). (= (Switch pred_ (Cons thn (Cons els (Nil)))) (LiftSwitch switch in outer-id))) ((let new-id (Id (i64-fresh!))) - (let inner (NewLoop new-id in (All ord (Pair pred thn*)))) + (let inner (NewLoop new-id in (All new-id ord (Pair pred thn*)))) (let outer (Switch pred_ (Cons inner (Cons els (Nil))))) (union loop outer)) :ruleset ivt) \ No newline at end of file diff --git a/tree_unique_args/src/ivt.rs b/tree_unique_args/src/ivt.rs index 384c917f5..e47ba7e3d 100644 --- a/tree_unique_args/src/ivt.rs +++ b/tree_unique_args/src/ivt.rs @@ -28,7 +28,7 @@ fn basic_ivt() -> Result { (let switch (Switch pred (Pair (Print (Num loop-id 0)) (Print (Num loop-id 1))))) - (let loop (Loop loop-id (Arg outer-id) (All (Sequential) (Pair pred switch)))) + (let loop (Loop loop-id (Arg outer-id) (All loop-id (Sequential) (Pair pred switch)))) (ExprIsValid loop)"; let check = " (check (= loop @@ -36,7 +36,7 @@ fn basic_ivt() -> Result { (LessThan (Arg outer-id) (Num outer-id 1)) (Cons (Loop new-id (Arg outer-id) - (All (Sequential) + (All new-id (Sequential) (Cons (LessThan (Arg new-id) (Num new-id 1)) (Cons (Print (Num new-id 0)) (Nil))))) (Cons (Print (Num outer-id 1)) (Nil))))))"; diff --git a/tree_unique_args/src/lib.rs b/tree_unique_args/src/lib.rs index d5def9048..566304ba6 100644 --- a/tree_unique_args/src/lib.rs +++ b/tree_unique_args/src/lib.rs @@ -31,7 +31,6 @@ pub struct Id(i64); pub enum Expr { Num(i64), Boolean(bool), - Unit, Add(Box, Box), Sub(Box, Box), Mul(Box, Box), @@ -48,7 +47,7 @@ pub enum Expr { Print(Box), Read(Box), Write(Box, Box), - All(Order, Vec), + All(Id, Order, Vec), Switch(Box, Vec), Loop(Id, Box, Box), Let(Id, Box, Box), @@ -64,7 +63,7 @@ impl Expr { /// Runs `func` on every child of this expression. pub fn for_each_child(&mut self, mut func: impl FnMut(&mut Expr)) { match self { - Expr::Num(_) | Expr::Boolean(_) | Expr::Unit | Expr::Arg(_) => {} + Expr::Num(_) | Expr::Boolean(_) | Expr::Arg(_) => {} Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) @@ -82,7 +81,7 @@ impl Expr { Expr::Get(a, _) | Expr::Function(_, a) | Expr::Call(_, a) => { func(a); } - Expr::All(_, children) => { + Expr::All(_, _, children) => { for child in children { func(child); } @@ -110,7 +109,6 @@ impl Expr { pub enum Value { Num(i64), Boolean(bool), - Unit, Tuple(Vec), } @@ -118,7 +116,6 @@ pub enum Value { pub enum Type { Num, Boolean, - Unit, Tuple(Vec), } @@ -141,7 +138,7 @@ pub fn run_test(build: &str, check: &str) -> Result { &subst::subst_rules().join("\n"), &deep_copy::deep_copy_rules().join("\n"), include_str!("sugar.egg"), - include_str!("util.egg"), + &util::rules().join("\n"), &id_analysis::id_analysis_rules().join("\n"), // optimizations include_str!("simple.egg"), diff --git a/tree_unique_args/src/loop_invariant.rs b/tree_unique_args/src/loop_invariant.rs index 2d5509a9f..c4b7d35c1 100644 --- a/tree_unique_args/src/loop_invariant.rs +++ b/tree_unique_args/src/loop_invariant.rs @@ -1,6 +1,6 @@ +use crate::ir::{Constructor, Purpose}; use std::iter; use strum::IntoEnumIterator; -use crate::ir::{Constructor, Purpose}; fn is_inv_base_case_for_ctor(ctor: Constructor) -> Option { let br = "\n "; @@ -13,7 +13,7 @@ fn is_inv_base_case_for_ctor(ctor: Constructor) -> Option { {br} (arg-inv loop i)) \ {br}((set (is-inv-Expr loop expr) true)){ruleset})" )), - Constructor::Num | Constructor::Boolean | Constructor::UnitExpr => { + Constructor::Num | Constructor::Boolean => { let ctor_pattern = ctor.construct(|field| field.var()); Some(format!( "(rule ((BodyContainsExpr loop expr) \ @@ -37,7 +37,6 @@ fn is_invariant_rule_for_ctor(ctor: Constructor) -> Option { // assume Arg as whole is not invariant Constructor::Cons | Constructor::Nil - | Constructor::UnitExpr | Constructor::Num | Constructor::Boolean | Constructor::Print @@ -88,12 +87,12 @@ fn loop_invariant_detection1() -> Result<(), egglog::Error> { (let id-outer (Id (i64-fresh!))) (let loop (Loop id1 - (All (Parallel) (Pair (Num id-outer 0) (Num id-outer 5))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Num id-outer 0) (Num id-outer 5))) + (All id1 (Sequential) (Pair ; pred (LessThan (Get (Arg id1) 0) (Get (Arg id1) 1)) ; output - (All (Parallel) + (All id1 (Parallel) (Pair (Get (Arg id1) 0) (Sub (Get (Arg id1) 1) (Add (Num id1 1) (Get (Arg id1) 0))) )))))) @@ -129,16 +128,16 @@ fn loop_invariant_detection2() -> Result<(), egglog::Error> { (let loop (Loop id1 - (All (Parallel) (list5 (Num id-outer 0) + (All id-outer (Parallel) (list5 (Num id-outer 0) (Num id-outer 1) (Num id-outer 2) (Num id-outer 3) (Num id-outer 4))) - (All (Sequential) (Pair + (All id1 (Sequential) (Pair ; pred (LessThan (Get (Arg id1) 0) (Get (Arg id1) 4)) ; output - (All (Parallel) + (All id1 (Parallel) (list5 (Add (Get (Arg id1) 0) inv diff --git a/tree_unique_args/src/purity_analysis.rs b/tree_unique_args/src/purity_analysis.rs index 8a864b03f..44ae9b2fd 100644 --- a/tree_unique_args/src/purity_analysis.rs +++ b/tree_unique_args/src/purity_analysis.rs @@ -6,8 +6,8 @@ use strum::IntoEnumIterator; fn is_pure(ctor: &Constructor) -> bool { use Constructor::*; match ctor { - Num | Boolean | UnitExpr | Add | Sub | Mul | LessThan | And | Or | Not | Get | All - | Switch | Loop | Let | Arg | Call | Cons | Nil => true, + Num | Boolean | Add | Sub | Mul | LessThan | And | Or | Not | Get | All | Switch | Loop + | Let | Arg | Call | Cons | Nil => true, Print | Read | Write => false, } } @@ -67,26 +67,26 @@ fn test_purity_analysis() -> Result<(), egglog::Error> { (let id-outer (Id (i64-fresh!))) (let pure-loop (Loop id1 - (All (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) + (All id1 (Sequential) (Pair ; pred (LessThan (Get (Arg id1) 0) (Get (Arg id1) 1)) ; output - (All (Parallel) (Pair + (All id1 (Parallel) (Pair (Add (Get (Arg id1) 0) (Num id1 1)) (Sub (Get (Arg id1) 1) (Num id1 1)))))))) (let id2 (Id (i64-fresh!))) (let impure-loop (Loop id2 - (All (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) + (All id2 (Sequential) (Pair ; pred (LessThan (Get (Arg id2) 0) (Get (Arg id2) 1)) ; output - (IgnoreFirst + (IgnoreFirst id2 (Print (Num id2 1)) - (All (Parallel) (Pair + (All id2 (Parallel) (Pair (Add (Get (Arg id2) 0) (Num id2 1)) (Sub (Get (Arg id2) 1) (Num id2 1))))))))) " @@ -117,7 +117,7 @@ fn test_purity_function() -> Result<(), egglog::Error> { (let f2 (Function id_fun2 (Get - (All (Sequential) + (All id_fun2 (Sequential) (Pair (Print (Get (Arg id_fun2) 0)) (Add @@ -126,25 +126,25 @@ fn test_purity_function() -> Result<(), egglog::Error> { 1))) (let pure-loop (Loop id1 - (All (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) + (All id1 (Sequential) (Pair ; pred (LessThan (Get (Arg id1) 0) (Get (Arg id1) 1)) ; output - (All (Parallel) + (All id1 (Parallel) (Pair - (Add (Call id_fun1 (All (Sequential) (Cons (Get (Arg id1) 0) (Nil)))) (Num id1 1)) + (Add (Call id_fun1 (All id1 (Sequential) (Cons (Get (Arg id1) 0) (Nil)))) (Num id1 1)) (Sub (Get (Arg id1) 1) (Num id1 1)))))))) (let impure-loop (Loop id2 - (All (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Num id-outer 0) (Num id-outer 0))) + (All id2 (Sequential) (Pair ; pred (LessThan (Get (Arg id2) 0) (Get (Arg id2) 1)) ; output - (All (Parallel) + (All id2 (Parallel) (Pair - (Add (Call id_fun2 (All (Sequential) (Cons (Get (Arg id2) 0) (Nil)))) (Num id2 1)) + (Add (Call id_fun2 (All id2 (Sequential) (Cons (Get (Arg id2) 0) (Nil)))) (Num id2 1)) (Sub (Get (Arg id2) 1) (Num id2 1)))))))) " .to_string(); diff --git a/tree_unique_args/src/schema.egg b/tree_unique_args/src/schema.egg index c0f553385..a22698660 100644 --- a/tree_unique_args/src/schema.egg +++ b/tree_unique_args/src/schema.egg @@ -41,7 +41,7 @@ (datatype Order (Parallel) (Sequential)) ; Perform a list of operations. Only way to create a tuple! -(function All (Order ListExpr) Expr) +(function All (IdSort Order ListExpr) Expr) ; Switch on a list of lazily-evaluated branches. Doesn't create context ; pred branches chosen diff --git a/tree_unique_args/src/subst.rs b/tree_unique_args/src/subst.rs index 0a1916224..ffccb7bda 100644 --- a/tree_unique_args/src/subst.rs +++ b/tree_unique_args/src/subst.rs @@ -52,12 +52,12 @@ fn test_subst() -> Result<(), egglog::Error> { (let id-outer (Id (i64-fresh!))) (let loop1 (Loop id1 - (All (Parallel) (Pair (Arg id-outer) (Num id-outer 0))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Arg id-outer) (Num id-outer 0))) + (All id1 (Sequential) (Pair ; pred (LessThan (Get (Arg id1) 0) (Get (Arg id1) 1)) ; output - (All (Parallel) (Pair + (All id1 (Parallel) (Pair (Add (Get (Arg id1) 0) (Num id1 1)) (Sub (Get (Arg id1) 1) (Num id1 1)))))))) (let loop1-substed (SubstExpr loop1 (Num id-outer 7))) @@ -66,12 +66,12 @@ fn test_subst() -> Result<(), egglog::Error> { let check = " (let loop1-substed-expected (Loop id1 - (All (Parallel) (Pair (Num id-outer 7) (Num id-outer 0))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Num id-outer 7) (Num id-outer 0))) + (All id1 (Sequential) (Pair ; pred (LessThan (Get (Arg id1) 0) (Get (Arg id1) 1)) ; output - (All (Parallel) (Pair + (All id1 (Parallel) (Pair (Add (Get (Arg id1) 0) (Num id1 1)) (Sub (Get (Arg id1) 1) (Num id1 1)))))))) (run-schedule (saturate always-run)) diff --git a/tree_unique_args/src/sugar.egg b/tree_unique_args/src/sugar.egg index eda9206bf..4b7321631 100644 --- a/tree_unique_args/src/sugar.egg +++ b/tree_unique_args/src/sugar.egg @@ -16,10 +16,10 @@ (rewrite (list5 a b c d e) (Cons a (Cons b (Cons c (Cons d (Cons e (Nil)))))) :ruleset always-run) -(function IgnoreFirst (Expr Expr) Expr) -(rewrite (IgnoreFirst a b) +(function IgnoreFirst (IdSort Expr Expr) Expr) +(rewrite (IgnoreFirst id a b) (Get - (All (Sequential) (Cons a (Cons b (Nil)))) + (All id (Sequential) (Cons a (Cons b (Nil)))) 1) :ruleset always-run) diff --git a/tree_unique_args/src/switch_rewrites.rs b/tree_unique_args/src/switch_rewrites.rs index e3c9e45ea..c416ce594 100644 --- a/tree_unique_args/src/switch_rewrites.rs +++ b/tree_unique_args/src/switch_rewrites.rs @@ -33,8 +33,8 @@ pub(crate) fn rules() -> String { :ruleset switch-rewrites) ; (if E then S1 else S2); S3 ==> if E then S1;S3 else S2;S3 - (rewrite (All ord (Cons (Switch e (Cons S1 (Cons S2 (Nil)))) S3)) - (Switch e (Cons (All ord (Cons S1 S3)) (Cons (All ord (Cons S2 S3)) (Nil)))) + (rewrite (All id ord (Cons (Switch e (Cons S1 (Cons S2 (Nil)))) S3)) + (Switch e (Cons (All id ord (Cons S1 S3)) (Cons (All id ord (Cons S2 S3)) (Nil)))) :ruleset switch-rewrites) {rules_needing_purity}" @@ -80,7 +80,7 @@ fn switch_rewrite_purity() -> crate::Result { let build = " (let switch-id (Id (i64-fresh!))) (let let-id (Id (i64-fresh!))) -(let impure (Let let-id (UnitExpr switch-id) (All (Sequential) (Pair (Boolean let-id true) (Print (Num let-id 1)))))) +(let impure (Let let-id (All switch-id (Parallel) (Nil)) (All let-id (Sequential) (Pair (Boolean let-id true) (Print (Num let-id 1)))))) (let switch (Switch (And (Boolean switch-id false) (Get impure 0)) (Pair (Num switch-id 1) (Num switch-id 2)))) (ExprIsValid switch) @@ -96,7 +96,7 @@ fn switch_rewrite_purity() -> crate::Result { let build = " (let switch-id (Id (i64-fresh!))) (let let-id (Id (i64-fresh!))) -(let impure (Let let-id (UnitExpr switch-id) (All (Sequential) (Cons (Boolean let-id true) (Nil))))) +(let impure (Let let-id (All switch-id (Parallel) (Nil)) (All let-id (Sequential) (Cons (Boolean let-id true) (Nil))))) (let switch (Switch (And (Boolean switch-id false) (Get impure 0)) (Pair (Num switch-id 1) (Num switch-id 2)))) (ExprIsValid switch) @@ -140,11 +140,11 @@ fn switch_pull_in_below() -> Result<(), egglog::Error> { (let s3 (Read (Num id 6))) (let switch (Switch c (Cons s1 (Cons s2 (Nil))))) - (let lhs (All (Sequential) (Cons switch (Cons s3 (Nil))))) + (let lhs (All id (Sequential) (Cons switch (Cons s3 (Nil))))) "; let check = " - (let s1s3 (All (Sequential) (Cons s1 (Cons s3 (Nil))))) - (let s2s3 (All (Sequential) (Cons s2 (Cons s3 (Nil))))) + (let s1s3 (All id (Sequential) (Cons s1 (Cons s3 (Nil))))) + (let s2s3 (All id (Sequential) (Cons s2 (Cons s3 (Nil))))) (let expected (Switch c (Cons s1s3 (Cons s2s3 (Nil))))) (check (= lhs expected)) "; diff --git a/tree_unique_args/src/type_analysis.egg b/tree_unique_args/src/type_analysis.egg index 1049e8fba..caed2d306 100644 --- a/tree_unique_args/src/type_analysis.egg +++ b/tree_unique_args/src/type_analysis.egg @@ -5,7 +5,6 @@ (datatype Type (IntT) (BoolT) - (UnitT) (FuncT Type Type) (TupleT TypeList) ) @@ -26,71 +25,20 @@ (rule ((= (TypeList-suffix list n) (TNil))) ((set (TypeList-length list) n)) :ruleset type-analysis) -(relation HasTypeDemand (Expr)) - (relation HasType (Expr Type)) ; Primitives -(rule ((HasTypeDemand (Num id n))) +(rule ((Num id n)) ((HasType (Num id n) (IntT))) :ruleset type-analysis) -(rule ((HasTypeDemand (Boolean id b))) +(rule ((Boolean id b)) ((HasType (Boolean id b) (BoolT))) :ruleset type-analysis) -(rule ((HasTypeDemand (UnitExpr id))) - ((HasType (UnitExpr id) (UnitT))) - :ruleset type-analysis) - - -; Pure Op Demand -(rule ((HasTypeDemand (Add x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (Sub x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (Mul x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (LessThan x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (And x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (Or x y))) - ( - (HasTypeDemand x) - (HasTypeDemand y) - ) - :ruleset type-analysis) -(rule ((HasTypeDemand (Not x))) - ((HasTypeDemand x)) - :ruleset type-analysis) -(rule ((HasTypeDemand (Get e idx))) - ((HasTypeDemand e)) - :ruleset type-analysis) - ; Pure Op Compute (rule ( - (HasTypeDemand (Add x y)) + (Add x y) (HasType x (IntT)) (HasType y (IntT)) ) @@ -99,7 +47,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Sub x y)) + (Sub x y) (HasType x (IntT)) (HasType y (IntT)) ) @@ -108,7 +56,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Mul x y)) + (Mul x y) (HasType x (IntT)) (HasType y (IntT)) ) @@ -117,7 +65,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (LessThan x y)) + (LessThan x y) (HasType x (IntT)) (HasType y (IntT)) ) @@ -126,7 +74,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (And x y)) + (And x y) (HasType x (BoolT)) (HasType y (BoolT)) ) @@ -135,7 +83,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Or x y)) + (Or x y) (HasType x (BoolT)) (HasType y (BoolT)) ) @@ -144,7 +92,7 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Not x)) + (Not x) (HasType x (BoolT)) ) ( @@ -152,30 +100,22 @@ ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Get e n)) + (Get e n) (HasType e (TupleT tylist)) ) ((HasType (Get e n) (TypeList-ith tylist n))) :ruleset type-analysis) ; Effectful Ops -(rule ((HasTypeDemand (Print e))) - ((HasType (Print e) (UnitT))) +(rule ((Print e)) + ((HasType (Print e) (TupleT (TNil)))) :ruleset type-analysis) ; TODO: Read and Write (requires type annotations) ; Switch ; if the condition is a boolean, it must have exactly two branches -(rule ((HasTypeDemand (Switch cond (Cons A (Cons B (Nil)))))) - ( - (HasTypeDemand cond) - (HasTypeDemand A) - (HasTypeDemand B) - ) - :ruleset type-analysis) (rule ( (= switch (Switch cond (Cons A (Cons B (Nil))))) - (HasTypeDemand switch) (HasType cond (BoolT)) (HasType A ty) (HasType B ty) @@ -185,21 +125,13 @@ ; Otherwise, the condition must be an integer, and we can have any number of branches. -; peel off a branch and demand type -(rule ((HasTypeDemand (Switch cond (Cons branch rest)))) - ( - (HasTypeDemand branch) - (HasTypeDemand (Switch cond rest)) - ) - :ruleset type-analysis) -; base case- demand the type of the condition -(rule ((HasTypeDemand (Switch cond Nil))) - ((HasTypeDemand cond)) +(rule ((Switch cond (Cons branch rest))) + ((Switch cond rest)) ; peel off a branch for type checking :ruleset type-analysis) ; base case- single branch switch has type of branch (rule ( - (HasTypeDemand (Switch cond (Cons branch (Nil)))) + (Switch cond (Cons branch (Nil))) ; boolean condition handled above, now we must have an integer condition (HasType cond (IntT)) (HasType branch ty) @@ -208,7 +140,7 @@ :ruleset type-analysis) ; recursive case (rule ( - (HasTypeDemand (Switch cond (Cons branch rest))) + (Switch cond (Cons branch rest)) (HasType (Switch cond rest) ty) ; make sure the condition is an integer ; (prevents us from typing boolean switches with >2 branches) @@ -219,25 +151,20 @@ :ruleset type-analysis) ; Sequencing -(rule ((HasTypeDemand (All ord (Cons hd tl)))) - ( - (HasTypeDemand hd) - (HasTypeDemand (All ord tl)) - ) +(rule ((All id ord (Cons hd tl))) + ((All id ord tl)) ; peel off a layer for type checking :ruleset type-analysis) ; base case: Nil -(rule ( - (HasTypeDemand (All ord (Nil))) - ) - ((HasType (All ord (Nil)) (TupleT (TNil)))) +(rule ((All id ord (Nil))) + ((HasType (All id ord (Nil)) (TupleT (TNil)))) :ruleset type-analysis) ; rec case (rule ( - (HasTypeDemand (All ord (Cons hd tl))) + (All id ord (Cons hd tl)) (HasType hd ty) - (HasType (All ord tl) (TupleT tylist)) + (HasType (All id ord tl) (TupleT tylist)) ) - ((HasType (All ord (Cons hd tl)) (TupleT (TCons ty tylist)))) + ((HasType (All id ord (Cons hd tl)) (TupleT (TCons ty tylist)))) :ruleset type-analysis) ; If an expr has two different types, panic @@ -251,23 +178,57 @@ ; Lets -(rule ((HasTypeDemand (Let id in out))) - ((HasTypeDemand in)) - :ruleset type-analysis) (rule ( - (HasTypeDemand (Let id in out)) + (Let id in out) (HasType in ty) ) ( (HasType (Arg id) ty) ; assert the let's argument has type ty in the let's context - (HasTypeDemand out) ; demand the type of out in the let's context ) :ruleset type-analysis) (rule ( - (HasTypeDemand (Let id in out)) + (Let id in out) (HasType out ty) ) ((HasType (Let id in out) ty)) + :ruleset type-analysis) + +; Loops + +(rule ( + (Loop id in pred-out) + (HasType in ty) + ) + ( + (HasType (Arg id) ty) ; assert the argument has type ty in the loop's context + ) + :ruleset type-analysis) + +(rule ( + (Loop id in pred-out) + (HasType in ty) ; input type + ; pred-out must be a two-element tuple. + ; pred must be boolean, output type must match input type + (HasType pred-out (TupleT (TCons (BoolT) (TCons ty (TNil))))) + ) + ((HasType (Loop id in pred-out) ty)) ; whole loop has type of output + :ruleset type-analysis) + +(rule ( + (Loop id in pred-out) + (HasType pred-out (TupleT (TCons pred-ty rest))) + (!= pred-ty (BoolT)) + ) + ((panic "Loop predicate was not a boolean")) + :ruleset type-analysis) + +(rule ( + (Loop id in pred-out) + (HasType in in-ty) + (HasType pred-out (TupleT lst)) + (!= (TypeList-length lst) 2) + ) + ((panic "Loop did not get two arguments (predicate and output)")) :ruleset type-analysis) \ No newline at end of file diff --git a/tree_unique_args/src/type_analysis.rs b/tree_unique_args/src/type_analysis.rs index 070f18f88..2b950998a 100644 --- a/tree_unique_args/src/type_analysis.rs +++ b/tree_unique_args/src/type_analysis.rs @@ -9,8 +9,6 @@ fn simple_types() -> Result<(), egglog::Error> { (let x (LessThan m n)) (let y (Not x)) (let z (And x (Or y y))) - (HasTypeDemand s) - (HasTypeDemand z) "; let check = " (run-schedule (saturate type-analysis)) @@ -36,8 +34,6 @@ fn switch_boolean() -> Result<(), egglog::Error> { (Cons (Add n1 n1) (Cons (Sub n1 n2) (Nil))))) (let wrong_switch (Switch b1 (Cons n1 (Cons n2 (Cons n1 (Nil)))))) - (HasTypeDemand switch) - (HasTypeDemand wrong_switch) "; let check = " (run-schedule (saturate type-analysis)) @@ -62,16 +58,13 @@ fn switch_int() -> Result<(), egglog::Error> { (Switch (Mul n1 n2) (Cons (LessThan n3 n4) (Nil)))) (let s3 (Switch (Sub n2 n2) (Cons (Print n1) (Cons (Print n4) (Cons (Print n3) (Nil)))))) - (HasTypeDemand s1) - (HasTypeDemand s2) - (HasTypeDemand s3) "; let check = " (run-schedule (saturate type-analysis)) (check (HasType s1 (IntT))) (check (HasType s2 (BoolT))) - (check (HasType s3 (UnitT))) + (check (HasType s3 (TupleT (TNil)))) "; crate::run_test(build, check) } @@ -79,28 +72,22 @@ fn switch_int() -> Result<(), egglog::Error> { #[test] fn tuple() -> Result<(), egglog::Error> { let build = " - (let n (Add (Num (Id (i64-fresh!)) 1) (Num (Id (i64-fresh!)) 2))) + (let id (Id (i64-fresh!))) + (let n (Add (Num id 1) (Num id 2))) (let m (Mul n n)) (let s (Sub n m)) (let x (LessThan m n)) (let y (Not x)) (let z (And x (Or y y))) - (let tup1 (All (Sequential) (Nil))) - (let tup2 (All (Sequential) (Cons z (Nil)))) - (let tup3 (All (Parallel) (Cons x (Cons m (Nil))))) - (let tup4 (All (Parallel) (Cons tup2 (Cons tup3 (Nil))))) - (HasTypeDemand tup1) - (HasTypeDemand tup2) - (HasTypeDemand tup3) - (HasTypeDemand tup4) + (let tup1 (All id (Sequential) (Nil))) + (let tup2 (All id (Sequential) (Cons z (Nil)))) + (let tup3 (All id (Parallel) (Cons x (Cons m (Nil))))) + (let tup4 (All id (Parallel) (Cons tup2 (Cons tup3 (Nil))))) (let get1 (Get tup3 0)) (let get2 (Get tup3 1)) (let get3 (Get (Get tup4 1) 1)) - (HasTypeDemand get1) - (HasTypeDemand get2) - (HasTypeDemand get3) "; let check = " (run-schedule (saturate type-analysis)) @@ -126,15 +113,13 @@ fn lets() -> Result<(), egglog::Error> { (let let-id (Id (i64-fresh!))) (let outer-ctx (Id (i64-fresh!))) (let l (Let let-id (Num outer-ctx 5) (Add (Arg let-id) (Arg let-id)))) - (HasTypeDemand l) (let outer (Id (i64-fresh!))) (let inner (Id (i64-fresh!))) (let ctx (Id (i64-fresh!))) (let nested (Let outer (Num ctx 3) - (Let inner (All (Parallel) (Cons (Arg outer) (Cons (Num outer 2) (Nil)))) + (Let inner (All ctx (Parallel) (Cons (Arg outer) (Cons (Num outer 2) (Nil)))) (Add (Get (Arg inner) 0) (Get (Arg inner) 1))))) - (HasTypeDemand nested) "; let check = " (run-schedule (saturate type-analysis)) @@ -143,3 +128,71 @@ fn lets() -> Result<(), egglog::Error> { "; crate::run_test(build, check) } + +#[test] +fn loops() -> Result<(), egglog::Error> { + let build = " + (let ctx (Id 0)) + (let loop-id (Id 1)) + (let l (Loop loop-id (Num ctx 1) + (All loop-id (Sequential) + (Cons (LessThan (Num loop-id 2) (Num loop-id 3)) + (Cons (Switch (Boolean loop-id true) + (Cons (Num loop-id 4) (Cons (Num loop-id 5) (Nil)))) + (Nil)))))) + "; + let check = " + (run-schedule (saturate type-analysis)) + (check (HasType l (IntT))) + "; + crate::run_test(build, check) +} + +#[test] +#[should_panic] +fn loop_pred_boolean() { + let build = " + (let ctx (Id 0)) + (let loop-id (Id 1)) + (let l (Loop loop-id (Num ctx 1) + (All loop-id (Sequential) + (Cons (Add (Num loop-id 2) (Num loop-id 3)) + (Cons (Switch (Boolean loop-id true) + (Cons (Num loop-id 4) (Cons (Num loop-id 5) (Nil)))) + (Nil)))))) + (run-schedule (saturate type-analysis))"; + let check = ""; + + let _ = crate::run_test(build, check); +} + +#[test] +#[should_panic] +fn loop_args1() { + let build = " + (let ctx (Id 0)) + (let loop-id (Id 1)) + (let l (Loop loop-id (Num ctx 1) (All loop-id (Sequential) (Nil)))) + (run-schedule (saturate type-analysis))"; + let check = ""; + + let _ = crate::run_test(build, check); +} + +#[test] +#[should_panic] +fn loop_args3() { + let build = " + (let ctx (Id 0)) + (let loop-id (Id 1)) + (let l (Loop loop-id (Num ctx 1) + (All loop-id (Sequential) + (Cons (LessThan (Num loop-id 2) (Num loop-id 3)) + (Cons (Switch (Boolean loop-id true) + (Cons (Num loop-id 4) (Cons (Num loop-id 5) (Nil)))) + (Cons (Num loop-id 1) (Nil))))))) + (run-schedule (saturate type-analysis))"; + let check = ""; + + let _ = crate::run_test(build, check); +} diff --git a/tree_unique_args/src/util.egg b/tree_unique_args/src/util.egg index b5699c04e..f99394d49 100644 --- a/tree_unique_args/src/util.egg +++ b/tree_unique_args/src/util.egg @@ -3,8 +3,8 @@ (function ListExpr-suffix (ListExpr i64) ListExpr :unextractable) (function Append (ListExpr Expr) ListExpr :unextractable) -(rule ((All order top)) ((union (ListExpr-suffix top 0) top)) :ruleset always-run) (rule ((Switch pred top)) ((union (ListExpr-suffix top 0) top)) :ruleset always-run) +(rule ((All id order top)) ((union (ListExpr-suffix top 0) top)) :ruleset always-run) (rule ((= (ListExpr-suffix top n) (Cons hd tl))) ((union (ListExpr-ith top n) hd) @@ -23,7 +23,9 @@ ;; get the ith output of a loop (function get-loop-outputs-ith (Expr i64) Expr :unextractable) (rule ((= loop (Loop id inputs pred-outputs)) - (= pred-outputs (All ord1 pred-out-list)) - (= (All ord2 outputs-list) (ListExpr-ith pred-out-list 1)) + (= pred-outputs (All id1 ord1 pred-out-list)) + (= (All id2 ord2 outputs-list) (ListExpr-ith pred-out-list 1)) (= ith-outputs (ListExpr-ith outputs-list i))) - ((union (get-loop-outputs-ith loop i) ith-outputs)) :ruleset always-run) \ No newline at end of file + ((union (get-loop-outputs-ith loop i) ith-outputs)) :ruleset always-run) +(function Expr-size (Expr) i64 :merge (min old new)) +(function ListExpr-size (ListExpr) i64 :merge (min old new)) diff --git a/tree_unique_args/src/util.rs b/tree_unique_args/src/util.rs index a5380555c..e4c6bd809 100644 --- a/tree_unique_args/src/util.rs +++ b/tree_unique_args/src/util.rs @@ -1,9 +1,56 @@ +use crate::ir::{Constructor, Purpose}; +use std::iter; +use strum::IntoEnumIterator; + +fn ast_size_for_ctor(ctor: Constructor) -> String { + let ctor_pattern = ctor.construct(|field| field.var()); + let ruleset = " :ruleset always-run"; + match ctor { + // List itself don't count size + Constructor::Nil => format!("(rule ({ctor_pattern}) ((set (ListExpr-size {ctor_pattern}) 0)) {ruleset})"), + Constructor::Cons => format!("(rule ((= list (Cons expr xs)) (= a (Expr-size expr)) (= b (ListExpr-size xs))) ((set (ListExpr-size list) (+ a b))){ruleset})"), + // let Get and All's size = children's size (I prefer not +1 here) + Constructor::Get => format!("(rule ((= expr (Get tup i)) (= n (Expr-size tup))) ((set (Expr-size expr) n)) {ruleset})"), + Constructor::All => format!("(rule ((= expr (All id ord list)) (= n (ListExpr-size list))) ((set (Expr-size expr) n)) {ruleset})"), + _ => { + let field_pattern = ctor.filter_map_fields(|field| { + let sort = field.sort().name(); + let var = field.var(); + match field.purpose { + Purpose::CapturedExpr + | Purpose::SubExpr + | Purpose::SubListExpr => + Some(format!("({sort}-size {var})")), + _ => None + } + }); + + let len = field_pattern.len(); + let result_str = field_pattern.join(" "); + + match len { + // Num, Bool Arg, UnitExpr for 0 + 0 => format!("(rule ((= expr {ctor_pattern})) ((set (Expr-size expr) 1)) {ruleset})"), + 1 => format!("(rule ((= expr {ctor_pattern}) (= n {result_str})) ((set (Expr-size expr) (+ 1 n))){ruleset})"), + 2 => format!("(rule ((= expr {ctor_pattern}) (= sum (+ {result_str}))) ((set (Expr-size expr) (+ 1 sum))){ruleset})"), + _ => panic!("Unimplemented") // we don't have ast take three Expr + } + }, + } +} + +pub(crate) fn rules() -> Vec { + iter::once(include_str!("util.egg").to_string()) + .chain(Constructor::iter().map(ast_size_for_ctor)) + .collect::>() +} + #[test] fn test_list_util() -> Result<(), egglog::Error> { let build = &*" (let id (Id 1)) (let list (Cons (Num id 0) (Cons (Num id 1) (Cons (Num id 2) (Cons (Num id 3) (Cons (Num id 4) (Nil))))))) - (let t (All (Sequential) list)) + (let t (All id (Sequential) list)) ".to_string(); let check = &*" (check (= (ListExpr-ith list 1) (Num id 1))) @@ -43,12 +90,12 @@ fn get_loop_output_ith_test() -> Result<(), egglog::Error> { (let id-outer (Id (i64-fresh!))) (let loop1 (Loop id1 - (All (Parallel) (Pair (Arg id-outer) (Num id-outer 0))) - (All (Sequential) (Pair + (All id-outer (Parallel) (Pair (Arg id-outer) (Num id-outer 0))) + (All id1 (Sequential) (Pair ; pred (LessThan (Get (Arg id1) 0) (Get (Arg id1) 1)) ; output - (All (Parallel) (Pair + (All id1 (Parallel) (Pair (Add (Get (Arg id1) 0) (Num id1 1)) (Sub (Get (Arg id1) 1) (Num id1 1)))))))) (let out0 (Add (Get (Arg id1) 0) (Num id1 1))) @@ -65,7 +112,51 @@ fn get_loop_output_ith_test() -> Result<(), egglog::Error> { = (get-loop-outputs-ith loop 1) out1 + ))"; + crate::run_test(build, check) +} + +#[test] +fn ast_size_test() -> Result<(), egglog::Error> { + let build = " + (let id1 (Id (i64-fresh!))) + (let id-outer (Id (i64-fresh!))) + (let inv + (Sub (Get (Arg id1) 4) + (Mul (Get (Arg id1) 2) + (Switch (Num id1 1) (list4 (Num id1 1) + (Num id1 2) + (Num id1 3) + (Num id1 4)) + ) + ) )) + + (let loop + (Loop id1 + (All id-outer (Parallel) (list5 (Num id-outer 0) + (Num id-outer 1) + (Num id-outer 2) + (Num id-outer 3) + (Num id-outer 4))) + (All id1 (Sequential) (Pair + ; pred + (LessThan (Get (Arg id1) 0) (Get (Arg id1) 4)) + ; output + (All id1 (Parallel) + (list5 + (Add (Get (Arg id1) 0) + inv + ) + (Get (Arg id1) 1) + (Get (Arg id1) 2) + (Get (Arg id1) 3) + (Get (Arg id1) 4) )))))) + "; + + let check = " + (check (= 10 (Expr-size inv))) + (check (= 25 (Expr-size loop))) "; crate::run_test(build, check)