Skip to content

Commit

Permalink
add If to schema
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Jan 31, 2024
1 parent 492040c commit 498a972
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tree_optimizer/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl Expr {
}
}
BOp(..) | UOp(..) | Print(..) | Write(..) | Read(..) | Get(..) | Program(..)
| Switch(..) => {
| Switch(..) | If(..) => {
self.for_each_child(move |child| child.give_fresh_ids_helper(current_id, fresh_id))
}
}
Expand Down
15 changes: 15 additions & 0 deletions tree_optimizer/src/error_checking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ pub(crate) fn error_checking_rules() -> Vec<String> {
(rule ((IsBranchList (Cons a rest)))
((IsBranchList rest))
:ruleset error-checking)
(rule ((If pred then else))
((IsBranchList (Cons then (Cons else (Nil)))))
:ruleset error-checking)
"
)];

Expand Down Expand Up @@ -58,3 +62,14 @@ fn test_switch_with_num_child() {
";
let _ = crate::run_test(build, "");
}

#[test]
#[should_panic(expected = "Expected Branch, got Switch")]
fn test_if_switch_child() {
let build = "
(If (Boolean (Shared) true)
(Switch (Num (Shared) 0) (Nil))
(Branch (Shared) (Num (Shared) 1)))
";
let _ = crate::run_test(build, "");
}
12 changes: 10 additions & 2 deletions tree_optimizer/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,10 @@ pub enum Expr {
Read(Box<Expr>, TreeType),
Write(Box<Expr>, Box<Expr>),
All(Id, Order, Vec<Expr>),
/// A pred and a list of branches
/// An integer pred and a list of branches
Switch(Box<Expr>, Vec<Expr>),
/// A boolean pred, then and else
If(Box<Expr>, Box<Expr>, Box<Expr>),
/// Should only be a child of `Switch`
/// Represents a single branch of a switch, giving
/// it a unique id
Expand Down Expand Up @@ -208,7 +210,7 @@ impl Expr {
match self {
Num(..) | Boolean(..) | Arg(..) | BOp(..) | UOp(..) | Get(..) | Read(..) | All(..)
| Switch(..) | Branch(..) | Loop(..) | Let(..) | Function(..) | Program(..)
| Call(..) => true,
| If(..) | Call(..) => true,
Print(..) | Write(..) => false,
}
}
Expand All @@ -232,6 +234,7 @@ impl Expr {
Expr::Function(_, _, _, _, _) => "Function",
Expr::Program(_) => "Program",
Expr::Call(_, _, _) => "Call",
Expr::If(_, _, _) => "If",
}
}

Expand Down Expand Up @@ -267,6 +270,11 @@ impl Expr {
func(child);
}
}
Expr::If(pred, then, els) => {
func(pred);
func(then);
func(els);
}
Expr::Branch(_id, child) => {
func(child);
}
Expand Down
24 changes: 24 additions & 0 deletions tree_optimizer/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ pub fn typecheck(e: &Expr, arg_ty: &Option<TreeType>) -> Result<TreeType, TypeEr
}
Ok(ty)
}
Expr::If(pred, then, els) => {
expect_type(pred, Bril(Bool))?;
let then_ty = typecheck(then, arg_ty)?;
let else_ty = typecheck(els, arg_ty)?;
if then_ty == else_ty {
Ok(then_ty)
} else {
Err(TypeError::ExpectedType(
*els.clone(),
then_ty.clone(),
else_ty,
))
}
}
Expr::Branch(_id, child) => typecheck(child, arg_ty),
Expr::Loop(_, input, pred_output) => {
let input_ty = typecheck(input, arg_ty)?;
Expand Down Expand Up @@ -240,6 +254,16 @@ pub fn interpret(e: &Expr, arg: &Option<Value>, vm: &mut VirtualMachine) -> Valu
};
interpret(&branches[pred as usize], arg, vm)
}
Expr::If(pred, then, els) => {
let Value::Boolean(pred) = interpret(pred, arg, vm) else {
panic!("if")
};
if pred {
interpret(then, arg, vm)
} else {
interpret(els, arg, vm)
}
}
Expr::Branch(_id, child) => interpret(child, arg, vm),
Expr::Loop(_, input, pred_output) => {
let mut vals = interpret(input, arg, vm);
Expand Down
3 changes: 3 additions & 0 deletions tree_optimizer/src/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ impl Constructor {
Constructor::Expr(Expr::Switch(..)) => {
vec![f(SubExpr, "pred"), f(SubListExpr, "branches")]
}
Constructor::Expr(Expr::If(..)) => {
vec![f(SubExpr, "pred"), f(SubExpr, "then"), f(SubExpr, "else")]
}
Constructor::Expr(Expr::Branch(..)) => {
vec![f(CapturingId, "id"), f(SubExpr, "expr")]
}
Expand Down
7 changes: 6 additions & 1 deletion tree_optimizer/src/schema.egg
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,14 @@
; Perform a list of operations. Only way to create a tuple!
(function All (IdSort Order ListExpr) Expr)

; Switch on a list of lazily-evaluated branches. Doesn't create context
; Switch on a list of lazily-evaluated branches.
; branches must be constructed with `(Branch id expr)`.
; pred must be an integer
; pred branches chosen
(function Switch (Expr ListExpr) Expr)
; If is like switch, but with a boolean predicate
; pred then else
(function If (Expr Expr Expr) Expr)
(function Branch (IdSort Expr) Expr)

; ==========================
Expand Down
23 changes: 13 additions & 10 deletions tree_optimizer/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,23 @@ fn ast_size_for_ctor(ctor: Constructor) -> String {
}
});

let len = field_pattern.len();
let result_str = field_pattern.join(" ");

match len {
// Num, Bool Arg, UnitExpr for 0
0 => format!("(rule ((= expr {ctor_pattern})) ((set (Expr-size expr) 1)) {ruleset})"),
1 => format!("(rule ((= expr {ctor_pattern}) (= n {result_str})) ((set (Expr-size expr) (+ 1 n))){ruleset})"),
2 => format!("(rule ((= expr {ctor_pattern}) (= sum (+ {result_str}))) ((set (Expr-size expr) (+ 1 sum))){ruleset})"),
_ => panic!("Unimplemented") // we don't have ast take three Expr
}
let result_str = sum_vars(field_pattern);
format!("(rule ((= expr {ctor_pattern}) (= sum {result_str})) ((set (Expr-size expr) (+ 1 sum))){ruleset})")
},
}
}

fn sum_vars(vars: Vec<String>) -> String {
let len = vars.len();
match len {
0 => "0".to_string(),
1 => vars[0].clone(),
2 => format!("(+ {} {})", vars[0], vars[1]),
3 => format!("(+ {} (+ {} {}))", vars[0], vars[1], vars[2]),
_ => panic!("Unimplemented"),
}
}

pub(crate) fn rules() -> Vec<String> {
iter::once(include_str!("util.egg").to_string())
.chain(Constructor::iter().map(ast_size_for_ctor))
Expand Down

0 comments on commit 498a972

Please sign in to comment.