Skip to content

Commit

Permalink
Merge pull request #295 from egraphs-good/oflatt-tree-translation-4
Browse files Browse the repository at this point in the history
Translate RVSDG loops to tree encoding
  • Loading branch information
oflatt authored Jan 29, 2024
2 parents 98accff + cd38d0a commit 3da0a81
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 54 deletions.
200 changes: 164 additions & 36 deletions src/rvsdg/tree_unique/to_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,26 @@
//! These shared nodes need to be let-bound so that they are only
//! computed once in the tree encoded
//! program.

#[cfg(test)]
use crate::{cfg::program_to_cfg, rvsdg::cfg_to_rvsdg, util::parse_from_string};
#[cfg(test)]
use tree_unique_args::ast::program;

use crate::rvsdg::{BasicExpr, Id, Operand, RvsdgBody, RvsdgFunction, RvsdgProgram};
use bril_rs::{Literal, ValueOps};
use hashbrown::HashMap;
use tree_unique_args::{
ast::{add, arg, concat, function, get, num, print, program, sequence, tfalse, tlet, ttrue},
ast::{
add, arg, concat, function, get, getarg, lessthan, num, parallel, parallel_vec, print,
program_vec, tfalse, tlet, tloop, ttrue,
},
Expr,
};

impl RvsdgProgram {
pub fn to_tree_encoding(&self) -> Expr {
program(
program_vec(
self.functions
.iter()
.map(|f| f.to_tree_encoding())
Expand Down Expand Up @@ -63,42 +69,48 @@ impl<'a> RegionTranslator<'a> {
res
}

/// Build a translator and translate
/// the operands to the tree encoding.
/// Produces a tree-encoded term that evaluates
/// to a tuple containing results.
fn translate(num_args: usize, nodes: &'a Vec<RvsdgBody>, results: Vec<Operand>) -> Expr {
let mut translator = RegionTranslator {
/// Make a new translator for a region with
/// num_args and the given nodes.
fn new(num_args: usize, nodes: &'a Vec<RvsdgBody>) -> RegionTranslator {
RegionTranslator {
num_args,
bindings: Vec::new(),
index_of: HashMap::new(),
nodes,
};

let mut result_indices = Vec::new();
for result in results {
result_indices.push(translator.translate_operand(result));
}
}

let mut expr = sequence(result_indices.iter().map(|i| get(arg(), *i)).collect());
/// Wrap the given expression in all the
/// bindings that have been generated.
fn build_translation(&self, inner: Expr) -> Expr {
let mut expr = inner;

for binding in translator.bindings.into_iter().rev() {
expr = cbind(binding, expr);
for binding in self.bindings.iter().rev() {
expr = cbind(binding.clone(), expr);
}
expr
}

fn translate_operand(&mut self, operand: Operand) -> usize {
/// Returns a pure expression (e.g. `getarg(0)`) that
/// returns the value for this operand.
/// The value of the operand is let-bound
/// and the expression refers to it.
fn translate_operand(&mut self, operand: Operand) -> Expr {
match operand {
Operand::Arg(index) => index,
Operand::Id(id) => self.translate_node(id),
Operand::Project(_id, _indexx) => {
todo!("Doesn't handle subregions yet");
Operand::Arg(index) => getarg(index),
Operand::Id(id) => getarg(self.translate_node(id)),
Operand::Project(p_index, id) => {
// Translated region becomes a tuple in the environment.
// This is the index of that tuple.
let index = self.translate_node(id);
get(getarg(index), p_index)
}
}
}

/// Translate a node or return the index of the already evaluated node.
/// Translate a node or return the index of the already-translated node.
/// For regions, translates the region and returns the index of the
/// tuple containing the results.
/// It's important not to evaluate a node twice, instead using the cached index
/// in `self.index_of`
fn translate_node(&mut self, id: Id) -> usize {
Expand All @@ -108,7 +120,29 @@ impl<'a> RegionTranslator<'a> {
let node = &self.nodes[id];
match node {
RvsdgBody::BasicOp(expr) => self.translate_basic_expr(expr.clone(), id),
_ => todo!("Doesn't handle subregions yet"),
RvsdgBody::Gamma { .. } => todo!("Doesn't handle gamma yet"),
RvsdgBody::Theta {
pred,
inputs,
outputs,
} => {
let mut translated_inputs = vec![];
// for loop instead of iterator because of lifetimes
for input in inputs {
translated_inputs.push(self.translate_operand(*input));
}

let mut sub_translator = RegionTranslator::new(inputs.len(), self.nodes);
let pred_translated = sub_translator.translate_operand(*pred);
let outputs_translated =
outputs.iter().map(|o| sub_translator.translate_operand(*o));
let pred_and_outputs =
parallel!(pred_translated, parallel_vec(outputs_translated.collect()));
let loop_translated = sub_translator.build_translation(pred_and_outputs);

let loop_expr = tloop(parallel_vec(translated_inputs), loop_translated);
self.add_binding(loop_expr, id)
}
}
}
}
Expand All @@ -120,10 +154,11 @@ impl<'a> RegionTranslator<'a> {
BasicExpr::Op(op, children, _ty) => {
let children = children
.iter()
.map(|c| get(arg(), self.translate_operand(*c)))
.map(|c| self.translate_operand(*c))
.collect::<Vec<_>>();
let expr = match (op, children.as_slice()) {
(ValueOps::Add, [a, b]) => add(a.clone(), b.clone()),
(ValueOps::Lt, [a, b]) => lessthan(a.clone(), b.clone()),
_ => todo!("handle other ops"),
};
self.add_binding(expr, id)
Expand All @@ -145,8 +180,11 @@ impl<'a> RegionTranslator<'a> {
BasicExpr::Print(args) => {
assert!(args.len() == 2, "print should have 2 arguments");
let arg1 = self.translate_operand(args[0]);
// argument 2 should have value unit, since it is
// the print buffer value.
let _arg2 = self.translate_operand(args[1]);
let expr = print(get(arg(), arg1));
// print outputs a new unit value
let expr = print(arg1);
self.add_binding(expr, id)
}
}
Expand All @@ -163,14 +201,104 @@ impl RvsdgFunction {
/// In the inner-most scope, the value of
/// all nodes is available.
pub fn to_tree_encoding(&self) -> Expr {
function(RegionTranslator::translate(
self.args.len(),
&self.nodes,
self.results.iter().map(|r| r.1).collect(),
))
let mut translator = RegionTranslator::new(self.args.len(), &self.nodes);
let translated_results = self
.results
.iter()
.map(|r| translator.translate_operand(r.1))
.collect::<Vec<_>>();

function(translator.build_translation(parallel_vec(translated_results)))
}
}

#[test]
fn translate_simple_loop() {
const PROGRAM: &str = r#"
@myfunc(): int {
.entry:
one: int = const 1;
two: int = const 2;
.loop:
cond: bool = lt one two;
br cond .loop .exit;
.exit:
ret one;
}
"#;
let prog = parse_from_string(PROGRAM);
let cfg = program_to_cfg(&prog);
let rvsdg = cfg_to_rvsdg(&cfg).unwrap();

rvsdg
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(cbind(
num(1), // [(), 1]
cbind(
num(2), // [(), 1, 2]
cbind(
tloop(
parallel!(getarg(0), getarg(1), getarg(2)), // [(), 1, 2]
cbind(
lessthan(getarg(1), getarg(2)), // [(), 1, 2, 1<2]
parallel!(getarg(3), parallel!(getarg(0), getarg(1), getarg(2)))
)
), // [(), 1, 2, [(), 1, 2]]
parallel!(get(getarg(3), 1), get(getarg(3), 0)) // return [1, ()]
),
)
))));
}

#[test]
fn translate_loop() {
const PROGRAM: &str = r#"
@main {
.entry:
i: int = const 0;
.loop:
max: int = const 10;
one: int = const 1;
i: int = add i one;
cond: bool = lt i max;
br cond .loop .exit;
.exit:
print i;
}
"#;
let prog = parse_from_string(PROGRAM);
let cfg = program_to_cfg(&prog);
let rvsdg = cfg_to_rvsdg(&cfg).unwrap();

rvsdg
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(cbind(
num(0), // [(), 0]
cbind(
tloop(
parallel!(getarg(0), getarg(1)),
cbind(
num(1), // [(), i, 1]
cbind(
add(getarg(1), getarg(2)), // [(), i, 1, i+1]
cbind(
num(10), // [(), i, 1, i+1, 10]
cbind(
lessthan(getarg(3), getarg(4)), // [(), i, 1, i+1, 10, i<10]
parallel!(getarg(5), parallel!(getarg(0), getarg(3)))
)
)
)
)
),
cbind(
print(get(getarg(2), 1)), // [(), 0, [() i]]
parallel!(getarg(3))
)
),
))));
}

#[test]
fn simple_translation() {
const PROGRAM: &str = r#"
Expand All @@ -187,13 +315,13 @@ fn simple_translation() {

rvsdg
.to_tree_encoding()
.assert_eq_ignoring_ids(&program(vec![function(cbind(
.assert_eq_ignoring_ids(&program!(function(cbind(
num(1),
cbind(
add(get(arg(), 1), get(arg(), 1)),
sequence(vec![get(arg(), 2), get(arg(), 0)]), // returns res and print state (unit)
parallel!(get(arg(), 2), get(arg(), 0)), // returns res and print state (unit)
),
))]));
))));
}

#[test]
Expand All @@ -214,17 +342,17 @@ fn two_print_translation() {

rvsdg
.to_tree_encoding()
.assert_eq_ignoring_ids(&program(vec![function(cbind(
.assert_eq_ignoring_ids(&program!(function(cbind(
num(2),
cbind(
num(1),
cbind(
add(get(arg(), 2), get(arg(), 1)),
cbind(
print(get(arg(), 3)),
cbind(print(get(arg(), 1)), sequence(vec![get(arg(), 5)])),
cbind(print(get(arg(), 1)), parallel!(get(arg(), 5))),
),
),
),
))]));
))));
}
Loading

0 comments on commit 3da0a81

Please sign in to comment.