Skip to content

Commit

Permalink
Merge branch 'main' into yihozhang-tu-let-inline
Browse files Browse the repository at this point in the history
  • Loading branch information
yihozhang committed Jan 26, 2024
2 parents 8801783 + 4f26ab3 commit b21ad84
Show file tree
Hide file tree
Showing 11 changed files with 574 additions and 113 deletions.
10 changes: 1 addition & 9 deletions src/rvsdg/tree_unique/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1 @@
//! Convert RVSDG programs to the tree
//! encoding of programs.
//! RVSDGs are close to this encoding,
//! but use a DAG-based semantics.
//! This means that nodes that are shared
//! are only computed once.
//! These shared nodes need to be let-bound so that they are only
//! computed once in the tree encoded
//! program.
pub mod to_tree;
9 changes: 9 additions & 0 deletions src/rvsdg/tree_unique/to_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//! Convert RVSDG programs to the tree
//! encoding of programs.
//! RVSDGs are close to this encoding,
//! but use a DAG-based semantics.
//! This means that nodes that are shared
//! are only computed once.
//! These shared nodes need to be let-bound so that they are only
//! computed once in the tree encoded
//! program.
204 changes: 204 additions & 0 deletions tree_unique_args/src/ast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
use crate::{Expr, Expr::*, Id, Order};

pub fn give_fresh_ids(expr: &mut Expr) {
let mut id = 1;
give_fresh_ids_helper(expr, 0, &mut id);
}

fn give_fresh_ids_helper(expr: &mut Expr, current_id: i64, fresh_id: &mut i64) {
match expr {
Loop(id, input, body) => {
let new_id = *fresh_id;
*fresh_id += 1;
*id = Id(new_id);
give_fresh_ids_helper(input, current_id, fresh_id);
give_fresh_ids_helper(body, new_id, fresh_id);
}
Let(id, arg, body) => {
let new_id = *fresh_id;
*fresh_id += 1;
*id = Id(new_id);
give_fresh_ids_helper(arg, current_id, fresh_id);
give_fresh_ids_helper(body, new_id, fresh_id);
}
Arg(id) => {
*id = Id(current_id);
}
Function(id, body) => {
let new_id = *fresh_id;
*fresh_id += 1;
*id = Id(new_id);
give_fresh_ids_helper(body, new_id, fresh_id);
}
Call(id, arg) => {
*id = Id(current_id);
give_fresh_ids_helper(arg, current_id, fresh_id);
}
_ => expr.for_each_child(move |child| give_fresh_ids_helper(child, current_id, fresh_id)),
}
}

pub fn program(args: Vec<Expr>) -> Expr {
let mut prog = Program(args);
give_fresh_ids(&mut prog);
prog
}

pub fn num(n: i64) -> Expr {
Num(n)
}

pub fn ttrue() -> Expr {
Boolean(true)
}
pub fn tfalse() -> Expr {
Boolean(false)
}

pub fn unit() -> Expr {
Unit
}

pub fn add(a: Expr, b: Expr) -> Expr {
Add(Box::new(a), Box::new(b))
}

pub fn sub(a: Expr, b: Expr) -> Expr {
Sub(Box::new(a), Box::new(b))
}

pub fn mul(a: Expr, b: Expr) -> Expr {
Mul(Box::new(a), Box::new(b))
}

pub fn lessthan(a: Expr, b: Expr) -> Expr {
LessThan(Box::new(a), Box::new(b))
}

pub fn and(a: Expr, b: Expr) -> Expr {
And(Box::new(a), Box::new(b))
}

pub fn or(a: Expr, b: Expr) -> Expr {
Or(Box::new(a), Box::new(b))
}

pub fn not(a: Expr) -> Expr {
Not(Box::new(a))
}

pub fn get(a: Expr, i: usize) -> Expr {
Get(Box::new(a), i)
}

pub fn concat(a: Expr, b: Expr) -> Expr {
Concat(Box::new(a), Box::new(b))
}

pub fn print(a: Expr) -> Expr {
Print(Box::new(a))
}

pub fn sequence(args: Vec<Expr>) -> Expr {
All(Order::Sequential, args)
}

pub fn parallel(args: Vec<Expr>) -> Expr {
All(Order::Parallel, args)
}

pub fn switch(arg: Expr, cases: Vec<Expr>) -> Expr {
Switch(Box::new(arg), cases)
}

