Skip to content

Commit

Permalink
Don't create new TypeInfo when desugaring
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Aug 24, 2023
1 parent 9e53038 commit 0fafc19
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 60 deletions.
5 changes: 4 additions & 1 deletion src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn normalize_expr(
panic!("handled above");
}
Expr::Call(f, children) => {
let is_compute = TypeInfo::default().is_primitive(*f);
let is_compute = desugar.type_info.is_primitive(*f);
let mut new_children = vec![];
for child in children {
match child {
Expand Down Expand Up @@ -418,6 +418,7 @@ pub struct Desugar {
// TODO fix getting fresh names using modules
pub(crate) number_underscores: usize,
pub(crate) global_variables: HashSet<Symbol>,
pub(crate) type_info: TypeInfo,
}

impl Default for Desugar {
Expand All @@ -429,6 +430,7 @@ impl Default for Desugar {
parser: ast::parse::ProgramParser::new(),
number_underscores: 3,
global_variables: Default::default(),
type_info: TypeInfo::default(),
}
}
}
Expand Down Expand Up @@ -689,6 +691,7 @@ impl Clone for Desugar {
parser: ast::parse::ProgramParser::new(),
number_underscores: self.number_underscores,
global_variables: self.global_variables.clone(),
type_info: self.type_info.clone(),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ impl Function {
pub fn new(egraph: &EGraph, decl: &FunctionDecl) -> Result<Self, Error> {
let mut input = Vec::with_capacity(decl.schema.input.len());
for s in &decl.schema.input {
input.push(match egraph.proof_state.type_info.sorts.get(s) {
input.push(match egraph.desugar.type_info.sorts.get(s) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(*s))),
})
}

let output = match egraph.proof_state.type_info.sorts.get(&decl.schema.output) {
let output = match egraph.desugar.type_info.sorts.get(&decl.schema.output) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(decl.schema.output))),
};
Expand Down
44 changes: 20 additions & 24 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ pub mod ast;
mod extract;
mod function;
mod gj;
mod proofs;
mod serialize;
pub mod sort;
mod termdag;
Expand All @@ -12,6 +11,7 @@ mod unionfind;
pub mod util;
mod value;

use ast::desugar::Desugar;
use extract::Extractor;
use hashbrown::hash_map::Entry;
use index::ColumnIndex;
Expand All @@ -21,8 +21,6 @@ use sort::*;
pub use termdag::{Term, TermDag};
use thiserror::Error;

use proofs::ProofState;

use symbolic_expressions::Sexp;

use ast::*;
Expand Down Expand Up @@ -201,7 +199,7 @@ impl FromStr for CompilerPassStop {
pub struct EGraph {
egraphs: Vec<Self>,
unionfind: UnionFind,
pub(crate) proof_state: ProofState,
pub(crate) desugar: Desugar,
functions: HashMap<Symbol, Function>,
rulesets: HashMap<Symbol, HashMap<Symbol, Rule>>,
ruleset_iteration: HashMap<Symbol, usize>,
Expand Down Expand Up @@ -240,7 +238,7 @@ impl Default for EGraph {
functions: Default::default(),
rulesets: Default::default(),
ruleset_iteration: Default::default(),
proof_state: ProofState::default(),
desugar: Desugar::default(),
global_bindings: Default::default(),
match_limit: usize::MAX,
node_limit: usize::MAX,
Expand Down Expand Up @@ -469,10 +467,10 @@ impl EGraph {

pub fn eval_lit(&self, lit: &Literal) -> Value {
match lit {
Literal::Int(i) => i.store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::F64(f) => f.store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::String(s) => s.store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::Unit => ().store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::Int(i) => i.store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::F64(f) => f.store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::String(s) => s.store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::Unit => ().store(&self.desugar.type_info.get_sort()).unwrap(),
}
}

Expand Down Expand Up @@ -985,7 +983,7 @@ impl EGraph {
}
NormAction::LetLit(var, lit) => {
let value = self.eval_lit(lit);
let etype = self.proof_state.type_info.infer_literal(lit);
let etype = self.desugar.type_info.infer_literal(lit);
let present = self
.global_bindings
.insert(*var, (etype, value, self.timestamp));
Expand Down Expand Up @@ -1162,33 +1160,31 @@ impl EGraph {
}

pub fn set_underscores_for_desugaring(&mut self, underscores: usize) {
self.proof_state.desugar.number_underscores = underscores;
self.desugar.number_underscores = underscores;
}

fn process_command(
&mut self,
command: Command,
stop: CompilerPassStop,
) -> Result<Vec<NormCommand>, Error> {
let program = self.proof_state.desugar.desugar_program(
vec![command],
self.test_proofs,
self.seminaive,
)?;
let program =
self.desugar
.desugar_program(vec![command], self.test_proofs, self.seminaive)?;
if stop == CompilerPassStop::Desugar {
return Ok(program);
}

let type_info_before = self.proof_state.type_info.clone();
let type_info_before = self.desugar.type_info.clone();

self.proof_state.type_info.typecheck_program(&program)?;
self.desugar.type_info.typecheck_program(&program)?;
if stop == CompilerPassStop::TypecheckDesugared {
return Ok(program);
}

// reset type info
self.proof_state.type_info = type_info_before;
self.proof_state.type_info.typecheck_program(&program)?;
self.desugar.type_info = type_info_before;
self.desugar.type_info.typecheck_program(&program)?;
if stop == CompilerPassStop::TypecheckTermEncoding {
return Ok(program);
}
Expand Down Expand Up @@ -1222,11 +1218,11 @@ impl EGraph {
}

pub fn parse_program(&self, input: &str) -> Result<Vec<Command>, Error> {
self.proof_state.desugar.parse_program(input)
self.desugar.parse_program(input)
}

pub fn parse_and_run_program(&mut self, input: &str) -> Result<Vec<String>, Error> {
let parsed = self.proof_state.desugar.parse_program(input)?;
let parsed = self.desugar.parse_program(input)?;
self.run_program(parsed)
}

Expand All @@ -1235,11 +1231,11 @@ impl EGraph {
}

pub(crate) fn get_sort(&self, value: &Value) -> Option<&ArcSort> {
self.proof_state.type_info.sorts.get(&value.tag)
self.desugar.type_info.sorts.get(&value.tag)
}

pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> {
self.proof_state.type_info.add_arcsort(arcsort)
self.desugar.type_info.add_arcsort(arcsort)
}

/// Gets the last extract report and returns it, if the last command saved it.
Expand Down
11 changes: 0 additions & 11 deletions src/proofs.rs

This file was deleted.

2 changes: 1 addition & 1 deletion src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ impl EGraph {
///
/// Checks for pattern created by Desugar.get_fresh
fn is_temp_name(&self, name: String) -> bool {
let number_underscores = self.proof_state.desugar.number_underscores;
let number_underscores = self.desugar.number_underscores;
let res = name.starts_with('v')
&& name.ends_with("_".repeat(number_underscores).as_str())
&& name[1..name.len() - number_underscores]
Expand Down
27 changes: 8 additions & 19 deletions src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl<'a> Context<'a> {
pub fn new(egraph: &'a EGraph) -> Self {
Self {
egraph,
unit: egraph.proof_state.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(),
unit: egraph.desugar.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(),
types: Default::default(),
errors: Vec::default(),
unionfind: UnionFind::default(),
Expand Down Expand Up @@ -396,7 +396,7 @@ impl<'a> Context<'a> {
(self.add_node(ENode::Var(*sym)), ty)
}
Expr::Lit(lit) => {
let t = self.egraph.proof_state.type_info.infer_literal(lit);
let t = self.egraph.desugar.type_info.infer_literal(lit);
(self.add_node(ENode::Literal(lit.clone())), Some(t))
}
Expr::Call(sym, args) => {
Expand All @@ -415,7 +415,7 @@ impl<'a> Context<'a> {
.collect();
let t = f.schema.output.clone();
(self.add_node(ENode::Func(*sym, ids)), Some(t))
} else if let Some(prims) = self.egraph.proof_state.type_info.primitives.get(sym) {
} else if let Some(prims) = self.egraph.desugar.type_info.primitives.get(sym) {
let (ids, arg_tys): (Vec<Id>, Vec<Option<ArcSort>>) =
args.iter().map(|arg| self.infer_query_expr(arg)).unzip();

Expand Down Expand Up @@ -533,13 +533,7 @@ impl<'a> ExprChecker<'a> for ActionChecker<'a> {
}

fn do_function(&mut self, f: Symbol, _args: Vec<Self::T>) -> Self::T {
let func_type = self
.egraph
.proof_state
.type_info
.func_types
.get(&f)
.unwrap();
let func_type = self.egraph.desugar.type_info.func_types.get(&f).unwrap();
self.instructions.push(Instruction::CallFunction(
f,
func_type.has_default || !func_type.has_merge,
Expand Down Expand Up @@ -601,11 +595,11 @@ trait ExprChecker<'a> {
match expr {
Expr::Lit(lit) => {
let t = self.do_lit(lit);
Ok((t, self.egraph().proof_state.type_info.infer_literal(lit)))
Ok((t, self.egraph().desugar.type_info.infer_literal(lit)))
}
Expr::Var(sym) => self.infer_var(*sym),
Expr::Call(sym, args) => {
if let Some(functype) = self.egraph().proof_state.type_info.func_types.get(sym) {
if let Some(functype) = self.egraph().desugar.type_info.func_types.get(sym) {
assert!(functype.input.len() == args.len());

let mut ts = vec![];
Expand All @@ -615,8 +609,7 @@ trait ExprChecker<'a> {

let t = self.do_function(*sym, ts);
Ok((t, functype.output.clone()))
} else if let Some(prims) = self.egraph().proof_state.type_info.primitives.get(sym)
{
} else if let Some(prims) = self.egraph().desugar.type_info.primitives.get(sym) {
let mut ts = Vec::with_capacity(args.len());
let mut tys = Vec::with_capacity(args.len());
for arg in args {
Expand Down Expand Up @@ -880,11 +873,7 @@ impl EGraph {
let (cost, expr) = self.extract(
values[0],
&mut termdag,
self.proof_state
.type_info
.sorts
.get(&values[0].tag)
.unwrap(),
self.desugar.type_info.sorts.get(&values[0].tag).unwrap(),
);
let extracted = termdag.to_string(&expr);
log::info!("extracted with cost {cost}: {}", extracted);
Expand Down
6 changes: 4 additions & 2 deletions src/typechecking.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{proofs::RULE_PROOF_KEYWORD, *};
use crate::*;

pub const RULE_PROOF_KEYWORD: &str = "rule-proof";

#[derive(Clone, Debug)]
pub struct FuncType {
Expand Down Expand Up @@ -630,7 +632,7 @@ pub enum TypeError {
#[error("Arity mismatch, expected {expected} args: {expr}")]
Arity { expr: Expr, expected: usize },
#[error(
"Type mismatch: expr = {expr}, expected = {}, actual = {}, reason: {reason}",
"Type mismatch: expr = {expr}, expected = {}, actual = {}, reason: {reason}",
.expected.name(), .actual.name(),
)]
Mismatch {
Expand Down

0 comments on commit 0fafc19

Please sign in to comment.