diff --git a/tree_optimizer/src/ast.rs b/tree_optimizer/src/ast.rs index 5553eab99..75364d941 100644 --- a/tree_optimizer/src/ast.rs +++ b/tree_optimizer/src/ast.rs @@ -86,7 +86,7 @@ impl Expr { } } BOp(..) | UOp(..) | Print(..) | Write(..) | Read(..) | Get(..) | Program(..) - | Switch(..) => { + | Switch(..) | If(..) => { self.for_each_child(move |child| child.give_fresh_ids_helper(current_id, fresh_id)) } } diff --git a/tree_optimizer/src/error_checking.rs b/tree_optimizer/src/error_checking.rs index bc6902f02..ebe771680 100644 --- a/tree_optimizer/src/error_checking.rs +++ b/tree_optimizer/src/error_checking.rs @@ -19,6 +19,10 @@ pub(crate) fn error_checking_rules() -> Vec { (rule ((IsBranchList (Cons a rest))) ((IsBranchList rest)) :ruleset error-checking) + +(rule ((If pred then else)) + ((IsBranchList (Cons then (Cons else (Nil))))) + :ruleset error-checking) " )]; @@ -58,3 +62,14 @@ fn test_switch_with_num_child() { "; let _ = crate::run_test(build, ""); } + +#[test] +#[should_panic(expected = "Expected Branch, got Switch")] +fn test_if_switch_child() { + let build = " +(If (Boolean (Shared) true) + (Switch (Num (Shared) 0) (Nil)) + (Branch (Shared) (Num (Shared) 1))) + "; + let _ = crate::run_test(build, ""); +} diff --git a/tree_optimizer/src/expr.rs b/tree_optimizer/src/expr.rs index bfb2eac50..3faee664b 100644 --- a/tree_optimizer/src/expr.rs +++ b/tree_optimizer/src/expr.rs @@ -179,8 +179,10 @@ pub enum Expr { Read(Box, TreeType), Write(Box, Box), All(Id, Order, Vec), - /// A pred and a list of branches + /// An integer pred and a list of branches Switch(Box, Vec), + /// A boolean pred, then and else + If(Box, Box, Box), /// Should only be a child of `Switch` /// Represents a single branch of a switch, giving /// it a unique id @@ -208,7 +210,7 @@ impl Expr { match self { Num(..) | Boolean(..) | Arg(..) | BOp(..) | UOp(..) | Get(..) | Read(..) | All(..) | Switch(..) | Branch(..) | Loop(..) | Let(..) | Function(..) | Program(..) - | Call(..) => true, + | If(..) | Call(..) => true, Print(..) | Write(..) => false, } } @@ -232,6 +234,7 @@ impl Expr { Expr::Function(_, _, _, _, _) => "Function", Expr::Program(_) => "Program", Expr::Call(_, _, _) => "Call", + Expr::If(_, _, _) => "If", } } @@ -267,6 +270,11 @@ impl Expr { func(child); } } + Expr::If(pred, then, els) => { + func(pred); + func(then); + func(els); + } Expr::Branch(_id, child) => { func(child); } diff --git a/tree_optimizer/src/interpreter.rs b/tree_optimizer/src/interpreter.rs index b38d3c5f0..4432acc28 100644 --- a/tree_optimizer/src/interpreter.rs +++ b/tree_optimizer/src/interpreter.rs @@ -89,6 +89,20 @@ pub fn typecheck(e: &Expr, arg_ty: &Option) -> Result { + expect_type(pred, Bril(Bool))?; + let then_ty = typecheck(then, arg_ty)?; + let else_ty = typecheck(els, arg_ty)?; + if then_ty == else_ty { + Ok(then_ty) + } else { + Err(TypeError::ExpectedType( + *els.clone(), + then_ty.clone(), + else_ty, + )) + } + } Expr::Branch(_id, child) => typecheck(child, arg_ty), Expr::Loop(_, input, pred_output) => { let input_ty = typecheck(input, arg_ty)?; @@ -240,6 +254,16 @@ pub fn interpret(e: &Expr, arg: &Option, vm: &mut VirtualMachine) -> Valu }; interpret(&branches[pred as usize], arg, vm) } + Expr::If(pred, then, els) => { + let Value::Boolean(pred) = interpret(pred, arg, vm) else { + panic!("if") + }; + if pred { + interpret(then, arg, vm) + } else { + interpret(els, arg, vm) + } + } Expr::Branch(_id, child) => interpret(child, arg, vm), Expr::Loop(_, input, pred_output) => { let mut vals = interpret(input, arg, vm); diff --git a/tree_optimizer/src/ir.rs b/tree_optimizer/src/ir.rs index 71555a869..ae6d71342 100644 --- a/tree_optimizer/src/ir.rs +++ b/tree_optimizer/src/ir.rs @@ -133,6 +133,9 @@ impl Constructor { Constructor::Expr(Expr::Switch(..)) => { vec![f(SubExpr, "pred"), f(SubListExpr, "branches")] } + Constructor::Expr(Expr::If(..)) => { + vec![f(SubExpr, "pred"), f(SubExpr, "then"), f(SubExpr, "else")] + } Constructor::Expr(Expr::Branch(..)) => { vec![f(CapturingId, "id"), f(SubExpr, "expr")] } diff --git a/tree_optimizer/src/schema.egg b/tree_optimizer/src/schema.egg index e8aa678b5..570e3b199 100644 --- a/tree_optimizer/src/schema.egg +++ b/tree_optimizer/src/schema.egg @@ -65,9 +65,14 @@ ; Perform a list of operations. Only way to create a tuple! (function All (IdSort Order ListExpr) Expr) -; Switch on a list of lazily-evaluated branches. Doesn't create context +; Switch on a list of lazily-evaluated branches. +; branches must be constructed with `(Branch id expr)`. +; pred must be an integer ; pred branches chosen (function Switch (Expr ListExpr) Expr) +; If is like switch, but with a boolean predicate +; pred then else +(function If (Expr Expr Expr) Expr) (function Branch (IdSort Expr) Expr) ; ========================== diff --git a/tree_optimizer/src/util.rs b/tree_optimizer/src/util.rs index 3117d7e39..5d7af7fba 100644 --- a/tree_optimizer/src/util.rs +++ b/tree_optimizer/src/util.rs @@ -32,20 +32,23 @@ fn ast_size_for_ctor(ctor: Constructor) -> String { } }); - let len = field_pattern.len(); - let result_str = field_pattern.join(" "); - - match len { - // Num, Bool Arg, UnitExpr for 0 - 0 => format!("(rule ((= expr {ctor_pattern})) ((set (Expr-size expr) 1)) {ruleset})"), - 1 => format!("(rule ((= expr {ctor_pattern}) (= n {result_str})) ((set (Expr-size expr) (+ 1 n))){ruleset})"), - 2 => format!("(rule ((= expr {ctor_pattern}) (= sum (+ {result_str}))) ((set (Expr-size expr) (+ 1 sum))){ruleset})"), - _ => panic!("Unimplemented") // we don't have ast take three Expr - } + let result_str = sum_vars(field_pattern); + format!("(rule ((= expr {ctor_pattern}) (= sum {result_str})) ((set (Expr-size expr) (+ 1 sum))){ruleset})") }, } } +fn sum_vars(vars: Vec) -> String { + let len = vars.len(); + match len { + 0 => "0".to_string(), + 1 => vars[0].clone(), + 2 => format!("(+ {} {})", vars[0], vars[1]), + 3 => format!("(+ {} (+ {} {}))", vars[0], vars[1], vars[2]), + _ => panic!("Unimplemented"), + } +} + pub(crate) fn rules() -> Vec { iter::once(include_str!("util.egg").to_string()) .chain(Constructor::iter().map(ast_size_for_ctor))