Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an If variant to Rvsdg struct #306

Merged
merged 4 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/cfg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,11 @@ pub(crate) enum BranchOp {
/// An unconditional branch to a block.
Jmp,
/// A conditional branch to a block.
Cond { arg: Identifier, val: CondVal },
Cond {
arg: Identifier,
val: CondVal,
bril_type: Type,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's document what different "types" here mean. I think this has caused some confusion in the past, and better signposting in the code could help.

},
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -423,10 +427,12 @@ impl SimpleCfgFunction {
BranchOp::Cond {
arg: cond1,
val: CondVal { val: val1, of: 2 },
bril_type: Type::Bool,
},
BranchOp::Cond {
arg: cond2,
val: CondVal { val: val2, of: 2 },
bril_type: Type::Bool,
},
) => {
assert_eq!(cond1, cond2);
Expand Down Expand Up @@ -517,6 +523,7 @@ pub(crate) fn function_to_cfg(func: &Function) -> SimpleCfgFunction {
op: BranchOp::Cond {
arg: arg.into(),
val: true.into(),
bril_type: Type::Bool,
},
pos: pos.clone(),
},
Expand All @@ -528,6 +535,7 @@ pub(crate) fn function_to_cfg(func: &Function) -> SimpleCfgFunction {
op: BranchOp::Cond {
arg: arg.into(),
val: false.into(),
bril_type: Type::Bool,
},
pos: pos.clone(),
},
Expand Down
42 changes: 21 additions & 21 deletions src/cfg/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
EggCCError,
};
use bril2json::parse_abstract_program_from_read;
use bril_rs::{load_program_from_read, Program};
use bril_rs::{load_program_from_read, Program, Type};

fn parse_from_string(input: &str) -> Program {
let abs_program = parse_abstract_program_from_read(input.as_bytes(), true, false, None);
Expand Down Expand Up @@ -32,8 +32,8 @@ cfg_test_function_to_cfg!(
include_str!("../../tests/brils/failing/mem/fib.bril"),
[
ENTRY = (Jmp) => "loop",
"loop" = (Cond { arg: "cond".into(), val: true.into() }) => "body",
"loop" = (Cond { arg: "cond".into(), val: false.into() }) => "done",
"loop" = (true_cond("cond")) => "body",
"loop" = (false_cond("cond")) => "done",
"body" = (Jmp) => "loop",
"done" = (Jmp) => EXIT,
]
Expand All @@ -43,12 +43,12 @@ cfg_test_function_to_cfg!(
queen,
include_str!("../../tests/small/failing/queens-func.bril"),
[
ENTRY = (Cond { arg: "ret_cond".into(), val: true.into() }) => "main.next.ret",
ENTRY = (Cond { arg: "ret_cond".into(), val: false.into() }) => "main.for.cond",
"main.for.cond" = (Cond { arg: "for_cond_0".into(), val: true.into() }) => "main.for.body",
"main.for.cond" = (Cond { arg: "for_cond_0".into(), val: false.into() }) => "main.next.ret.1",
"main.for.body" = (Cond { arg: "is_valid".into(), val: true.into() }) => "main.rec.func",
"main.for.body" = (Cond { arg: "is_valid".into(), val: false.into() }) => "main.next.loop",
ENTRY = (Cond { arg: "ret_cond".into(), val: true.into(), bril_type: Type::Bool }) => "main.next.ret",
ENTRY = (Cond { arg: "ret_cond".into(), val: false.into(), bril_type: Type::Bool }) => "main.for.cond",
"main.for.cond" = (Cond { arg: "for_cond_0".into(), val: true.into(), bril_type: Type::Bool }) => "main.for.body",
"main.for.cond" = (Cond { arg: "for_cond_0".into(), val: false.into(), bril_type: Type::Bool }) => "main.next.ret.1",
"main.for.body" = (Cond { arg: "is_valid".into(), val: true.into(), bril_type: Type::Bool }) => "main.rec.func",
"main.for.body" = (Cond { arg: "is_valid".into(), val: false.into(), bril_type: Type::Bool }) => "main.next.loop",
"main.rec.func" = (Jmp) => "main.next.loop",
"main.next.loop" = (Jmp) => "main.for.cond",
"main.next.ret" = (Jmp) => "main.print",
Expand All @@ -69,8 +69,8 @@ cfg_test_function_to_cfg!(
diamond,
include_str!("../../tests/small/diamond.bril"),
[
ENTRY = (Cond { arg: "cond".into(), val: true.into() }) => "B",
ENTRY = (Cond { arg: "cond".into(), val: false.into() }) => "C",
ENTRY = (Cond { arg: "cond".into(), val: true.into(), bril_type: Type::Bool }) => "B",
ENTRY = (Cond { arg: "cond".into(), val: false.into(), bril_type: Type::Bool }) => "C",
"B" = (Jmp) => "D",
"C" = (Jmp) => "D",
"D" = (Jmp) => EXIT,
Expand All @@ -81,10 +81,10 @@ cfg_test_function_to_cfg!(
block_diamond,
include_str!("../../tests/small/block-diamond.bril"),
[
ENTRY = (Cond { arg: "a_cond".into(), val: true.into() }) => "B",
ENTRY = (Cond { arg: "a_cond".into(), val: false.into() }) => "D",
"B" = (Cond { arg: "b_cond".into(), val: true.into() }) => "C",
"B" = (Cond { arg: "b_cond".into(), val: false.into() }) => "E",
ENTRY = (true_cond("a_cond")) => "B",
ENTRY = (false_cond("a_cond")) => "D",
"B" = (true_cond("b_cond")) => "C",
"B" = (false_cond("b_cond")) => "E",
"C" = (Jmp) => "F",
"D" = (Jmp) => "E",
"E" = (Jmp) => "F",
Expand All @@ -96,10 +96,10 @@ cfg_test_function_to_cfg!(
unstructured,
include_str!("../../tests/small/should_fail/unstructured.bril"),
[
ENTRY = (Cond { arg: "a_cond".into(), val: true.into() }) => "B",
ENTRY = (Cond { arg: "a_cond".into(), val: false.into() }) => "C",
"B" = (Cond { arg: "b_cond".into(), val: true.into() }) => "C",
"B" = (Cond { arg: "b_cond".into(), val: false.into() }) => "D",
ENTRY = (true_cond("a_cond")) => "B",
ENTRY = (false_cond("a_cond")) => "C",
"B" = (true_cond("b_cond")) => "C",
"B" = (false_cond("b_cond")) => "D",
"C" = (Jmp) => "B",
"D" = (Jmp) => EXIT,
]
Expand All @@ -110,8 +110,8 @@ cfg_test_function_to_cfg!(
include_str!("../../tests/small/fib_shape.bril"),
[
ENTRY = (Jmp) => "loop",
"loop" = (Cond { arg: "cond".into(), val: true.into() }) => "body",
"loop" = (Cond { arg: "cond".into(), val: false.into() }) => "done",
"loop" = (true_cond("cond")) => "body",
"loop" = (false_cond("cond")) => "done",
"body" = (Jmp) => "loop",
"done" = (Jmp) => EXIT,
]
Expand Down
2 changes: 2 additions & 0 deletions src/cfg/to_structured.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

use std::collections::HashMap;

use bril_rs::Type;
use petgraph::{
algo::dominators::{self, Dominators},
prelude::NodeIndex,
Expand Down Expand Up @@ -149,6 +150,7 @@ impl<'a> StructuredCfgBuilder<'a> {
BranchOp::Cond {
val: val1,
arg: arg1,
bril_type: Type::Bool,
},
..
},
Expand Down
88 changes: 70 additions & 18 deletions src/rvsdg/from_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use petgraph::visit::EdgeRef;
use petgraph::Direction;
use petgraph::{algo::dominators::Dominators, stable_graph::NodeIndex};

use crate::cfg::{ret_id, Annotation, BranchOp, CondVal, Identifier, SwitchCfgFunction};
use crate::cfg::{ret_id, Annotation, BranchOp, CondVal, SwitchCfgFunction};
use crate::rvsdg::Result;

use super::live_variables::{live_variables, Names};
Expand Down Expand Up @@ -205,11 +205,17 @@ impl<'a> RvsdgBuilder<'a> {
BranchOp::Cond {
arg,
val: CondVal { val, of },
bril_type,
} => {
assert_eq!(
of, 2,
"loop predicate has more than two options (restructuring should avoid this)"
);
assert_eq!(
bril_type,
Type::Bool,
"loop predicate is not a boolean in RVSDG translation"
);
let var = self.analysis.intern.intern(arg);
let op = get_op(var, &None, &self.store, &self.analysis.intern)?;
if val == 0 {
Expand Down Expand Up @@ -259,18 +265,42 @@ impl<'a> RvsdgBuilder<'a> {
.neighbors_directed(block, Direction::Outgoing)
.next());
}
let placeholder = Identifier::Num(!0);
let mut pred = placeholder.clone();
let mut succs = Vec::from_iter(self.cfg.graph.edges_directed(block, Direction::Outgoing).map(|e| {
if let BranchOp::Cond { arg, val: CondVal { val, of:_ }} = &e.weight().op {
if pred == placeholder {
pred = arg.clone();
}
(*val, e.target())

let mut succs_iter = self.cfg.graph.edges_directed(block, Direction::Outgoing);
let mut succs = vec![];
let first_e = succs_iter.next();
// Bind pred, first_val, and bril_type from the first edge
let Some(BranchOp::Cond {
arg: pred,
val: CondVal {
val: first_val,
of: _,
},
bril_type,
}) = first_e.map(|e| e.weight().op.clone())
else {
panic!("Couldn't find a conditional branch in block {block:?}");
oflatt marked this conversation as resolved.
Show resolved Hide resolved
};
succs.push((first_val, first_e.unwrap().target()));
// for the rest of the edges, make sure pred and bril_type match up
for e in succs_iter {
if let BranchOp::Cond {
arg,
val: CondVal { val, of: _ },
bril_type: other_bril_type,
} = &e.weight().op
{
assert_eq!(
bril_type, *other_bril_type,
"Mismatched types in conditional branches in block {block:?}"
);
assert_eq!(pred, *arg, "Multiple predicates in block {block:?}");
succs.push((*val, e.target()));
} else {
panic!("Invalid mix of conditional and non-conditional branches in block {block:?}")
}
}));
}

let pred_var = self.analysis.intern.intern(pred);
let pred_op = get_op(
pred_var,
Expand Down Expand Up @@ -353,14 +383,36 @@ impl<'a> RvsdgBuilder<'a> {

let next = next.unwrap();
let pred = pred_op;
let gamma_node = get_id(
&mut self.expr,
RvsdgBody::Gamma {
pred,
inputs,
outputs,
},
);
let gamma_node = if bril_type == Type::Bool {
oflatt marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(
outputs.len(),
2,
"Found wrong number of branches for boolean.",
);
get_id(
&mut self.expr,
RvsdgBody::If {
pred,
inputs,
then_branch: outputs[1].clone(),
else_branch: outputs[0].clone(),
},
)
} else {
assert_eq!(
bril_type,
Type::Int,
"Branch predicate should be bool or integer"
);
get_id(
&mut self.expr,
RvsdgBody::Gamma {
pred,
inputs,
outputs,
},
)
};
// Remap all input variables to the output of this node.
for (i, var) in output_vars.iter().copied().enumerate() {
self.store.insert(var, Operand::Project(i, gamma_node));
Expand Down
1 change: 1 addition & 0 deletions src/rvsdg/live_variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub(crate) fn live_variables(cfg: &SwitchCfgFunction) -> LiveVariableAnalysis {
if let BranchOp::Cond {
arg,
val: CondVal { val: _, of },
bril_type: _,
} = &edge.weight().op
{
if *of > 1 {
Expand Down
10 changes: 9 additions & 1 deletion src/rvsdg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,17 @@ pub(crate) enum Operand {
pub(crate) enum RvsdgBody {
BasicOp(BasicExpr<Operand>),

/// Conditional branch, witha boolean predicate.
oflatt marked this conversation as resolved.
Show resolved Hide resolved
If {
pred: Operand,
inputs: Vec<Operand>,
/// invariant: then_branch and else_branch have same length
then_branch: Vec<Operand>,
else_branch: Vec<Operand>,
},

/// Conditional branch, where the outputs chosen depend on the predicate.
Gamma {
/// always has type bool
pred: Operand,
inputs: Vec<Operand>,
/// invariant: all of the vecs in output have
Expand Down
5 changes: 5 additions & 0 deletions src/rvsdg/restructure.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Convert a potentially irreducible CFG to a reducible one.

use bril_rs::Type;
use hashbrown::{HashMap, HashSet};
use petgraph::{
algo::{dominators, tarjan_scc},
Expand Down Expand Up @@ -55,6 +56,9 @@ impl SwitchCfgFunction {
self.restructure_branches(&mut state);
}

/// Using a boolean predicate,
/// add a branch to the graph that jumps from `from` and to
/// `to` when the predicate has value `cv`.
fn branch_if(
&mut self,
from: NodeIndex,
Expand All @@ -69,6 +73,7 @@ impl SwitchCfgFunction {
op: BranchOp::Cond {
arg: id.clone(),
val: cv,
bril_type: Type::Bool,
},
pos: None,
},
Expand Down
18 changes: 18 additions & 0 deletions src/rvsdg/rvsdg2svg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,18 @@ fn mk_node_and_input_edges(index: Id, nodes: &[RvsdgBody]) -> (Node, Vec<Edge>)
),
once(pred).chain(inputs).copied().collect::<Vec<_>>(),
),
RvsdgBody::If {
pred,
inputs,
then_branch,
else_branch,
} => (
Node::Match(vec![
("true".into(), mk_region(inputs.len(), then_branch, nodes)),
("false".into(), mk_region(inputs.len(), else_branch, nodes)),
]),
once(pred).chain(inputs).copied().collect::<Vec<_>>(),
),
RvsdgBody::Theta {
pred,
inputs,
Expand Down Expand Up @@ -581,6 +593,12 @@ fn reachable_nodes(reachable: &mut BTreeSet<Id>, all: &[RvsdgBody], output: Oper
| RvsdgBody::BasicOp(BasicExpr::Print(xs)) => xs.clone(),
RvsdgBody::BasicOp(BasicExpr::Const(..)) => vec![],
RvsdgBody::Gamma { pred, inputs, .. } => once(pred).chain(inputs).copied().collect(),
RvsdgBody::If {
pred,
inputs,
then_branch: _,
else_branch: _,
} => once(pred).chain(inputs).copied().collect(),
RvsdgBody::Theta { inputs, .. } => inputs.clone(),
};
for input in inputs {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ expression: svg
text-anchor="middle"
x="50"
y="62.5">
2
1
</text>
</g>
<circle
Expand Down Expand Up @@ -254,7 +254,7 @@ expression: svg
text-anchor="middle"
x="200"
y="100">
0
true
</text>
</g>
<g>
Expand Down Expand Up @@ -309,7 +309,7 @@ expression: svg
text-anchor="middle"
x="50"
y="62.5">
1
2
</text>
</g>
<circle
Expand Down Expand Up @@ -394,7 +394,7 @@ expression: svg
text-anchor="middle"
x="575"
y="100">
1
false
</text>
</g>
<circle
Expand Down
Loading
Loading