Skip to content

Commit

Permalink
tests passing with function types
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Jan 30, 2024
1 parent d9d7282 commit 1f99e1d
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 102 deletions.
12 changes: 0 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ bril-rs = { git = "https://github.com/sampsyo/bril", rev = "daaff28" }
ordered-float = { version = "3.7" }
serde_json = "1.0.103"

tree-unique-args = { path = "tree_unique_args" }
tree_optimizer = { path = "tree_optimizer" }

# binary dependencies
Expand Down
130 changes: 69 additions & 61 deletions src/rvsdg/tree_unique/to_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
//! computed once in the tree encoded
//! program.

use std::iter;

#[cfg(test)]
use crate::{cfg::program_to_cfg, rvsdg::cfg_to_rvsdg, util::parse_from_string};
#[cfg(test)]
use tree_unique_args::ast::program;
use bril_rs::Type;
#[cfg(test)]
use tree_optimizer::ast::program;

use crate::rvsdg::{BasicExpr, Id, Operand, RvsdgBody, RvsdgFunction, RvsdgProgram};
use bril_rs::{Literal, ValueOps};
Expand Down Expand Up @@ -211,21 +211,9 @@ impl RvsdgFunction {
.collect::<Vec<_>>();

function(
self.name.clone(),
TreeType::Tuple(
self.args
.iter()
.map(|ty| ty.to_tree_type())
.chain(iter::once(TreeType::Unit))
.collect(),
),
TreeType::Tuple(
self.results
.iter()
.map(|r| r.0.to_tree_type())
.chain(iter::once(TreeType::Unit))
.collect(),
),
self.name.as_str(),
TreeType::Tuple(self.args.iter().map(|ty| ty.to_tree_type()).collect()),
TreeType::Tuple(self.results.iter().map(|r| r.0.to_tree_type()).collect()),
translator.build_translation(parallel_vec(translated_results)),
)
}
Expand All @@ -251,22 +239,27 @@ fn translate_simple_loop() {

rvsdg
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(cbind(
num(1), // [(), 1]
.assert_eq_ignoring_ids(&program!(function(
"myfunc",
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Bril(Type::Int), TreeType::Unit]),
cbind(
num(2), // [(), 1, 2]
num(1), // [(), 1]
cbind(
tloop(
parallel!(getarg(0), getarg(1), getarg(2)), // [(), 1, 2]
cbind(
lessthan(getarg(1), getarg(2)), // [(), 1, 2, 1<2]
parallel!(getarg(3), parallel!(getarg(0), getarg(1), getarg(2)))
)
), // [(), 1, 2, [(), 1, 2]]
parallel!(get(getarg(3), 1), get(getarg(3), 0)) // return [1, ()]
),
num(2), // [(), 1, 2]
cbind(
tloop(
parallel!(getarg(0), getarg(1), getarg(2)), // [(), 1, 2]
cbind(
lessthan(getarg(1), getarg(2)), // [(), 1, 2, 1<2]
parallel!(getarg(3), parallel!(getarg(0), getarg(1), getarg(2)))
)
), // [(), 1, 2, [(), 1, 2]]
parallel!(get(getarg(3), 1), get(getarg(3), 0)) // return [1, ()]
),
)
)
))));
)));
}

#[test]
Expand All @@ -291,31 +284,36 @@ fn translate_loop() {

rvsdg
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(cbind(
num(0), // [(), 0]
.assert_eq_ignoring_ids(&program!(function(
"main",
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Unit]),
cbind(
tloop(
parallel!(getarg(0), getarg(1)),
cbind(
num(1), // [(), i, 1]
num(0), // [(), 0]
cbind(
tloop(
parallel!(getarg(0), getarg(1)),
cbind(
add(getarg(1), getarg(2)), // [(), i, 1, i+1]
num(1), // [(), i, 1]
cbind(
num(10), // [(), i, 1, i+1, 10]
add(getarg(1), getarg(2)), // [(), i, 1, i+1]
cbind(
lessthan(getarg(3), getarg(4)), // [(), i, 1, i+1, 10, i<10]
parallel!(getarg(5), parallel!(getarg(0), getarg(3)))
num(10), // [(), i, 1, i+1, 10]
cbind(
lessthan(getarg(3), getarg(4)), // [(), i, 1, i+1, 10, i<10]
parallel!(getarg(5), parallel!(getarg(0), getarg(3)))
)
)
)
)
),
cbind(
print(get(getarg(2), 1)), // [(), 0, [() i]]
parallel!(getarg(3))
)
),
cbind(
print(get(getarg(2), 1)), // [(), 0, [() i]]
parallel!(getarg(3))
)
),
))));
)
)));
}

