Skip to content

Commit

Permalink
get rid of concat
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Jan 30, 2024
1 parent 0e63305 commit c50e1a8
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 50 deletions.
36 changes: 29 additions & 7 deletions src/rvsdg/tree_unique/to_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use bril_rs::{Literal, ValueOps};
use hashbrown::HashMap;
use tree_optimizer::{
ast::{
add, arg, concat, function, get, getarg, lessthan, num, parallel, parallel_vec,
program_vec, tfalse, tlet, tloop, tprint, ttrue,
add, function, get, getarg, lessthan, num, parallel, parallel_vec, program_vec,
tfalse, tlet, tloop, tprint, ttrue,
},
expr::{Expr, TreeType},
};
Expand Down Expand Up @@ -53,8 +53,13 @@ struct RegionTranslator<'a> {

/// helper that binds a new expression, adding it
/// to the environment using concat
fn cbind(expr: Expr, body: Expr) -> Expr {
tlet(concat(arg(), expr), body)
fn cbind(index: usize, expr: Expr, body: Expr) -> Expr {
let mut concatted = vec![];
for i in 0..index {
concatted.push(getarg(i as usize));
}
concatted.push(expr);
tlet(parallel_vec(concatted), body)
}

impl<'a> RegionTranslator<'a> {
Expand Down Expand Up @@ -87,8 +92,8 @@ impl<'a> RegionTranslator<'a> {
fn build_translation(&self, inner: Expr) -> Expr {
let mut expr = inner;

for binding in self.bindings.iter().rev() {
expr = cbind(binding.clone(), expr);
for (i, binding) in self.bindings.iter().rev().enumerate() {
expr = cbind(i, binding.clone(), expr);
}
expr
}
Expand Down Expand Up @@ -244,13 +249,17 @@ fn translate_simple_loop() {
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
TreeType::Tuple(vec![TreeType::Bril(Type::Int), TreeType::Tuple(vec![])]),
cbind(
1,
num(1), // [(), 1]
cbind(
2,
num(2), // [(), 1, 2]
cbind(
3,
tloop(
parallel!(getarg(0), getarg(1), getarg(2)), // [(), 1, 2]
cbind(
4,
lessthan(getarg(1), getarg(2)), // [(), 1, 2, 1<2]
parallel!(getarg(3), parallel!(getarg(0), getarg(1), getarg(2)))
)
Expand Down Expand Up @@ -289,17 +298,23 @@ fn translate_loop() {
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
cbind(
1,
num(0), // [(), 0]
cbind(
2,
tloop(
parallel!(getarg(0), getarg(1)),
cbind(
3,
num(1), // [(), i, 1]
cbind(
4,
add(getarg(1), getarg(2)), // [(), i, 1, i+1]
cbind(
5,
num(10), // [(), i, 1, i+1, 10]
cbind(
6,
lessthan(getarg(3), getarg(4)), // [(), i, 1, i+1, 10, i<10]
parallel!(getarg(5), parallel!(getarg(0), getarg(3)))
)
Expand All @@ -308,6 +323,7 @@ fn translate_loop() {
)
),
cbind(
2,
tprint(get(getarg(2), 1)), // [(), 0, [() i]]
parallel!(getarg(3))
)
Expand Down Expand Up @@ -337,8 +353,10 @@ fn simple_translation() {
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
TreeType::Tuple(vec![TreeType::Bril(Type::Int), TreeType::Tuple(vec![])]),
cbind(
1,
num(1),
cbind(
2,
add(get(arg(), 1), get(arg(), 1)),
parallel!(get(arg(), 2), get(arg(), 0)), // returns res and print state (unit)
),
Expand Down Expand Up @@ -369,14 +387,18 @@ fn two_print_translation() {
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
cbind(
1,
num(2),
cbind(
2,
num(1),
cbind(
3,
add(get(arg(), 2), get(arg(), 1)),
cbind(
4,
tprint(get(arg(), 3)),
cbind(tprint(get(arg(), 1)), parallel!(get(arg(), 5))),
cbind(5, tprint(get(arg(), 1)), parallel!(get(arg(), 5))),
),
),
),
Expand Down
4 changes: 0 additions & 4 deletions tree_optimizer/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ pub fn get(a: Expr, i: usize) -> Expr {
Get(Box::new(a), i)
}

pub fn concat(a: Expr, b: Expr) -> Expr {
Concat(Box::new(a), Box::new(b))
}

pub fn tprint(a: Expr) -> Expr {
Print(Box::new(a))
}
Expand Down
14 changes: 4 additions & 10 deletions tree_optimizer/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,6 @@ pub enum Expr {
BOp(PureBOp, Box<Expr>, Box<Expr>),
UOp(PureUOp, Box<Expr>),
Get(Box<Expr>, usize),
/// Concat is a convenient built-in way
/// to put two tuples together.
/// It's not strictly necessary, but
/// doing it by constructing a new big tuple is tedius and slow.
Concat(Box<Expr>, Box<Expr>),
Print(Box<Expr>),
Read(Box<Expr>),
Write(Box<Expr>, Box<Expr>),
Expand Down Expand Up @@ -211,9 +206,9 @@ impl Expr {
pub fn is_pure(&self) -> bool {
use Expr::*;
match self {
Num(..) | Boolean(..) | Arg(..) | BOp(..) | UOp(..) | Get(..) | Concat(..)
| Read(..) | All(..) | Switch(..) | Branch(..) | Loop(..) | Let(..) | Function(..)
| Program(..) | Call(..) => true,
Num(..) | Boolean(..) | Arg(..) | BOp(..) | UOp(..) | Get(..) | Read(..) | All(..)
| Switch(..) | Branch(..) | Loop(..) | Let(..) | Function(..) | Program(..)
| Call(..) => true,
Print(..) | Write(..) => false,
}
}
Expand All @@ -225,7 +220,6 @@ impl Expr {
Expr::BOp(_, _, _) => "BOp",
Expr::UOp(_, _) => "UOp",
Expr::Get(_, _) => "Get",
Expr::Concat(_, _) => todo!("Remove concat from ast"),
Expr::Print(_) => "Print",
Expr::Read(_) => "Read",
Expr::Write(_, _) => "Write",
Expand All @@ -252,7 +246,7 @@ impl Expr {
Expr::UOp(_, a) => {
func(a);
}
Expr::Concat(a, b) | Expr::Write(a, b) => {
Expr::Write(a, b) => {
func(a);
func(b);
}
Expand Down
26 changes: 0 additions & 26 deletions tree_optimizer/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,6 @@ pub fn typecheck(e: &Expr, arg_ty: &Option<TreeType>) -> Result<TreeType, TypeEr
)),
}
}
Expr::Concat(tuple_1, tuple_2) => {
let ty_tuple_1 = typecheck(tuple_1, arg_ty)?;
let ty_tuple_2 = typecheck(tuple_2, arg_ty)?;
match (ty_tuple_1.clone(), ty_tuple_2.clone()) {
(TreeType::Tuple(tys_1), TreeType::Tuple(tys_2)) => {
Ok(TreeType::Tuple(tys_1.into_iter().chain(tys_2).collect()))
}
(TreeType::Tuple(_tys_1), _) => Err(TypeError::ExpectedTupleType(
*tuple_2.clone(),
ty_tuple_2.clone(),
)),
_ => Err(TypeError::ExpectedTupleType(
*tuple_1.clone(),
ty_tuple_1.clone(),
)),
}
}
Expr::Print(e) => {
// right now, only print nums
expect_type(e, Bril(Int))?;
Expand Down Expand Up @@ -218,15 +201,6 @@ pub fn interpret(e: &Expr, arg: &Option<Value>, vm: &mut VirtualMachine) -> Valu
};
vals[*i].clone()
}
Expr::Concat(tuple_1, tuple_2) => {
let Value::Tuple(t1) = interpret(tuple_1, arg, vm) else {
panic!("concat")
};
let Value::Tuple(t2) = interpret(tuple_2, arg, vm) else {
panic!("concat")
};
Value::Tuple(t1.into_iter().chain(t2).collect())
}
Expr::Print(e) => {
let Value::Num(n) = interpret(e, arg, vm) else {
panic!("print")
Expand Down
3 changes: 0 additions & 3 deletions tree_optimizer/src/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ impl Constructor {
use Purpose::{CapturedExpr, CapturingId, ReferencingId, Static, SubExpr, SubListExpr};
let f = |purpose, name| Field { purpose, name };
match self {
Constructor::Expr(Expr::Concat(..)) => {
todo!("Remove concat from enum")
}
Constructor::Expr(Expr::Function(..)) => {
vec![
f(Static(Sort::IdSort), "id"),
Expand Down

0 comments on commit c50e1a8

Please sign in to comment.