pub fn tloop(input: Expr, body: Expr) -> Expr {
Loop(Id(0), Box::new(input), Box::new(body))
}

pub fn tlet(arg: Expr, body: Expr) -> Expr {
Let(Id(0), Box::new(arg), Box::new(body))
}

pub fn arg() -> Expr {
Arg(Id(0))
}

pub fn function(arg: Expr) -> Expr {
Function(Id(0), Box::new(arg))
}

pub fn call(arg: Expr) -> Expr {
Call(Id(0), Box::new(arg))
}

#[test]
fn test_gives_nested_ids() {
let mut prog = tlet(num(0), tlet(num(1), num(2)));
give_fresh_ids(&mut prog);
assert_eq!(
prog,
Let(
Id(1),
Box::new(Num(0)),
Box::new(Let(Id(2), Box::new(Num(1)), Box::new(Num(2))))
)
);
}

#[test]
fn test_gives_loop_ids() {
let mut prog = tlet(num(0), tloop(num(1), num(2)));
give_fresh_ids(&mut prog);
assert_eq!(
prog,
Let(
Id(1),
Box::new(Num(0)),
Box::new(Loop(Id(2), Box::new(Num(1)), Box::new(Num(2))))
)
);
}

#[test]
fn test_complex_program_ids() {
// test a program that includes
// a let, a loop, a switch, and a call
let prog = program(vec![function(tlet(
num(0),
tloop(
num(1),
switch(
arg(),
vec![
num(2),
call(num(3)),
tlet(num(4), num(5)),
tloop(num(6), num(7)),
],
),
),
))]);
assert_eq!(
prog,
Program(vec![Function(
Id(1),
Box::new(Let(
Id(2),
Box::new(Num(0)),
Box::new(Loop(
Id(3),
Box::new(Num(1)),
Box::new(Switch(
Box::new(Arg(Id(3))),
vec![
Num(2),
Call(Id(3), Box::new(Num(3))),
Let(Id(4), Box::new(Num(4)), Box::new(Num(5))),
Loop(Id(5), Box::new(Num(6)), Box::new(Num(7))),
]
))
))
))
)])
);
}
28 changes: 28 additions & 0 deletions tree_unique_args/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub fn typecheck(e: &Expr, arg_ty: &Option<Type>) -> Result<Type, TypeError> {
}
};
match e {
Expr::Program(_) => panic!("Found non top level program."),
Expr::Num(_) => Ok(Type::Num),
Expr::Boolean(_) => Ok(Type::Boolean),
Expr::Unit => Ok(Type::Unit),
Expand Down Expand Up @@ -50,6 +51,23 @@ pub fn typecheck(e: &Expr, arg_ty: &Option<Type>) -> Result<Type, TypeError> {
)),
}
}
Expr::Concat(tuple_1, tuple_2) => {
let ty_tuple_1 = typecheck(tuple_1, arg_ty)?;
let ty_tuple_2 = typecheck(tuple_2, arg_ty)?;
match (ty_tuple_1.clone(), ty_tuple_2.clone()) {
(Type::Tuple(tys_1), Type::Tuple(tys_2)) => {
Ok(Type::Tuple(tys_1.into_iter().chain(tys_2).collect()))
}
(Type::Tuple(_tys_1), _) => Err(TypeError::ExpectedTupleType(
*tuple_2.clone(),
ty_tuple_2.clone(),
)),
_ => Err(TypeError::ExpectedTupleType(
*tuple_1.clone(),
ty_tuple_1.clone(),
)),
}
}
Expr::Print(e) => {
// right now, only print nums
expect_type(e, Type::Num)?;
Expand Down Expand Up @@ -117,6 +135,7 @@ pub struct VirtualMachine {
// assumes e typechecks and that memory is written before read
pub fn interpret(e: &Expr, arg: &Option<Value>, vm: &mut VirtualMachine) -> Value {
match e {
Expr::Program(_) => todo!("interpret programs"),
Expr::Num(x) => Value::Num(*x),
Expr::Boolean(x) => Value::Boolean(*x),
Expr::Unit => Value::Unit,
Expand Down Expand Up @@ -186,6 +205,15 @@ pub fn interpret(e: &Expr, arg: &Option<Value>, vm: &mut VirtualMachine) -> Valu
};
vals[*i].clone()
}
Expr::Concat(tuple_1, tuple_2) => {
let Value::Tuple(t1) = interpret(tuple_1, arg, vm) else {
panic!("concat")
};
let Value::Tuple(t2) = interpret(tuple_2, arg, vm) else {
panic!("concat")
};
Value::Tuple(t1.into_iter().chain(t2).collect())
}
Expr::Print(e) => {
let Value::Num(n) = interpret(e, arg, vm) else {
panic!("print")
Expand Down
57 changes: 56 additions & 1 deletion tree_unique_args/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod ast;
pub(crate) mod body_contains;
pub(crate) mod conditional_invariant_code_motion;
pub(crate) mod deep_copy;
Expand Down Expand Up @@ -38,6 +39,11 @@ pub enum Expr {
Or(Box<Expr>, Box<Expr>),
Not(Box<Expr>),
Get(Box<Expr>, usize),
/// Concat is a convenient built-in way
/// to put two tuples together.
/// It's not strictly necessary, but
/// doing it by constructing a new big tuple is tedius and slow.
Concat(Box<Expr>, Box<Expr>),
Print(Box<Expr>),
Read(Box<Expr>),
Write(Box<Expr>, Box<Expr>),
Expand All @@ -47,9 +53,58 @@ pub enum Expr {
Let(Id, Box<Expr>, Box<Expr>),
Arg(Id),
Function(Id, Box<Expr>),
/// A list of functions, with the first
/// being the main function.
Program(Vec<Expr>),
Call(Id, Box<Expr>),
}

impl Expr {
/// Runs `func` on every child of this expression.
pub fn for_each_child(&mut self, mut func: impl FnMut(&mut Expr)) {
match self {
Expr::Num(_) | Expr::Boolean(_) | Expr::Unit | Expr::Arg(_) => {}
Expr::Add(a, b)
| Expr::Sub(a, b)
| Expr::Mul(a, b)
| Expr::LessThan(a, b)
| Expr::And(a, b)
| Expr::Or(a, b)
| Expr::Concat(a, b)
| Expr::Write(a, b) => {
func(a);
func(b);
}
Expr::Not(a) | Expr::Print(a) | Expr::Read(a) => {
func(a);
}
Expr::Get(a, _) | Expr::Function(_, a) | Expr::Call(_, a) => {
func(a);
}
Expr::All(_, children) => {
for child in children {
func(child);
}
}
Expr::Switch(input, children) => {
func(input);
for child in children {
func(child);
}
}
Expr::Loop(_, pred, output) | Expr::Let(_, pred, output) => {
func(pred);
func(output);
}
Expr::Program(functions) => {
for function in functions {
func(function);
}
}
}
}
}

#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Num(i64),
Expand Down Expand Up @@ -85,7 +140,7 @@ pub fn run_test(build: &str, check: &str) -> Result {
&subst::subst_rules().join("\n"),
&deep_copy::deep_copy_rules().join("\n"),
include_str!("sugar.egg"),
include_str!("util.egg"),
&util::rules().join("\n"),
&id_analysis::id_analysis_rules().join("\n"),
// optimizations
include_str!("simple.egg"),
Expand Down
2 changes: 1 addition & 1 deletion tree_unique_args/src/schema.egg
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
; ==========================

; f output
(relation Function (IdSort Expr))
(function Function (IdSort Expr) Expr)

; f arg
(function Call (IdSort Expr) Expr)
Expand Down
12 changes: 12 additions & 0 deletions tree_unique_args/src/sugar.egg
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@
(Cons a (Cons b (Nil)))
:ruleset always-run)

(function list3 (Expr Expr Expr) ListExpr)
(rewrite (list3 a b c)
(Cons a (Cons b (Cons c (Nil)))) :ruleset always-run)

(function list4 (Expr Expr Expr Expr) ListExpr)
(rewrite (list4 a b c d)
(Cons a (Cons b (Cons c (Cons d (Nil))))) :ruleset always-run)

(function list5 (Expr Expr Expr Expr Expr) ListExpr)
(rewrite (list5 a b c d e)
(Cons a (Cons b (Cons c (Cons d (Cons e (Nil)))))) :ruleset always-run)

(function IgnoreFirst (Expr Expr) Expr)
(rewrite (IgnoreFirst a b)
(Get
Expand Down
Loading

0 comments on commit b21ad84

Please sign in to comment.