#[test]
Expand All @@ -334,13 +332,18 @@ fn simple_translation() {

rvsdg
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(cbind(
num(1),
.assert_eq_ignoring_ids(&program!(function(
"add",
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Bril(Type::Int), TreeType::Unit]),
cbind(
add(get(arg(), 1), get(arg(), 1)),
parallel!(get(arg(), 2), get(arg(), 0)), // returns res and print state (unit)
),
))));
num(1),
cbind(
add(get(arg(), 1), get(arg(), 1)),
parallel!(get(arg(), 2), get(arg(), 0)), // returns res and print state (unit)
),
)
)));
}

#[test]
Expand All @@ -361,17 +364,22 @@ fn two_print_translation() {

rvsdg
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(cbind(
num(2),
.assert_eq_ignoring_ids(&program!(function(
"add",
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Unit]),
cbind(
num(1),
num(2),
cbind(
add(get(arg(), 2), get(arg(), 1)),
num(1),
cbind(
print(get(arg(), 3)),
cbind(print(get(arg(), 1)), parallel!(get(arg(), 5))),
add(get(arg(), 2), get(arg(), 1)),
cbind(
print(get(arg(), 3)),
cbind(print(get(arg(), 1)), parallel!(get(arg(), 5))),
),
),
),
),
))));
)
)));
}
64 changes: 37 additions & 27 deletions tree_optimizer/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,14 @@ macro_rules! parallel {
}
pub use parallel;

pub fn switch(arg: Expr, cases: Vec<Expr>) -> Expr {
let cases_wrapped = cases
.into_iter()
.map(|case| Branch(Shared, Box::new(case)))
.collect();
Switch(Box::new(arg), cases_wrapped)
#[macro_export]
macro_rules! switch {
($arg:expr, $($x:expr),*) => ($crate::ast::switch_vec($arg, vec![$($x),*]))
}
pub use switch;

pub fn switch_vec(arg: Expr, cases: Vec<Expr>) -> Expr {
Switch(Box::new(arg), cases)
}

pub fn tloop(input: Expr, body: Expr) -> Expr {
Expand All @@ -189,8 +191,8 @@ pub fn arg() -> Expr {
Arg(Shared)
}

pub fn function(name: String, in_ty: TreeType, out_ty: TreeType, arg: Expr) -> Expr {
Function(Shared, name, in_ty, out_ty, Box::new(arg))
pub fn function(name: &str, in_ty: TreeType, out_ty: TreeType, arg: Expr) -> Expr {
Function(Shared, name.to_string(), in_ty, out_ty, Box::new(arg))
}

pub fn call(arg: Expr) -> Expr {
Expand Down Expand Up @@ -229,36 +231,44 @@ fn test_gives_loop_ids() {
fn test_complex_program_ids() {
// test a program that includes
// a let, a loop, a switch, and a call
let prog = program!(function(tlet(
num(0),
tloop(
num(1),
switch!(
arg(),
num(2),
call(num(3)),
tlet(num(4), num(5)),
tloop(num(6), num(7))
let prog = program!(function(
"main",
TreeType::Unit,
TreeType::Unit,
tlet(
num(0),
tloop(
num(1),
switch!(
arg(),
num(2),
call(num(3)),
tlet(num(4), num(5)),
tloop(num(6), num(7))
),
),
),
)));
)
));
assert_eq!(
prog,
Program(vec![Function(
Id(1),
Unique(1),
"main".to_string(),
TreeType::Unit,
TreeType::Unit,
Box::new(Let(
Id(2),
Unique(2),
Box::new(Num(0)),
Box::new(Loop(
Id(3),
Unique(3),
Box::new(Num(1)),
Box::new(Switch(
Box::new(Arg(Id(3))),
Box::new(Arg(Unique(3))),
vec![
Num(2),
Call(Id(3), Box::new(Num(3))),
Let(Id(4), Box::new(Num(4)), Box::new(Num(5))),
Loop(Id(5), Box::new(Num(6)), Box::new(Num(7))),
Call(Unique(3), Box::new(Num(3))),
Let(Unique(4), Box::new(Num(4)), Box::new(Num(5))),
Loop(Unique(5), Box::new(Num(6)), Box::new(Num(7))),
]
))
))
Expand Down
2 changes: 1 addition & 1 deletion tree_optimizer/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ impl std::str::FromStr for Expr {
("TupleT", xs) => {
let tys = xs
.iter()
.map(|x| egglog_type_to_type(x))
.map(egglog_type_to_type)
.collect::<Result<Vec<_>, _>>()?;
Ok(Tuple(tys))
}
Expand Down

0 comments on commit 1f99e1d

Please sign in to comment.