diff --git a/dag_in_context/src/add_context.rs b/dag_in_context/src/add_context.rs index 8ceecc98a..5b7d9cea5 100644 --- a/dag_in_context/src/add_context.rs +++ b/dag_in_context/src/add_context.rs @@ -217,6 +217,7 @@ impl Expr { out_ty.clone(), body.add_ctx_with_cache(current_ctx, cache), )), + Expr::Symbolic(_) => panic!("found symbol"), }; cache .with_ctx diff --git a/dag_in_context/src/interpreter.rs b/dag_in_context/src/interpreter.rs index fb5052f47..1bd4b4da3 100644 --- a/dag_in_context/src/interpreter.rs +++ b/dag_in_context/src/interpreter.rs @@ -388,6 +388,7 @@ impl<'a> VirtualMachine<'a> { let e_val = self.interpret_expr(e, arg); self.interpret_call(func_name, &e_val) } + Expr::Symbolic(_) => panic!("found symbolic"), }; self.eval_cache.insert(Rc::as_ptr(expr), res.clone()); res diff --git a/dag_in_context/src/linearity.rs b/dag_in_context/src/linearity.rs index 5dd67440b..a7cf4f523 100644 --- a/dag_in_context/src/linearity.rs +++ b/dag_in_context/src/linearity.rs @@ -192,6 +192,7 @@ impl<'a> Extractor<'a> { self.find_effectful_nodes_in_region(body, linearity) } Expr::Const(_, _, _) => panic!("Const has no effect"), + Expr::Symbolic(_) => panic!("found symbolic"), } } diff --git a/dag_in_context/src/pretty_print.rs b/dag_in_context/src/pretty_print.rs index 135d46489..417e70b39 100644 --- a/dag_in_context/src/pretty_print.rs +++ b/dag_in_context/src/pretty_print.rs @@ -1,18 +1,28 @@ -use std::{collections::HashMap, rc::Rc}; - use crate::{ - from_egglog::FromEgglog, - prologue, - schema::{self, Assumption, BaseType, BinaryOp, Expr, RcExpr, TernaryOp, Type, UnaryOp}, + from_egglog::FromEgglog, optimizations::body_contains, prologue, schema::{self, Assumption, BaseType, BinaryOp, Expr, RcExpr, TernaryOp, Type, UnaryOp} }; -use egglog::TermDag; +use egglog::{util::IndexMap, TermDag}; +use indexmap::IndexMap; +use std::{collections::HashMap, rc::Rc}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +enum Either3 { + Left(A), + Mid(B), + Right(C), +} pub struct PrettyPrinter { pub expr: RcExpr, + pub cache: IndexMap<*const schema::Expr, String>, + pub symbols: IndexMap, String>, + // pub types: IndexMap, + // pub assums: IndexMap, } + impl PrettyPrinter { - pub fn new(str_expr: String) -> std::result::Result { + pub fn from_string(str_expr: String) -> std::result::Result { let bounded_expr = format!("(let EXPR___ {})", str_expr); let prog = prologue().to_owned() + &bounded_expr; let mut egraph = egglog::EGraph::default(); @@ -26,217 +36,323 @@ impl PrettyPrinter { termdag: &termdag, conversion_cache: HashMap::new(), }; - Ok(PrettyPrinter { - expr: converter.expr_from_egglog(extracted), - }) + Ok(Self::from_expr(converter.expr_from_egglog(extracted))) + } + + pub fn from_expr(expr: RcExpr) -> PrettyPrinter { + let mut cache = IndexMap::new(); + let mut symbols = IndexMap::new(); + Self::assign_fresh_var(&expr, &mut cache, &mut symbols); + PrettyPrinter { + expr, + cache, + symbols, + } } pub fn to_egglog_default(&self) -> String { - self.to_egglog(&|rc, len| (rc > 1 && len > 80) || rc > 4 || len > 200) + self.to_egglog(&|rc, len| (rc > 1 && len > 30) || len > 80) } pub fn to_egglog(&self, fold_when: &dyn Fn(usize, usize) -> bool) -> String { - let mut log = String::new(); - let mut cache: HashMap<*const schema::Expr, String> = HashMap::new(); - let mut symbols: HashMap = HashMap::new(); - let res = Self::to_egglog_helper(&self.expr, &mut cache, &mut symbols, &mut log, fold_when); - log + &format!("(let EXPR___ {res})") + let mut log = IndexMap::new(); + let res = self.to_nested_expr(&self.expr, &mut log, fold_when); + let log = self + .symbols + .iter() + .map(|(expr, symbol)| format!("(let {symbol}\n{expr}) \n")) + .chain( + log.iter() + .map(|(var, expr)| format!("(let {var} \n{}) \n", expr.pretty())) + .collect::>(), + ) + .collect::>() + .join(""); + log + &format!("(let EXPR___\n{})", res.pretty()) + } + + pub fn to_rust_default(&self) -> String { + self.to_rust(&|rc, len| (rc > 1 && len > 30) || rc > 4 || len > 80) } // turn the Expr to a rust ast macro string. // return a rust ast macro // fold_when: provide a function that decide when to fold the macro to a let binding pub fn to_rust(&self, fold_when: &dyn Fn(usize, usize) -> bool) -> String { - let mut log = String::new(); - let mut cache: HashMap<*const schema::Expr, String> = HashMap::new(); - let mut symbols: HashMap = HashMap::new(); - let res = Self::to_ast(&self.expr, &mut cache, &mut symbols, &mut log, fold_when); - log + &format!("let expr___ = {res}; \n") + let mut log = IndexMap::new(); + let res = self.to_nested_expr(&self.expr, &mut log, fold_when); + let log = self + .symbols + .iter() + .map(|(expr, symbol)| format!("let {symbol} = {expr}; \n")) + .chain( + log.iter() + .map(|(var, expr)| format!("(let {var} \n{}) \n", expr.pretty())) + .collect::>(), + ) + .collect::>() + .join(""); + log + &format!("(let EXPR___\n{})", res.pretty()) } - pub fn to_rust_default(&self) -> String { - self.to_rust(&|rc, len| (rc > 1 && len > 30) || rc > 4 || len > 80) + fn assign_fresh_var( + expr: &RcExpr, + cache: &mut IndexMap<*const schema::Expr, String>, + // types: &mut IndexMap, + // assums: &mut IndexMap<*const Assumption, String>, + symbols : &mut IndexMap, String> + ) { + let len = cache.len(); + let make_fresh = |info: String| format!("{info}_{}", len); + // let try_insert_fresh = + // |var: String, info: String, symbols: IndexMap| { + // if !symbols.contains_key(&var) { + // let fresh_var = format!("{info}_{}", symbols.len()); + // symbols.insert(var, fresh_var); + // } + // }; + fn try_insert_fresh (var : T, info: String, symbols: &mut IndexMap) { + if !symbols.contains_key(&var) { + let fresh_var = format!("{info}_{}", symbols.len()); + symbols.insert(var, fresh_var); + } + } + + let expr_ptr = Rc::as_ptr(expr); + + // some expr need fresh var, other do not + if !cache.contains_key(&expr_ptr) { + match expr.as_ref() { + Expr::Const(c, ty, assum) => { + try_insert_fresh(ty.to_owned(), ty.abbrev(), types); + try_insert_fresh( Rc::as_ptr(&Rc::new(assum.to_owned())) , assum.abbrev(), assums); + let c = match c { + schema::Constant::Int(i) => format!("int{i}"), + schema::Constant::Bool(b) => format!("bool{b}"), + }; + cache.insert(expr_ptr, make_fresh(c)); + } + Expr::Top(op, lhs, mid, rhs) => { + Self::assign_fresh_var(lhs, cache, types, assums); + Self::assign_fresh_var(mid, cache, types, assums); + Self::assign_fresh_var(rhs, cache, types, assums); + cache.insert(expr_ptr, make_fresh(op.to_ast())); + } + Expr::Bop(op, lhs, rhs) => { + Self::assign_fresh_var(lhs, cache, types, assums); + Self::assign_fresh_var(rhs, cache, types, assums); + cache.insert(expr_ptr, make_fresh(op.to_ast())); + } + Expr::Uop(op, expr) => { + Self::assign_fresh_var(expr, cache, types, assums); + cache.insert(expr_ptr, make_fresh(op.to_ast())); + } + Expr::Get(expr, usize) => { + if let Expr::Arg(..) = expr.as_ref() { + cache.insert(expr_ptr, make_fresh(format!("get_at_{usize}"))); + } + Self::assign_fresh_var(expr, cache, types, assums); + } + Expr::Alloc(id, x, y, ptrty) => { + Self::assign_fresh_var(x, cache, types, assums); + Self::assign_fresh_var(y, cache, types, assums); + try_insert_fresh(ptrty.to_string(), ptrty.abbrev(), types, assums); + cache.insert(expr_ptr, expr.as_ref().abbrev() + &id.to_string()); + } + Expr::Call(name, arg) => { + Self::assign_fresh_var(arg, cache, types, assums); + cache.insert(expr_ptr, make_fresh("call_".to_owned() + name)); + } + Expr::Empty(ty, assum) => { + try_insert_fresh(ty.pretty(), ty.abbrev(), types, assums); + try_insert_fresh(assum.pretty(), assum.abbrev(), types, assums); + } + Expr::Single(expr) => { + Self::assign_fresh_var(expr, cache, types, assums); + } + Expr::Concat(lhs, rhs) => { + Self::assign_fresh_var(lhs, cache, types, assums); + Self::assign_fresh_var(rhs, cache, types, assums); + } + Expr::If(cond, input, then, els) => { + Self::assign_fresh_var(cond, cache, types, assums); + Self::assign_fresh_var(input, cache, types, assums); + Self::assign_fresh_var(then, cache, types, assums); + Self::assign_fresh_var(els, cache, types, assums); + cache.insert(expr_ptr, make_fresh("if".into())); + } + Expr::Switch(cond, input, branch) => { + Self::assign_fresh_var(cond, cache, types, assums); + Self::assign_fresh_var(input, cache, types, assums); + branch + .iter() + .for_each(|expr| Self::assign_fresh_var(expr, cache, types, assums)); + cache.insert(expr_ptr, make_fresh("switch".into())); + } + Expr::DoWhile(input, body) => { + Self::assign_fresh_var(input, cache, types, assums); + Self::assign_fresh_var(body, cache, types, assums); + cache.insert(expr_ptr, make_fresh("dowhile".into())); + } + Expr::Arg(ty, assum) => { + try_insert_fresh(ty.pretty(), ty.abbrev(), types, assums); + try_insert_fresh(assum.pretty(), assum.abbrev(), types, assums); + } + Expr::Function(_, tyin, tyout, body) => { + try_insert_fresh(tyin.pretty(), tyin.abbrev(), types, assums); + try_insert_fresh(tyout.pretty(), tyout.abbrev(), types, assums); + Self::assign_fresh_var(body, cache, types, assums); + } + Expr::Symbolic(_) => panic!("no symbolic should occur here"), + } + } } - // symbols: Type and Assumption's string -> their binding var - fn to_egglog_helper( + fn to_nested_expr( + &self, expr: &RcExpr, - cache: &mut HashMap<*const schema::Expr, String>, - symbols: &mut HashMap, - log: &mut String, + log: &mut IndexMap, fold_when: &dyn Fn(usize, usize) -> bool, - ) -> String { - use self::*; - let find_or_insert = |var: String, - info: String, - str_builder: &mut String, - symbols: &mut HashMap| { - let fresh_var = format!("{}_{}", info, symbols.len()); - symbols - .entry(var.clone()) - .or_insert_with(|| { - str_builder.push_str(&format!("(let {} \n {}) \n", fresh_var.clone(), var)); - fresh_var - }) + ) -> Expr { + let fold = |egglog: Expr, log: &mut IndexMap| { + let fresh_var = self.cache.get(&Rc::as_ptr(&expr)).unwrap(); + if !log.contains_key(fresh_var) { + log.insert(fresh_var.into(), egglog); + } + Expr::Symbolic(fresh_var.into()) + }; + let fold_or_plain = |egglog: Expr, log: &mut IndexMap| { + let rc = Rc::strong_count(expr); + let size = egglog .clone() + .to_string() + .replace(&['(', ')', ' '][..], "") + .len(); + if fold_when(rc, size) { + fold(egglog, log) + } else { + egglog + } }; - let fold_or_plain = - |egglog_str: String, - info: String, - str_builder: &mut String, - cache: &mut HashMap<*const schema::Expr, String>| { - let rc = Rc::strong_count(expr); - if fold_when(rc, egglog_str.len()) { - let fresh_var = format!("{}_{}", info, cache.len()); - cache - .entry(Rc::as_ptr(expr)) - .or_insert_with(|| { - str_builder.push_str(&format!( - "(let {} \n {}) \n", - fresh_var.clone(), - egglog_str - )); - fresh_var - }) - .clone() + match expr.as_ref() { + Expr::Function(name, inty, outty, body) => { + let inty_str = self.symbols.get(&inty.pretty()).unwrap(); + let outty_str = self.symbols.get(&outty.pretty()).unwrap(); + let body = self.to_nested_expr(body, log, fold_when); + Expr::Function( + name.into(), + Type::Symbolic(inty_str.into()), + Type::Symbolic(outty_str.into()), + Rc::new(body), + ) + } + Expr::Const(c, ty, assum) => { + let ty = self.symbols.get(&ty.pretty()).unwrap(); + let assum = self.symbols.get(&assum.pretty()).unwrap(); + let c = Expr::Const( + c.clone(), + Type::Symbolic(ty.into()), + Assumption::WildCard(assum.into()), + ); + fold(c, log) + } + Expr::Top(op, x, y, z) => { + let left = self.to_nested_expr(x, log, fold_when); + let mid = self.to_nested_expr(y, log, fold_when); + let right = self.to_nested_expr(z, log, fold_when); + let top = Expr::Top(op.clone(), Rc::new(left), Rc::new(mid), Rc::new(right)); + fold_or_plain(top, log) + } + Expr::Bop(op, x, y) => { + let left = self.to_nested_expr(x, log, fold_when); + let right = self.to_nested_expr(y, log, fold_when); + let bop = Expr::Bop(op.clone(), Rc::new(left), Rc::new(right)); + fold_or_plain(bop, log) + } + Expr::Uop(op, x) => { + let sub_expr = self.to_nested_expr(x, log, fold_when); + let uop = Expr::Uop(op.clone(), Rc::new(sub_expr)); + fold_or_plain(uop, log) + } + Expr::Get(x, pos) => { + let sub_expr = self.to_nested_expr(x, log, fold_when); + let get = Expr::Get(Rc::new(sub_expr), pos.clone()); + // fold Get Arg i anyway + if let Expr::Arg(_, _) = x.as_ref() { + fold(get, log) } else { - egglog_str - } - }; - match cache.get(&Rc::as_ptr(expr)) { - Some(str) => str.to_owned(), - None => { - let expr = expr.as_ref(); - match expr { - Expr::Function(name, inty, outty, body) => { - let inty_str = - find_or_insert(inty.to_string(), inty.abbrev(), log, symbols); - let outty_str = - find_or_insert(outty.to_string(), outty.abbrev(), log, symbols); - let body = Self::to_egglog_helper(body, cache, symbols, log, fold_when); - let fun = format!("(Function {name} {inty_str} {outty_str} \n {body})"); - fold_or_plain(fun, format!("Fun_{name}"), log, cache) - } - Expr::Const(c, ty, assum) => { - let ty = find_or_insert(ty.to_string(), ty.abbrev(), log, symbols); - let assum = find_or_insert(assum.to_string(), assum.abbrev(), log, symbols); - let constant = format!("(Const {c} {ty} {assum})"); - cache.insert(expr, constant.clone()); - constant - } - Expr::Top(op, x, y, z) => { - let left = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let mid = Self::to_egglog_helper(y, cache, symbols, log, fold_when); - let right = Self::to_egglog_helper(z, cache, symbols, log, fold_when); - let top = - format!("(Top ({:?}) \n {} \n {} \n {})", op, left, mid, right); - fold_or_plain(top, format!("{:?}", op), log, cache) - } - Expr::Bop(op, x, y) => { - let left = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let right = Self::to_egglog_helper(y, cache, symbols, log, fold_when); - let bop = format!("(Bop ({:?}) \n {} \n {})", op, left, right); - - fold_or_plain(bop, format!("{:?}", op), log, cache) - } - Expr::Uop(op, x) => { - let sub_expr = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let uop = format!("(Uop ({:?}) {})", op, sub_expr); - - fold_or_plain(uop, format!("{:?}", op), log, cache) - } - Expr::Get(x, pos) => { - let sub_expr = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let get = format!("(Get {sub_expr} {pos})"); - cache.insert(expr, get.clone()); - get - } - Expr::Alloc(id, x, y, pointer_ty) => { - let amount = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let state_edge = Self::to_egglog_helper(y, cache, symbols, log, fold_when); - let ty = find_or_insert( - pointer_ty.to_string(), - pointer_ty.abbrev(), - log, - symbols, - ); - let alloc = - format!("(Alloc {id} \n {amount} \n {state_edge} \n {ty})"); - fold_or_plain(alloc, format!("Alloc{id}"), log, cache) - } - Expr::Call(name, x) => { - let sub_expr = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let call = format!("(Call {name} {sub_expr})"); - fold_or_plain(call, format!("CallFun_{name}"), log, cache) - } - Expr::Empty(ty, assum) => { - let ty = find_or_insert(ty.to_string(), ty.abbrev(), log, symbols); - let assum = find_or_insert(assum.to_string(), assum.abbrev(), log, symbols); - let empty = format!("(Empty {ty} {assum})"); - cache.insert(expr, empty.clone()); - empty - } - // doesn't fold Tuple - Expr::Single(x) => { - let sub_expr = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let single = format!("(Single {})", sub_expr.clone()); - cache.insert(expr, single.clone()); - single - } - Expr::Concat(x, y) => { - let left = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let right = Self::to_egglog_helper(y, cache, symbols, log, fold_when); - let concat = format!("(Concat {left} {right})"); - cache.insert(expr, concat.clone()); - concat - } - Expr::Switch(x, inputs, _branches) => { - let cond = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let inputs = Self::to_egglog_helper(inputs, cache, symbols, log, fold_when); - - fn cons_list(vec: Vec) -> String { - match vec.get(0) { - Some(str) => { - format!("(Cons {} {})", str, cons_list(vec[1..].to_vec())) - } - None => "(Nil)".to_string(), - } - } - let branches = _branches - .iter() - .map(|branch| { - Self::to_egglog_helper(branch, cache, symbols, log, fold_when) - }) - .collect::>(); - let branch_list = cons_list(branches); - let switch = format!("(Switch \n {cond}\n {inputs}\n {branch_list})"); - fold_or_plain(switch, "switch".into(), log, cache) - } - Expr::If(x, inputs, y, z) => { - let pred = Self::to_egglog_helper(x, cache, symbols, log, fold_when); - let inputs = Self::to_egglog_helper(inputs, cache, symbols, log, fold_when); - let left = Self::to_egglog_helper(y, cache, symbols, log, fold_when); - let right = Self::to_egglog_helper(z, cache, symbols, log, fold_when); - let if_expr = - format!("(If \n {pred}\n {inputs}\n {left}\n {right})"); - fold_or_plain(if_expr, "if".into(), log, cache) - } - Expr::DoWhile(inputs, body) => { - let inputs = Self::to_egglog_helper(inputs, cache, symbols, log, fold_when); - let body = Self::to_egglog_helper(body, cache, symbols, log, fold_when); - let dowhile = format!("(DoWhile\n {inputs}\n {body})"); - fold_or_plain(dowhile, "dowhile".into(), log, cache) - } - Expr::Arg(ty, assum) => { - let ty = find_or_insert(ty.to_string(), ty.abbrev(), log, symbols); - let assum = find_or_insert(assum.to_string(), assum.abbrev(), log, symbols); - let arg = format!("(Arg {ty} {assum})"); - cache.insert(expr, arg.clone()); - arg - } + get } } + Expr::Alloc(id, x, y, ty) => { + let amount = self.to_nested_expr(x, log, fold_when); + let state_edge = self.to_nested_expr(y, log, fold_when); + let alloc = + Expr::Alloc(id.clone(), Rc::new(amount), Rc::new(state_edge), ty.clone()); + fold_or_plain(alloc, log) + } + Expr::Call(name, x) => { + let sub_expr = self.to_nested_expr(x, log, fold_when); + let call = Expr::Call(name.into(), Rc::new(sub_expr)); + fold_or_plain(call, log) + } + Expr::Empty(ty, assum) => { + let ty = self.symbols.get(&ty.pretty()).unwrap(); + let assum = self.symbols.get(&assum.pretty()).unwrap(); + Expr::Empty( + Type::Symbolic(ty.into()), + Assumption::WildCard(assum.into()), + ) + } + // doesn't fold Tuple + Expr::Single(x) => { + let sub_expr = self.to_nested_expr(x, log, fold_when); + Expr::Single(Rc::new(sub_expr)) + } + Expr::Concat(x, y) => { + let left = self.to_nested_expr(x, log, fold_when); + let right = self.to_nested_expr(y, log, fold_when); + Expr::Concat(Rc::new(left), Rc::new(right)) + } + Expr::Switch(x, inputs, _branches) => { + let cond = self.to_nested_expr(x, log, fold_when); + let inputs = self.to_nested_expr(inputs, log, fold_when); + let branches = _branches + .iter() + .map(|branch| Rc::new(self.to_nested_expr(branch, log, fold_when))) + .collect::>(); + let switch = Expr::Switch(Rc::new(cond), Rc::new(inputs), branches); + fold_or_plain(switch, log) + } + Expr::If(x, inputs, y, z) => { + let pred = self.to_nested_expr(x, log, fold_when); + let inputs = self.to_nested_expr(inputs, log, fold_when); + let left = self.to_nested_expr(y, log, fold_when); + let right = self.to_nested_expr(z, log, fold_when); + let if_expr = Expr::If( + Rc::new(pred), + Rc::new(inputs), + Rc::new(left), + Rc::new(right), + ); + fold_or_plain(if_expr, log) + } + Expr::DoWhile(inputs, body) => { + let inputs = self.to_nested_expr(inputs, log, fold_when); + let body = self.to_nested_expr(body, log, fold_when); + let dowhile = Expr::DoWhile(Rc::new(inputs), Rc::new(body)); + fold_or_plain(dowhile, log) + } + Expr::Arg(ty, assum) => { + let ty = self.symbols.get(&ty.pretty()).unwrap(); + let assum = self.symbols.get(&assum.pretty()).unwrap(); + Expr::Arg( + Type::Symbolic(ty.into()), + Assumption::WildCard(assum.into()), + ) + } + Expr::Symbolic(_) => panic!("No symbolic should occur here"), } } @@ -258,163 +374,257 @@ impl PrettyPrinter { let expr = Self::to_ast(expr, cache, symbols, log, fold_when); vec![expr] } - _ => panic!("not well formed Concat, expr not wrapped with Single"), + _ => panic!("Not well formed Concat, expr not in Single"), } } - fn to_ast( - expr: &RcExpr, - cache: &mut HashMap<*const schema::Expr, String>, - symbols: &mut HashMap, - log: &mut String, - fold_when: &dyn Fn(usize, usize) -> bool, - ) -> String { - let find_or_insert = |var: String, - info: String, - str_builder: &mut String, - symbols: &mut HashMap| { - let fresh_var = format!("{}_{}", info, symbols.len()); - symbols - .entry(var.clone()) - .or_insert_with(|| { - str_builder.push_str(&format!("let {} = {}; \n", fresh_var.clone(), var)); - fresh_var - }) - .clone() - }; + // fn to_ast( + // expr: &RcExpr, + // cache: &mut HashMap<*const schema::Expr, String>, + // symbols: &mut HashMap, + // log: &mut String, + // fold_when: &dyn Fn(usize, usize) -> bool, + // ) -> String { + // let find_or_insert = |var: String, + // info: String, + // str_builder: &mut String, + // symbols: &mut HashMap| { + // let fresh_var = format!("{}_{}", info, symbols.len()); + // symbols + // .entry(var.clone()) + // .or_insert_with(|| { + // str_builder.push_str(&format!("let {} = {}; \n", fresh_var.clone(), var)); + // fresh_var + // }) + // .clone() + // }; - let fold_or_plain = - |ast_str: String, - info: String, - str_builder: &mut String, - cache: &mut HashMap<*const schema::Expr, String>| { - let rc = Rc::strong_count(expr); - if fold_when(rc, ast_str.len()) { - let fresh_var = format!("{}_{}", info, cache.len()); - let lookup = cache - .entry(Rc::as_ptr(expr)) - .or_insert_with(|| { - str_builder.push_str(&format!( - "let {} = {}; \n", - fresh_var.clone(), - ast_str - )); - fresh_var - }) - .clone(); - format!("{lookup}.clone()") - } else { - ast_str - } - }; - - match cache.get(&Rc::as_ptr(expr)) { - Some(str) => format!("{}.clone()", str), - None => { - match expr.as_ref() { - // just don't fold simple things like expr, getat anyway - Expr::Const(c, _, _) => match c { - schema::Constant::Bool(true) => "ttrue()".into(), - schema::Constant::Bool(false) => "tfalse()".into(), - schema::Constant::Int(n) => format!("int({})", n), - }, - Expr::Bop(op, lhs, rhs) => { - let left = Self::to_ast(lhs, cache, symbols, log, fold_when); - let right = Self::to_ast(rhs, cache, symbols, log, fold_when); - let ast_str = format!("{}({}, {})", op.to_ast(), left, right); - fold_or_plain(ast_str, op.to_ast(), log, cache) - } - Expr::Top(op, x, y, z) => { - let left = Self::to_ast(x, cache, symbols, log, fold_when); - let mid = Self::to_ast(y, cache, symbols, log, fold_when); - let right = Self::to_ast(z, cache, symbols, log, fold_when); - let ast_str = format!("{}({}, {}, {})", op.to_ast(), left, mid, right); - fold_or_plain(ast_str, op.to_ast(), log, cache) - } - Expr::Uop(op, expr) => { - let expr = Self::to_ast(expr, cache, symbols, log, fold_when); - let ast_str = format!("{}({})", op.to_ast(), expr); - fold_or_plain(ast_str, op.to_ast(), log, cache) - } - Expr::Get(expr, index) => match expr.as_ref() { - Expr::Arg(_, _) => { - format!("getat({index})") - } - _ => { - let expr = Self::to_ast(expr, cache, symbols, log, fold_when); - format!("get({expr}, {index})") - } - }, - Expr::Alloc(id, expr, state, ty) => { - let expr = Self::to_ast(expr, cache, symbols, log, fold_when); - let state = Self::to_ast(state, cache, symbols, log, &fold_when); - let ty_str = ty.to_ast(); - let ty_binding = find_or_insert(ty_str, ty.abbrev(), log, symbols); - let ast_str = format!("alloc({id}, {expr}, {state}, {ty_binding})"); - fold_or_plain(ast_str, "alloc".into(), log, cache) - } - Expr::Call(name, arg) => { - let arg = Self::to_ast(arg, cache, symbols, log, fold_when); - format!("call({name}, {arg})") - } - Expr::Empty(..) => "empty()".into(), - Expr::Single(expr) => { - let expr = Self::to_ast(expr, cache, symbols, log, fold_when); - format!("single({expr})") - } - Expr::Concat(..) => { - let vec = Self::concat_helper(expr, cache, symbols, log, fold_when); - let inside = vec.join(", "); - format!("parallel!({inside})") - } - Expr::Switch(cond, inputs, cases) => { - let cond = Self::to_ast(cond, cache, symbols, log, fold_when); - let inputs = Self::to_ast(inputs, cache, symbols, log, fold_when); - let cases = cases - .iter() - .map(|expr| Self::to_ast(expr, cache, symbols, log, fold_when)) - .collect::>() - .join(", "); - let ast_str = format!("switch!({cond}, {inputs}; parallel!({cases}))"); - fold_or_plain(ast_str, "switch".into(), log, cache) - } - Expr::If(cond, input, then, els) => { - let cond = Self::to_ast(cond, cache, symbols, log, fold_when); - let input = Self::to_ast(input, cache, symbols, log, fold_when); - let then = Self::to_ast(then, cache, symbols, log, fold_when); - let els = Self::to_ast(els, cache, symbols, log, fold_when); - let ast_str = format!("tif({cond}, {input}, {then}, {els})"); - fold_or_plain(ast_str, "if".into(), log, cache) - } - Expr::DoWhile(input, body) => { - let input = Self::to_ast(input, cache, symbols, log, fold_when); - let body = Self::to_ast(body, cache, symbols, log, fold_when); - let ast_str = format!("dowhile({input}, {body})"); - fold_or_plain(ast_str, "dowhile".into(), log, cache) - } - Expr::Arg(..) => "arg()".into(), - Expr::Function(name, ty_in, ty_out, body) => { - let ty_in_str = ty_in.to_ast(); - let ty_in_binding = find_or_insert(ty_in_str, ty_in.abbrev(), log, symbols); - let ty_out_str = ty_out.to_ast(); - let ty_out_binding = - find_or_insert(ty_out_str, ty_out.abbrev(), log, symbols); - let body = Self::to_ast(body, cache, symbols, log, fold_when); - format!("function(\"{name}\", {ty_in_binding}, {ty_out_binding}, {body})") - } - } - } - } - } + // let fold_or_plain = + // |ast_str: String, + // info: String, + // str_builder: &mut String, + // cache: &mut HashMap<*const schema::Expr, String>| { + // let rc = Rc::strong_count(expr); + // if fold_when(rc, ast_str.len()) { + // let fresh_var = format!("{}_{}", info, cache.len()); + // let lookup = cache + // .entry(Rc::as_ptr(expr)) + // .or_insert_with(|| { + // str_builder.push_str(&format!( + // "let {} = {}; \n", + // fresh_var.clone(), + // ast_str + // )); + // fresh_var + // }) + // .clone(); + // format!("{lookup}.clone()") + // } else { + // ast_str + // } + // }; + + // match cache.get(&Rc::as_ptr(expr)) { + // Some(str) => format!("{}.clone()", str), + // None => { + // match expr.as_ref() { + // // just don't fold simple things like expr, getat anyway + // Expr::Const(c, _, _) => match c { + // schema::Constant::Bool(true) => "ttrue()".into(), + // schema::Constant::Bool(false) => "tfalse()".into(), + // schema::Constant::Int(n) => format!("int({})", n), + // }, + // Expr::Bop(op, lhs, rhs) => { + // let left = Self::to_ast(lhs, cache, symbols, log, fold_when); + // let right = Self::to_ast(rhs, cache, symbols, log, fold_when); + // let ast_str = format!("{}({}, {})", op.to_ast(), left, right); + // fold_or_plain(ast_str, op.to_ast(), log, cache) + // } + // Expr::Top(op, x, y, z) => { + // let left = Self::to_ast(x, cache, symbols, log, fold_when); + // let mid = Self::to_ast(y, cache, symbols, log, fold_when); + // let right = Self::to_ast(z, cache, symbols, log, fold_when); + // let ast_str = format!("{}({}, {}, {})", op.to_ast(), left, mid, right); + // fold_or_plain(ast_str, op.to_ast(), log, cache) + // } + // Expr::Uop(op, expr) => { + // let expr = Self::to_ast(expr, cache, symbols, log, fold_when); + // let ast_str = format!("{}({})", op.to_ast(), expr); + // fold_or_plain(ast_str, op.to_ast(), log, cache) + // } + // Expr::Get(expr, index) => match expr.as_ref() { + // Expr::Arg(_, _) => { + // format!("getat({index})") + // } + // _ => { + // let expr = Self::to_ast(expr, cache, symbols, log, fold_when); + // format!("get({expr}, {index})") + // } + // }, + // Expr::Alloc(id, expr, state, ty) => { + // let expr = Self::to_ast(expr, cache, symbols, log, fold_when); + // let state = Self::to_ast(state, cache, symbols, log, &fold_when); + // let ty_str = ty.to_ast(); + // let ty_binding = find_or_insert(ty_str, ty.abbrev(), log, symbols); + // let ast_str = format!("alloc({id}, {expr}, {state}, {ty_binding})"); + // fold_or_plain(ast_str, "alloc".into(), log, cache) + // } + // Expr::Call(name, arg) => { + // let arg = Self::to_ast(arg, cache, symbols, log, fold_when); + // format!("call({name}, {arg})") + // } + // Expr::Empty(..) => "empty()".into(), + // Expr::Single(expr) => { + // let expr = Self::to_ast(expr, cache, symbols, log, fold_when); + // format!("single({expr})") + // } + // Expr::Concat(..) => { + // let vec = Self::concat_helper(expr, cache, symbols, log, fold_when); + // let inside = vec.join(", "); + // format!("parallel!({inside})") + // } + // Expr::Switch(cond, inputs, cases) => { + // let cond = Self::to_ast(cond, cache, symbols, log, fold_when); + // let inputs = Self::to_ast(inputs, cache, symbols, log, fold_when); + // let cases = cases + // .iter() + // .map(|expr| Self::to_ast(expr, cache, symbols, log, fold_when)) + // .collect::>() + // .join(", "); + // let ast_str = format!("switch!({cond}, {inputs}; parallel!({cases}))"); + // fold_or_plain(ast_str, "switch".into(), log, cache) + // } + // Expr::If(cond, input, then, els) => { + // let cond = Self::to_ast(cond, cache, symbols, log, fold_when); + // let input = Self::to_ast(input, cache, symbols, log, fold_when); + // let then = Self::to_ast(then, cache, symbols, log, fold_when); + // let els = Self::to_ast(els, cache, symbols, log, fold_when); + // let ast_str = format!("tif({cond}, {input}, {then}, {els})"); + // fold_or_plain(ast_str, "if".into(), log, cache) + // } + // Expr::DoWhile(input, body) => { + // let input = Self::to_ast(input, cache, symbols, log, fold_when); + // let body = Self::to_ast(body, cache, symbols, log, fold_when); + // let ast_str = format!("dowhile({input}, {body})"); + // fold_or_plain(ast_str, "dowhile".into(), log, cache) + // } + // Expr::Arg(..) => "arg()".into(), + // Expr::Function(name, ty_in, ty_out, body) => { + // let ty_in_str = ty_in.to_ast(); + // let ty_in_binding = find_or_insert(ty_in_str, ty_in.abbrev(), log, symbols); + // let ty_out_str = ty_out.to_ast(); + // let ty_out_binding = + // find_or_insert(ty_out_str, ty_out.abbrev(), log, symbols); + // let body = Self::to_ast(body, cache, symbols, log, fold_when); + // format!("function(\"{name}\", {ty_in_binding}, {ty_out_binding}, {body})") + // } + // Expr::Symbolic(_) => panic!("no symbolic should occur here"), + // } + // } + // } + //} } impl Expr { pub fn abbrev(&self) -> String { format!("{:?}", self) } + + pub fn pretty(&self) -> String { + let (term, termdag) = Rc::new(self.clone()).to_egglog(); + let expr = termdag.term_to_expr(&term); + expr.to_sexp().pretty() + } + + pub fn to_ast(&self) -> String { + let e = String::new(); + match self { + Expr::Const(c, ..) => match c { + schema::Constant::Bool(true) => "ttrue()".into(), + schema::Constant::Bool(false) => "tfalse()".into(), + schema::Constant::Int(n) => format!("int({})", n), + } + Expr::Top(op, x, y, z) => { + let left = x.to_ast(); + let mid = y.to_ast(); + let right = x.to_ast(); + format!("{}({}, {}, {})", op.to_ast(), left, mid, right) + }, + Expr::Bop(op, x, y) => { + let left = x.to_ast(); + let right = y.to_ast(); + format!("{}({}, {})", op.to_ast(), left, right) + }, + Expr::Uop(op, x) => { + let expr = x.to_ast(); + format!("{}({})", op.to_ast(), expr) + }, + Expr::Get(expr, index) => match expr.as_ref() { + Expr::Arg(_, _) => { + format!("getat({index})") + } + _ => { + let expr = expr.to_ast(); + format!("get({expr}, {index})") + } + }, + Expr::Alloc(id, expr, state, ty) => { + let expr = expr.to_ast(); + let state = state.to_ast(); + let ty_str = ty.to_ast(); + format!("alloc({id}, {expr}, {state}, {ty_str})") + }, + Expr::Call(name, arg) => { + let arg = arg.to_ast(); + format!("call({name}, {arg})") + }, + Expr::Empty(..) => "empty()".into(), + Expr::Single(expr) => { + let expr = expr.to_ast(); + format!("single({expr})") + }, + Expr::Concat(..) => {e}, + Expr::If(cond, inputs, x, y) => { + let cond = cond.to_ast(); + let input = inputs.to_ast(); + let then = x.to_ast(); + let els = y.to_ast(); + format!("tif({cond}, {input}, {then}, {els})") + }, + Expr::Switch(cond, inputs, cases) => { + let cond = cond.to_ast(); + let inputs = inputs.to_ast(); + let cases = cases + .iter() + .map(|expr| expr.to_ast()) + .collect::>() + .join(", "); + format!("switch!({cond}, {inputs}; parallel!({cases}))") + }, + Expr::DoWhile(inputs, body) => { + let inputs = inputs.to_ast(); + let body = body.to_ast(); + format!("dowhile({inputs}, {body})") + }, + Expr::Arg(..) => "arg()".into(), + Expr::Function(name, inty, outty, body) => { + let inty = inty.to_ast(); + let outty = outty.to_ast(); + let body = body.to_ast(); + format!("function(\"{name}\", {inty}, {outty}, {body})") + }, + Expr::Symbolic(str) => str.into(), + } + } } impl Assumption { + pub fn pretty(&self) -> String { + let (term, termdag) = self.to_egglog(); + let expr = termdag.term_to_expr(&term); + expr.to_sexp().pretty() + } + pub fn to_ast( &self, cache: &mut HashMap<*const schema::Expr, String>, @@ -484,6 +694,12 @@ impl BaseType { } impl Type { + pub fn pretty(&self) -> String { + let (term, termdag) = self.to_egglog(); + let expr = termdag.term_to_expr(&term); + expr.to_sexp().pretty() + } + pub fn to_ast(&self) -> String { match self { Type::Base(t) => format!("base({})", BaseType::to_ast(t)), @@ -499,6 +715,7 @@ impl Type { format!("tuplet!({vec_ty_str})") } Type::Unknown => panic!("found unknown in to_ast"), + Type::Symbolic(_) => panic!("found symbolic in to_ast"), } } @@ -515,6 +732,7 @@ impl Type { format!("tpl_{}", vec_ty_str) } Type::Unknown => "unknown".into(), + Type::Symbolic(str) => str.into(), } } } @@ -585,7 +803,9 @@ fn test_pretty_print() { let expr_str = my_loop.to_string(); - PrettyPrinter::new(expr_str.clone()) + let res = PrettyPrinter::from_string(expr_str.clone()) .unwrap() - .to_rust_default(); + .to_egglog_default(); + + println!("{res}") } diff --git a/dag_in_context/src/schema.rs b/dag_in_context/src/schema.rs index 561bb45f1..924b95399 100644 --- a/dag_in_context/src/schema.rs +++ b/dag_in_context/src/schema.rs @@ -26,6 +26,7 @@ pub enum Type { /// `to_egglog` calls `with_arg_types`, so there are never any /// unknown types in the egraph. Unknown, + Symbolic(String), } #[derive(Debug, Clone, PartialEq, Eq, EnumIter, PartialOrd, Ord)] @@ -103,6 +104,7 @@ pub enum Expr { DoWhile(RcExpr, RcExpr), Arg(Type, Assumption), Function(String, Type, Type, RcExpr), + Symbolic(String), // now only used for pretty printer } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/dag_in_context/src/schema_helpers.rs b/dag_in_context/src/schema_helpers.rs index 9fb99ccf5..dfc789e29 100644 --- a/dag_in_context/src/schema_helpers.rs +++ b/dag_in_context/src/schema_helpers.rs @@ -111,6 +111,7 @@ impl Expr { Expr::Empty(..) => Constructor::Empty, Expr::Alloc(..) => Constructor::Alloc, Expr::Top(..) => Constructor::Top, + Expr::Symbolic(_) => panic!("found symbolic"), } } pub fn func_name(&self) -> Option { @@ -200,6 +201,7 @@ impl Expr { Expr::Const(_, _, _) => vec![], Expr::Empty(_, _) => vec![], Expr::Arg(_, _) => vec![], + Expr::Symbolic(_) => panic!("found symbolic"), } } @@ -228,6 +230,7 @@ impl Expr { } Expr::DoWhile(inputs, _body) => vec![inputs.clone()], Expr::Arg(_, _) => vec![], + Expr::Symbolic(_) => panic!("found symbolic"), } } @@ -248,6 +251,7 @@ impl Expr { Expr::DoWhile(x, _) => x.get_arg_type(), Expr::Arg(ty, _) => ty.clone(), Expr::Function(_, ty, _, _) => ty.clone(), + Expr::Symbolic(_) => panic!("found symbolic"), } } @@ -268,6 +272,7 @@ impl Expr { Expr::DoWhile(x, _) => x.get_ctx(), Expr::Arg(_, ctx) => ctx, Expr::Function(_, _, _, x) => x.get_ctx(), + Expr::Symbolic(_) => panic!("found symbolic"), } } @@ -383,6 +388,7 @@ impl Expr { Rc::new(Expr::Const(c.clone(), arg_ty.clone(), arg_ctx.clone())) } Expr::Empty(_, _) => Rc::new(Expr::Empty(arg_ty.clone(), arg_ctx.clone())), + Expr::Symbolic(_) => panic!("found symbolic"), }; // Add the substituted to cache @@ -751,6 +757,7 @@ impl Type { Type::Base(basety) => basety.contains_state(), Type::TupleT(types) => types.iter().any(|ty| ty.contains_state()), Type::Unknown => panic!("Unknown type"), + Type::Symbolic(_) => panic!("Symbolic type"), } } } diff --git a/dag_in_context/src/to_egglog.rs b/dag_in_context/src/to_egglog.rs index f0825cad4..85e776fca 100644 --- a/dag_in_context/src/to_egglog.rs +++ b/dag_in_context/src/to_egglog.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, rc::Rc}; +use std::{collections::HashMap, rc::Rc, vec}; use egglog::{ ast::{Literal, Symbol}, @@ -99,6 +99,7 @@ impl Type { // Unknown shouldn't show up in the egglog file, but is useful for printing // before types are annotated. Type::Unknown => term_dag.app("Unknown".into(), vec![]), + Type::Symbolic(str) => term_dag.var(str.into()), } } } @@ -257,6 +258,7 @@ impl Expr { let name_lit = term_dag.lit(Literal::String(name.into())); term_dag.app("Function".into(), vec![name_lit, ty_in, ty_out, body]) } + Expr::Symbolic(name) => term_dag.var(name.into()), }; term_dag diff --git a/src/util.rs b/src/util.rs index ce955c2fe..27c53db41 100644 --- a/src/util.rs +++ b/src/util.rs @@ -634,11 +634,11 @@ impl Run { RunType::PrettyPrint => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; let dag = rvsdg.to_dag_encoding(true); - let res = std::iter::once(PrettyPrinter { expr: dag.entry }.to_rust_default()) + let res = std::iter::once(PrettyPrinter::from_expr(dag.entry).to_rust_default()) .chain( dag.functions .into_iter() - .map(|expr| PrettyPrinter { expr }.to_rust_default()), + .map(|expr| PrettyPrinter::from_expr(expr).to_rust_default()), ) .collect::>() .join("\n\n"); @@ -656,16 +656,14 @@ impl Run { let dag = rvsdg.to_dag_encoding(true); let optimized = dag_in_context::optimize(&dag).map_err(EggCCError::EggLog)?; let res = std::iter::once( - PrettyPrinter { - expr: optimized.entry, - } + PrettyPrinter::from_expr(optimized.entry) .to_rust_default(), ) .chain( optimized .functions .into_iter() - .map(|expr| PrettyPrinter { expr }.to_rust_default()), + .map(|expr| PrettyPrinter::from_expr(expr).to_rust_default()), ) .collect::>() .join("\n\n");