Skip to content

Commit

Permalink
Merge branch 'main' of github.com:egraphs-good/eggcc into yihozhang-t…
Browse files Browse the repository at this point in the history
…u-args-used-analysis
  • Loading branch information
yihozhang committed Jan 29, 2024
2 parents e09fbdd + 2dc98fe commit 2e6bb70
Show file tree
Hide file tree
Showing 19 changed files with 116 additions and 119 deletions.
12 changes: 4 additions & 8 deletions tree_unique_args/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ pub fn tfalse() -> Expr {
Boolean(false)
}

pub fn unit() -> Expr {
Unit
}

pub fn add(a: Expr, b: Expr) -> Expr {
Add(Box::new(a), Box::new(b))
}
Expand Down Expand Up @@ -99,12 +95,12 @@ pub fn print(a: Expr) -> Expr {
Print(Box::new(a))
}

pub fn sequence(args: Vec<Expr>) -> Expr {
All(Order::Sequential, args)
pub fn sequence(id: Id, args: Vec<Expr>) -> Expr {
All(id, Order::Sequential, args)
}

pub fn parallel(args: Vec<Expr>) -> Expr {
All(Order::Parallel, args)
pub fn parallel(id: Id, args: Vec<Expr>) -> Expr {
All(id, Order::Parallel, args)
}

pub fn switch(arg: Expr, cases: Vec<Expr>) -> Expr {
Expand Down
2 changes: 1 addition & 1 deletion tree_unique_args/src/body_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fn test_body_contains() -> Result<(), egglog::Error> {
(let loop
(Loop id1
(Num id-outer 1)
(All (Sequential) (Pair
(All id1 (Sequential) (Pair
; pred
(LessThan (Num id1 2) (Num id1 3))
; output
Expand Down
12 changes: 6 additions & 6 deletions tree_unique_args/src/deep_copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ fn test_deep_copy() -> Result<(), egglog::Error> {
(let id-outer (Id (i64-fresh!)))
(let loop
(Loop id1
(All (Parallel) (Pair (Arg id-outer) (Num id-outer 0)))
(All (Sequential) (Pair
(All id-outer (Parallel) (Pair (Arg id-outer) (Num id-outer 0)))
(All id1 (Sequential) (Pair
; pred
(LessThan (Get (Arg id1) 0) (Get (Arg id1) 1))
; output
(Let id2
(All (Parallel) (Pair
(All id1 (Parallel) (Pair
(Add (Get (Arg id1) 0) (Num id1 1))
(Sub (Get (Arg id1) 1) (Num id1 1))))
(Arg id2))))))
Expand All @@ -75,13 +75,13 @@ fn test_deep_copy() -> Result<(), egglog::Error> {
let check = "
(let loop-copied-expected
(Loop (Id 4)
(All (Parallel) (Pair (Arg (Id 3)) (Num (Id 3) 0)))
(All (Sequential) (Pair
(All (Id 3) (Parallel) (Pair (Arg (Id 3)) (Num (Id 3) 0)))
(All (Id 4) (Sequential) (Pair
; pred
(LessThan (Get (Arg (Id 4)) 0) (Get (Arg (Id 4)) 1))
; output
(Let (Id 5)
(All (Parallel) (Pair
(All (Id 4) (Parallel) (Pair
(Add (Get (Arg (Id 4)) 0) (Num (Id 4) 1))
(Sub (Get (Arg (Id 4)) 1) (Num (Id 4) 1))))
(Arg (Id 5)))))))
Expand Down
13 changes: 7 additions & 6 deletions tree_unique_args/src/id_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,21 @@ fn test_id_analysis() -> Result<(), egglog::Error> {
(ExprIsValid (Let
let0-id
(Num outer-id 0)
(All
(All let0-id
(Parallel)
(Pair
(Let
let1-id
(Num let0-id 3)
(Boolean let1-id true))
(UnitExpr let0-id)
(All let0-id (Parallel) (Nil))
))))
";

let check = "
(check (ExprHasRefId (Num outer-id 0) outer-id))
(check (ExprHasRefId (Boolean let1-id true) let1-id))
(check (ExprHasRefId (UnitExpr let0-id) let0-id))
(check (ExprHasRefId (All let0-id (Parallel) (Nil)) let0-id))
(check (ExprHasRefId
(Let
let1-id
Expand All @@ -81,7 +81,7 @@ fn test_id_analysis() -> Result<(), egglog::Error> {
let1-id
(Num let0-id 3)
(Boolean let1-id true))
(UnitExpr let0-id)
(All let0-id (Parallel) (Nil))
)
let0-id
))
Expand All @@ -90,13 +90,14 @@ fn test_id_analysis() -> Result<(), egglog::Error> {
let0-id
(Num outer-id 0)
(All
let0-id
(Parallel)
(Pair
(Let
let1-id
(Num let0-id 3)
(Boolean let1-id true))
(UnitExpr let0-id)
(All let0-id (Parallel) (Nil))
)))
outer-id
))
Expand Down Expand Up @@ -135,7 +136,7 @@ fn test_id_analysis_listexpr_id_conflict_panics() {
let build = "
(let id1 (Id (i64-fresh!)))
(let id2 (Id (i64-fresh!)))
(let conflict-expr (Cons (Num id1 3) (Cons (UnitExpr id2) (Nil))))
(let conflict-expr (Cons (Num id1 3) (Cons (All id2 (Parallel) (Nil)) (Nil))))
(ListExprIsValid conflict-expr)";
let check = "";

Expand Down
35 changes: 21 additions & 14 deletions tree_unique_args/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pub fn typecheck(e: &Expr, arg_ty: &Option<Type>) -> Result<Type, TypeError> {
Expr::Program(_) => panic!("Found non top level program."),
Expr::Num(_) => Ok(Type::Num),
Expr::Boolean(_) => Ok(Type::Boolean),
Expr::Unit => Ok(Type::Unit),
Expr::Add(e1, e2) | Expr::Sub(e1, e2) | Expr::Mul(e1, e2) => {
expect_type(e1, Type::Num)?;
expect_type(e2, Type::Num)?;
Expand Down Expand Up @@ -71,7 +70,7 @@ pub fn typecheck(e: &Expr, arg_ty: &Option<Type>) -> Result<Type, TypeError> {
Expr::Print(e) => {
// right now, only print nums
expect_type(e, Type::Num)?;
Ok(Type::Unit)
Ok(Type::Tuple(vec![]))
}
Expr::Read(addr) => {
// right now, all memory holds nums.
Expand All @@ -83,9 +82,9 @@ pub fn typecheck(e: &Expr, arg_ty: &Option<Type>) -> Result<Type, TypeError> {
Expr::Write(addr, data) => {
expect_type(addr, Type::Num)?;
expect_type(data, Type::Num)?;
Ok(Type::Unit)
Ok(Type::Tuple(vec![]))
}
Expr::All(_, exprs) => {
Expr::All(_, _, exprs) => {
let tys = exprs
.iter()
.map(|expr| typecheck(expr, arg_ty))
Expand Down Expand Up @@ -138,7 +137,6 @@ pub fn interpret(e: &Expr, arg: &Option<Value>, vm: &mut VirtualMachine) -> Valu
Expr::Program(_) => todo!("interpret programs"),
Expr::Num(x) => Value::Num(*x),
Expr::Boolean(x) => Value::Boolean(*x),
Expr::Unit => Value::Unit,
Expr::Add(e1, e2) => {
let Value::Num(n1) = interpret(e1, arg, vm) else {
panic!("add")
Expand Down Expand Up @@ -219,7 +217,7 @@ pub fn interpret(e: &Expr, arg: &Option<Value>, vm: &mut VirtualMachine) -> Valu
panic!("print")
};
vm.log.push(n);
Value::Unit
Value::Tuple(vec![])
}
Expr::Read(e_addr) => {
let Value::Num(addr) = interpret(e_addr, arg, vm) else {
Expand All @@ -233,9 +231,9 @@ pub fn interpret(e: &Expr, arg: &Option<Value>, vm: &mut VirtualMachine) -> Valu
};
let data = interpret(e_data, arg, vm);
vm.mem.insert(addr as usize, data);
Value::Unit
Value::Tuple(vec![])
}
Expr::All(_, exprs) => {
Expr::All(_, _, exprs) => {
// this always executes sequentially (which is a valid way to
// execute parallel tuples)
let vals = exprs
Expand Down Expand Up @@ -285,13 +283,15 @@ fn test_interpreter() {
Id(0),
Box::new(Expr::Num(1)),
Box::new(Expr::All(
Id(0),
Order::Parallel,
vec![
// pred: i < 10
Expr::LessThan(Box::new(Expr::Arg(Id(0))), Box::new(Expr::Num(10))),
// output
Expr::Get(
Box::new(Expr::All(
Id(0),
Order::Parallel,
vec![
// i = i + 1
Expand Down Expand Up @@ -319,6 +319,7 @@ fn test_interpreter_fib_using_memory() {
let nth = 10;
let fib_nth = 55;
let e = Expr::All(
Id(-1),
Order::Sequential,
vec![
Expr::Write(Box::new(Expr::Num(0)), Box::new(Expr::Num(0))),
Expand All @@ -327,13 +328,15 @@ fn test_interpreter_fib_using_memory() {
Id(0),
Box::new(Expr::Num(2)),
Box::new(Expr::All(
Id(0),
Order::Parallel,
vec![
// pred: i < nth
Expr::LessThan(Box::new(Expr::Arg(Id(0))), Box::new(Expr::Num(nth))),
// output
Expr::Get(
Box::new(Expr::All(
Id(0),
Order::Parallel,
vec![
// i = i + 1
Expand Down Expand Up @@ -370,8 +373,8 @@ fn test_interpreter_fib_using_memory() {
assert_eq!(
res,
Value::Tuple(vec![
Value::Unit,
Value::Unit,
Value::Tuple(vec![]),
Value::Tuple(vec![]),
Value::Num(nth + 1),
Value::Num(fib_nth)
])
Expand Down Expand Up @@ -443,7 +446,6 @@ impl std::str::FromStr for Expr {
("Boolean", [_id, egglog::ast::Expr::Lit(egglog::ast::Literal::Bool(b))]) => {
Ok(Expr::Boolean(*b))
}
("UnitExpr", [_id]) => Ok(Expr::Unit),
("Add", [x, y]) => Ok(Expr::Add(
Box::new(egglog_expr_to_expr(x)?),
Box::new(egglog_expr_to_expr(y)?),
Expand Down Expand Up @@ -482,7 +484,7 @@ impl std::str::FromStr for Expr {
Box::new(egglog_expr_to_expr(x)?),
Box::new(egglog_expr_to_expr(y)?),
)),
("All", [egglog::ast::Expr::Call(order, empty), xs]) => {
("All", [id, egglog::ast::Expr::Call(order, empty), xs]) => {
if !empty.is_empty() {
return Err(ExprParseError::InvalidOrderArguments);
}
Expand All @@ -491,7 +493,11 @@ impl std::str::FromStr for Expr {
"Sequential" => Ok(Order::Sequential),
s => Err(ExprParseError::InvalidOrder(s.to_owned())),
}?;
Ok(Expr::All(order, list_expr_to_vec(xs)?))
Ok(Expr::All(
egglog_expr_to_id(id)?,
order,
list_expr_to_vec(xs)?,
))
}
("Switch", [pred, branches]) => Ok(Expr::Switch(
Box::new(egglog_expr_to_expr(pred)?),
Expand Down Expand Up @@ -533,7 +539,7 @@ fn test_expr_parser() {
let s = "(Loop
(Id 1)
(Num (Id 0) 1)
(All (Sequential)
(All (Id 1) (Sequential)
(Cons (LessThan (Num (Id 1) 2) (Num (Id 1) 3))
(Cons (Switch (Boolean (Id 1) true) (Cons (Num (Id 1) 4) (Cons (Num (Id 1) 5) (Nil))))
(Nil)))))
Expand All @@ -543,6 +549,7 @@ fn test_expr_parser() {
Id(1),
Box::new(Expr::Num(1)),
Box::new(Expr::All(
Id(1),
Order::Sequential,
vec![
Expr::LessThan(Box::new(Expr::Num(2)), Box::new(Expr::Num(3))),
Expand Down
10 changes: 5 additions & 5 deletions tree_unique_args/src/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ impl ESort {
pub(crate) enum Constructor {
Num,
Boolean,
UnitExpr,
Add,
Sub,
Mul,
Expand Down Expand Up @@ -123,7 +122,6 @@ impl Constructor {
match self {
Constructor::Num => "Num",
Constructor::Boolean => "Boolean",
Constructor::UnitExpr => "UnitExpr",
Constructor::Add => "Add",
Constructor::Sub => "Sub",
Constructor::Mul => "Mul",
Expand Down Expand Up @@ -152,7 +150,6 @@ impl Constructor {
match self {
Constructor::Num => vec![f(ReferencingId, "id"), f(Static(Sort::I64), "n")],
Constructor::Boolean => vec![f(ReferencingId, "id"), f(Static(Sort::Bool), "b")],
Constructor::UnitExpr => vec![f(ReferencingId, "id")],
Constructor::Add => vec![f(SubExpr, "x"), f(SubExpr, "y")],
Constructor::Sub => vec![f(SubExpr, "x"), f(SubExpr, "y")],
Constructor::Mul => vec![f(SubExpr, "x"), f(SubExpr, "y")],
Expand All @@ -166,7 +163,11 @@ impl Constructor {
Constructor::Print => vec![f(SubExpr, "printee")],
Constructor::Read => vec![f(SubExpr, "addr")],
Constructor::Write => vec![f(SubExpr, "addr"), f(SubExpr, "data")],
Constructor::All => vec![f(Static(Sort::Order), "order"), f(SubListExpr, "exprs")],
Constructor::All => vec![
f(ReferencingId, "id"),
f(Static(Sort::Order), "order"),
f(SubListExpr, "exprs"),
],
Constructor::Switch => vec![f(SubExpr, "pred"), f(SubListExpr, "branches")],
Constructor::Loop => vec![
f(CapturingId, "id"),
Expand Down Expand Up @@ -209,7 +210,6 @@ impl Constructor {
match self {
Constructor::Num => ESort::Expr,
Constructor::Boolean => ESort::Expr,
Constructor::UnitExpr => ESort::Expr,
Constructor::Add => ESort::Expr,
Constructor::Sub => ESort::Expr,
Constructor::Mul => ESort::Expr,
Expand Down
6 changes: 3 additions & 3 deletions tree_unique_args/src/is_valid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ fn test_is_valid() -> Result<(), egglog::Error> {
(let id-outer (Id (i64-fresh!)))
(let loop
(Loop id1
(All (Parallel) (Pair (Num id-outer 0) (Num id-outer 0)))
(All (Sequential) (Pair
(All id-outer (Parallel) (Pair (Num id-outer 0) (Num id-outer 0)))
(All id1 (Sequential) (Pair
; pred
(LessThan (Get (Arg id1) 0) (Get (Arg id1) 1))
; output
(All (Parallel) (Pair
(All id1 (Parallel) (Pair
(Add (Get (Arg id1) 0) (Num id1 1))
(Sub (Get (Arg id1) 1) (Num id1 1))))))))
(ExprIsValid loop)
Expand Down
6 changes: 3 additions & 3 deletions tree_unique_args/src/ivt.egg
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@
(rule ((= loop (Loop id in out))
(ExprHasRefId loop outer-id)
(ExprIsValid loop)
(= out (All ord (Cons pred (Cons switch (Nil)))))
(= out (All id ord (Cons pred (Cons switch (Nil)))))
(= switch (Switch pred branches))
(ExprIsPure pred))
((LiftSwitch switch in outer-id)) :ruleset ivt)

; Apply the rule
(rule ((= loop (Loop id in out))
(ExprIsValid loop)
(= out (All ord (Cons pred (Cons switch (Nil)))))
(= out (All id ord (Cons pred (Cons switch (Nil)))))
(= switch (Switch pred (Cons thn* (Cons els* (Nil)))))
; NB: we don't constrain 'outer-id' here in part because there can only
; be _one_ outer-id that is referenced by it. (See the panic in the rules
; for *HasRefId).
(= (Switch pred_ (Cons thn (Cons els (Nil)))) (LiftSwitch switch in outer-id)))
((let new-id (Id (i64-fresh!)))
(let inner (NewLoop new-id in (All ord (Pair pred thn*))))
(let inner (NewLoop new-id in (All new-id ord (Pair pred thn*))))
(let outer (Switch pred_ (Cons inner (Cons els (Nil)))))
(union loop outer)) :ruleset ivt)
4 changes: 2 additions & 2 deletions tree_unique_args/src/ivt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ fn basic_ivt() -> Result {
(let switch (Switch pred
(Pair (Print (Num loop-id 0))
(Print (Num loop-id 1)))))
(let loop (Loop loop-id (Arg outer-id) (All (Sequential) (Pair pred switch))))
(let loop (Loop loop-id (Arg outer-id) (All loop-id (Sequential) (Pair pred switch))))
(ExprIsValid loop)";
let check = "
(check (= loop
(Switch
(LessThan (Arg outer-id) (Num outer-id 1))
(Cons
(Loop new-id (Arg outer-id)
(All (Sequential)
(All new-id (Sequential)
(Cons (LessThan (Arg new-id) (Num new-id 1))
(Cons (Print (Num new-id 0)) (Nil)))))
(Cons (Print (Num outer-id 1)) (Nil))))))";
Expand Down
Loading

0 comments on commit 2e6bb70

Please sign in to comment.