Skip to content

Commit

Permalink
fix up more small issues
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Jan 30, 2024
1 parent c50e1a8 commit 821dbd2
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 133 deletions.
29 changes: 15 additions & 14 deletions src/rvsdg/tree_unique/to_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ use crate::{cfg::program_to_cfg, rvsdg::cfg_to_rvsdg, util::parse_from_string};
#[cfg(test)]
use bril_rs::Type;
#[cfg(test)]
use tree_optimizer::ast::program;
use tree_optimizer::ast::{arg, program};

use crate::rvsdg::{BasicExpr, Id, Operand, RvsdgBody, RvsdgFunction, RvsdgProgram};
use bril_rs::{Literal, ValueOps};
use hashbrown::HashMap;
use tree_optimizer::{
ast::{
add, function, get, getarg, lessthan, num, parallel, parallel_vec, program_vec,
tfalse, tlet, tloop, tprint, ttrue,
add, function, get, getarg, lessthan, num, parallel, parallel_vec, program_vec, tfalse,
tlet, tloop, tprint, ttrue,
},
expr::{Expr, TreeType},
};
Expand Down Expand Up @@ -52,11 +52,12 @@ struct RegionTranslator<'a> {
}

/// helper that binds a new expression, adding it
/// to the environment using concat
/// to the environment by concatenating all previous values
/// with the new one
fn cbind(index: usize, expr: Expr, body: Expr) -> Expr {
let mut concatted = vec![];
for i in 0..index {
concatted.push(getarg(i as usize));
concatted.push(getarg(i));
}
concatted.push(expr);
tlet(parallel_vec(concatted), body)
Expand Down Expand Up @@ -92,8 +93,8 @@ impl<'a> RegionTranslator<'a> {
fn build_translation(&self, inner: Expr) -> Expr {
let mut expr = inner;

for (i, binding) in self.bindings.iter().rev().enumerate() {
expr = cbind(i, binding.clone(), expr);
for (i, binding) in self.bindings.iter().enumerate().rev() {
expr = cbind(i + self.num_args, binding.clone(), expr);
}
expr
}
Expand Down Expand Up @@ -259,7 +260,7 @@ fn translate_simple_loop() {
tloop(
parallel!(getarg(0), getarg(1), getarg(2)), // [(), 1, 2]
cbind(
4,
3,
lessthan(getarg(1), getarg(2)), // [(), 1, 2, 1<2]
parallel!(getarg(3), parallel!(getarg(0), getarg(1), getarg(2)))
)
Expand Down Expand Up @@ -305,16 +306,16 @@ fn translate_loop() {
tloop(
parallel!(getarg(0), getarg(1)),
cbind(
3,
2,
num(1), // [(), i, 1]
cbind(
4,
3,
add(getarg(1), getarg(2)), // [(), i, 1, i+1]
cbind(
5,
4,
num(10), // [(), i, 1, i+1, 10]
cbind(
6,
5,
lessthan(getarg(3), getarg(4)), // [(), i, 1, i+1, 10, i<10]
parallel!(getarg(5), parallel!(getarg(0), getarg(3)))
)
Expand All @@ -323,8 +324,8 @@ fn translate_loop() {
)
),
cbind(
2,
tprint(get(getarg(2), 1)), // [(), 0, [() i]]
3,
tprint(get(getarg(2), 1)), // [(), 0, [() i], ()]
parallel!(getarg(3))
)
),
Expand Down
94 changes: 72 additions & 22 deletions tree_optimizer/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ fn give_fresh_ids_helper(expr: &mut Expr, current_id: i64, fresh_id: &mut i64) {
/// a vec for program
#[macro_export]
macro_rules! program {
($($x:expr),*) => ($crate::ast::program_vec(vec![$($x),*]))
($($x:expr),* $(,)?) => ($crate::ast::program_vec(vec![$($x),*]))
}
use bril_rs::Type;
pub use program;

pub fn program_vec(args: Vec<Expr>) -> Expr {
Expand All @@ -96,14 +97,30 @@ pub fn program_vec(args: Vec<Expr>) -> Expr {
}

pub fn num(n: i64) -> Expr {
Num(n)
Num(Shared, n)
}

pub fn ttrue() -> Expr {
Boolean(true)
Boolean(Shared, true)
}
pub fn tfalse() -> Expr {
Boolean(false)
Boolean(Shared, false)
}

pub fn tint() -> TreeType {
TreeType::Bril(Type::Int)
}

pub fn tbool() -> TreeType {
TreeType::Bril(Type::Bool)
}

pub fn twrite(addr: Expr, data: Expr) -> Expr {
Write(Box::new(addr), Box::new(data))
}

pub fn tread(addr: Expr, ty: TreeType) -> Expr {
Read(Box::new(addr), ty)
}

pub fn add(a: Expr, b: Expr) -> Expr {
Expand Down Expand Up @@ -142,6 +159,10 @@ pub fn get(a: Expr, i: usize) -> Expr {
Get(Box::new(a), i)
}

pub fn branch(a: Expr) -> Expr {
Branch(Shared, Box::new(a))
}

pub fn tprint(a: Expr) -> Expr {
Print(Box::new(a))
}
Expand All @@ -153,7 +174,7 @@ pub fn sequence_vec(args: Vec<Expr>) -> Expr {
#[macro_export]
macro_rules! sequence {
// use crate::ast::sequence_vec to resolve import errors
($($x:expr),*) => ($crate::ast::sequence_vec(vec![$($x),*]))
($($x:expr),* $(,)?) => ($crate::ast::sequence_vec(vec![$($x),*]))
}
pub use sequence;

Expand All @@ -163,13 +184,13 @@ pub fn parallel_vec(args: Vec<Expr>) -> Expr {

#[macro_export]
macro_rules! parallel {
($($x:expr),*) => ($crate::ast::parallel_vec(vec![$($x),*]))
($($x:expr),* $(,)?) => ($crate::ast::parallel_vec(vec![$($x),*]))
}
pub use parallel;

#[macro_export]
macro_rules! switch {
($arg:expr, $($x:expr),*) => ($crate::ast::switch_vec($arg, vec![$($x),*]))
($arg:expr, $($x:expr),* $(,)?) => ($crate::ast::switch_vec($arg, vec![$($x),*]))
}
pub use switch;

Expand Down Expand Up @@ -205,8 +226,12 @@ fn test_gives_nested_ids() {
prog,
Let(
Unique(1),
Box::new(Num(0)),
Box::new(Let(Unique(2), Box::new(Num(1)), Box::new(Num(2))))
Box::new(Num(Unique(1), 0)),
Box::new(Let(
Unique(2),
Box::new(Num(Unique(2), 1)),
Box::new(Num(Unique(2), 2))
))
)
);
}
Expand All @@ -219,8 +244,12 @@ fn test_gives_loop_ids() {
prog,
Let(
Unique(1),
Box::new(Num(0)),
Box::new(Loop(Unique(2), Box::new(Num(1)), Box::new(Num(2))))
Box::new(Num(Unique(1), 0)),
Box::new(Loop(
Unique(2),
Box::new(Num(Unique(2), 1)),
Box::new(Num(Unique(2), 2))
))
)
);
}
Expand All @@ -238,11 +267,11 @@ fn test_complex_program_ids() {
tloop(
num(1),
switch!(
arg(),
num(2),
call("otherfunc", num(3)),
tlet(num(4), num(5)),
tloop(num(6), num(7))
branch(arg()),
branch(num(2)),
branch(call("otherfunc", num(3))),
branch(tlet(num(4), num(5))),
branch(tloop(num(6), num(7))),
),
),
)
Expand All @@ -256,17 +285,38 @@ fn test_complex_program_ids() {
TreeType::Tuple(vec![]),
Box::new(Let(
Unique(2),
Box::new(Num(0)),
Box::new(Num(Unique(2), 0)),
Box::new(Loop(
Unique(3),
Box::new(Num(1)),
Box::new(Num(Unique(3), 1)),
Box::new(Switch(
Box::new(Arg(Unique(3))),
vec![
Num(2),
Call(Unique(3), "otherfunc".into(), Box::new(Num(3))),
Let(Unique(4), Box::new(Num(4)), Box::new(Num(5))),
Loop(Unique(5), Box::new(Num(6)), Box::new(Num(7))),
Branch(Unique(4), Box::new(Num(Unique(4), 2))),
Branch(
Unique(5),
Box::new(Call(
Unique(5),
"otherfunc".into(),
Box::new(Num(Unique(5), 3))
))
),
Branch(
Unique(6),
Box::new(Let(
Unique(7),
Box::new(Num(Unique(7), 4)),
Box::new(Num(Unique(7), 5))
))
),
Branch(
Unique(8),
Box::new(Loop(
Unique(9),
Box::new(Num(Unique(9), 6)),
Box::new(Num(Unique(9), 7))
))
),
]
))
))
Expand Down
18 changes: 9 additions & 9 deletions tree_optimizer/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ impl ESort {

#[derive(Clone, Debug, PartialEq, EnumIter)]
pub enum Expr {
Num(i64),
Boolean(bool),
Num(Id, i64),
Boolean(Id, bool),
BOp(PureBOp, Box<Expr>, Box<Expr>),
UOp(PureUOp, Box<Expr>),
Get(Box<Expr>, usize),
Print(Box<Expr>),
Read(Box<Expr>),
Read(Box<Expr>, TreeType),
Write(Box<Expr>, Box<Expr>),
All(Id, Order, Vec<Expr>),
/// A pred and a list of branches
Expand All @@ -198,7 +198,7 @@ pub enum Expr {

impl Default for Expr {
fn default() -> Self {
Expr::Num(0)
Expr::Num(Id::Shared, 0)
}
}

Expand All @@ -215,13 +215,13 @@ impl Expr {

pub fn name(&self) -> &'static str {
match self {
Expr::Num(_) => "Num",
Expr::Boolean(_) => "Boolean",
Expr::Num(..) => "Num",
Expr::Boolean(..) => "Boolean",
Expr::BOp(_, _, _) => "BOp",
Expr::UOp(_, _) => "UOp",
Expr::Get(_, _) => "Get",
Expr::Print(_) => "Print",
Expr::Read(_) => "Read",
Expr::Read(..) => "Read",
Expr::Write(_, _) => "Write",
Expr::All(_, _, _) => "All",
Expr::Switch(_, _) => "Switch",
Expand All @@ -238,7 +238,7 @@ impl Expr {
/// Runs `func` on every child of this expression.
pub fn for_each_child(&mut self, mut func: impl FnMut(&mut Expr)) {
match self {
Expr::Num(_) | Expr::Boolean(_) | Expr::Arg(_) => {}
Expr::Num(..) | Expr::Boolean(..) | Expr::Arg(..) => {}
Expr::BOp(_, a, b) => {
func(a);
func(b);
Expand All @@ -250,7 +250,7 @@ impl Expr {
func(a);
func(b);
}
Expr::Print(a) | Expr::Read(a) => {
Expr::Print(a) | Expr::Read(a, _) => {
func(a);
}
Expr::Get(a, _) | Expr::Function(_, _, _, _, a) | Expr::Call(_, _, a) => {
Expand Down
Loading

0 comments on commit 821dbd2

Please sign in to comment.