Skip to content

Commit

Permalink
unify constructor and expr
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Jan 30, 2024
1 parent 2384297 commit 0e63305
Show file tree
Hide file tree
Showing 12 changed files with 217 additions and 173 deletions.
2 changes: 1 addition & 1 deletion src/rvsdg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl RvsdgType {
pub(crate) fn to_tree_type(&self) -> TreeType {
match self {
RvsdgType::Bril(t) => TreeType::Bril(t.clone()),
RvsdgType::PrintState => TreeType::Unit,
RvsdgType::PrintState => TreeType::Tuple(vec![]),
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions src/rvsdg/tree_unique/to_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ fn translate_simple_loop() {
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(
"myfunc",
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Bril(Type::Int), TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
TreeType::Tuple(vec![TreeType::Bril(Type::Int), TreeType::Tuple(vec![])]),
cbind(
num(1), // [(), 1]
cbind(
Expand Down Expand Up @@ -286,8 +286,8 @@ fn translate_loop() {
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(
"main",
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
cbind(
num(0), // [(), 0]
cbind(
Expand Down Expand Up @@ -334,8 +334,8 @@ fn simple_translation() {
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(
"add",
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Bril(Type::Int), TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
TreeType::Tuple(vec![TreeType::Bril(Type::Int), TreeType::Tuple(vec![])]),
cbind(
num(1),
cbind(
Expand Down Expand Up @@ -366,8 +366,8 @@ fn two_print_translation() {
.to_tree_encoding()
.assert_eq_ignoring_ids(&program!(function(
"add",
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Unit]),
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
TreeType::Tuple(vec![TreeType::Tuple(vec![])]),
cbind(
num(2),
cbind(
Expand Down
8 changes: 4 additions & 4 deletions tree_optimizer/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ fn test_complex_program_ids() {
// a let, a loop, a switch, and a call
let prog = program!(function(
"main",
TreeType::Unit,
TreeType::Unit,
TreeType::Tuple(vec![]),
TreeType::Tuple(vec![]),
tlet(
num(0),
tloop(
Expand All @@ -256,8 +256,8 @@ fn test_complex_program_ids() {
Program(vec![Function(
Unique(1),
"main".into(),
TreeType::Unit,
TreeType::Unit,
TreeType::Tuple(vec![]),
TreeType::Tuple(vec![]),
Box::new(Let(
Unique(2),
Box::new(Num(0)),
Expand Down
5 changes: 4 additions & 1 deletion tree_optimizer/src/deep_copy.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::ir::{Constructor, ESort, Purpose};
use crate::{
expr::ESort,
ir::{Constructor, Purpose},
};
use strum::IntoEnumIterator;

fn deep_copy_rule_for_ctor(ctor: Constructor) -> String {
Expand Down
12 changes: 7 additions & 5 deletions tree_optimizer/src/error_checking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
//! such as checking that switch children
//! are all `Branch`es.

use strum::IntoEnumIterator;

use crate::ir::{Constructor, ESort};
use crate::{
expr::{ESort, Expr},
ir::Constructor,
};

pub(crate) fn error_checking_rules() -> Vec<String> {
let mut res = vec![format!(
Expand All @@ -25,15 +26,16 @@ pub(crate) fn error_checking_rules() -> Vec<String> {
if ctor.sort() == ESort::ListExpr {
continue;
}
if ctor == Constructor::Branch {
if let Constructor::Expr(Expr::Branch(..)) = ctor {
continue;
}

let pat = ctor.construct(|field| field.var());
let ctor_name = ctor.name();
res.push(format!(
"
(rule ((IsBranchList (Cons {pat} rest)))
((panic \"Expected Branch, got {ctor}\"))
((panic \"Expected Branch, got {ctor_name}\"))
:ruleset error-checking)
"
));
Expand Down
100 changes: 97 additions & 3 deletions tree_optimizer/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,63 @@ impl PureUOp {
}
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub(crate) enum Sort {
Expr,
ListExpr,
Order,
BinPureOp,
UnaryPureOp,
IdSort,
I64,
Bool,
Type,
String,
}

impl Sort {
pub(crate) fn name(&self) -> &'static str {
match self {
Sort::Expr => "Expr",
Sort::ListExpr => "ListExpr",
Sort::Order => "Order",
Sort::IdSort => "IdSort",
Sort::I64 => "i64",
Sort::String => "String",
Sort::Bool => "bool",
Sort::Type => "Type",
Sort::BinPureOp => "BinPureOp",
Sort::UnaryPureOp => "UnaryPureOp",
}
}
}

// Subset of sorts that refer to expressions
#[derive(Debug, EnumIter, PartialEq)]
pub(crate) enum ESort {
Expr,
ListExpr,
}

impl std::fmt::Display for ESort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}

impl ESort {
pub(crate) fn to_sort(&self) -> Sort {
match self {
ESort::Expr => Sort::Expr,
ESort::ListExpr => Sort::ListExpr,
}
}

pub(crate) fn name(&self) -> &'static str {
self.to_sort().name()
}
}

#[derive(Clone, Debug, PartialEq, EnumIter)]
pub enum Expr {
Num(i64),
Expand Down Expand Up @@ -151,6 +208,39 @@ impl Default for Expr {
}

impl Expr {
pub fn is_pure(&self) -> bool {
use Expr::*;
match self {
Num(..) | Boolean(..) | Arg(..) | BOp(..) | UOp(..) | Get(..) | Concat(..)
| Read(..) | All(..) | Switch(..) | Branch(..) | Loop(..) | Let(..) | Function(..)
| Program(..) | Call(..) => true,
Print(..) | Write(..) => false,
}
}

pub fn name(&self) -> &'static str {
match self {
Expr::Num(_) => "Num",
Expr::Boolean(_) => "Boolean",
Expr::BOp(_, _, _) => "BOp",
Expr::UOp(_, _) => "UOp",
Expr::Get(_, _) => "Get",
Expr::Concat(_, _) => todo!("Remove concat from ast"),
Expr::Print(_) => "Print",
Expr::Read(_) => "Read",
Expr::Write(_, _) => "Write",
Expr::All(_, _, _) => "All",
Expr::Switch(_, _) => "Switch",
Expr::Branch(_, _) => "Branch",
Expr::Loop(_, _, _) => "Loop",
Expr::Let(_, _, _) => "Let",
Expr::Arg(_) => "Arg",
Expr::Function(_, _, _, _, _) => "Function",
Expr::Program(_) => "Program",
Expr::Call(_, _, _) => "Call",
}
}

/// Runs `func` on every child of this expression.
pub fn for_each_child(&mut self, mut func: impl FnMut(&mut Expr)) {
match self {
Expand Down Expand Up @@ -206,14 +296,18 @@ pub enum Value {
Tuple(Vec<Value>),
}

#[derive(Clone, PartialEq, Debug, Default)]
#[derive(Clone, PartialEq, Debug)]
pub enum TreeType {
#[default]
Unit,
Bril(Type),
Tuple(Vec<TreeType>),
}

impl Default for TreeType {
fn default() -> Self {
TreeType::Tuple(vec![])
}
}

pub enum TypeError {
ExpectedType(Expr, TreeType, TreeType),
ExpectedTupleType(Expr, TreeType),
Expand Down
5 changes: 4 additions & 1 deletion tree_optimizer/src/id_analysis.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::ir::{Constructor, ESort, Purpose};
use crate::{
expr::ESort,
ir::{Constructor, Purpose},
};
use strum::IntoEnumIterator;

fn id_analysis_rules_for_ctor(ctor: Constructor) -> String {
Expand Down
Loading

0 comments on commit 0e63305

Please sign in to comment.