Skip to content

Commit

Permalink
Add type_info alias method
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Aug 26, 2023
1 parent 0fafc19 commit 7ec920e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ impl Function {
pub fn new(egraph: &EGraph, decl: &FunctionDecl) -> Result<Self, Error> {
let mut input = Vec::with_capacity(decl.schema.input.len());
for s in &decl.schema.input {
input.push(match egraph.desugar.type_info.sorts.get(s) {
input.push(match egraph.type_info().sorts.get(s) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(*s))),
})
}

let output = match egraph.desugar.type_info.sorts.get(&decl.schema.output) {
let output = match egraph.type_info().sorts.get(&decl.schema.output) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(decl.schema.output))),
};
Expand Down
18 changes: 11 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,10 @@ impl EGraph {

pub fn eval_lit(&self, lit: &Literal) -> Value {
match lit {
Literal::Int(i) => i.store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::F64(f) => f.store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::String(s) => s.store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::Unit => ().store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::Int(i) => i.store(&self.type_info().get_sort()).unwrap(),
Literal::F64(f) => f.store(&self.type_info().get_sort()).unwrap(),
Literal::String(s) => s.store(&self.type_info().get_sort()).unwrap(),
Literal::Unit => ().store(&self.type_info().get_sort()).unwrap(),
}
}

Expand Down Expand Up @@ -983,7 +983,7 @@ impl EGraph {
}
NormAction::LetLit(var, lit) => {
let value = self.eval_lit(lit);
let etype = self.desugar.type_info.infer_literal(lit);
let etype = self.type_info().infer_literal(lit);
let present = self
.global_bindings
.insert(*var, (etype, value, self.timestamp));
Expand Down Expand Up @@ -1175,7 +1175,7 @@ impl EGraph {
return Ok(program);
}

let type_info_before = self.desugar.type_info.clone();
let type_info_before = self.type_info().clone();

self.desugar.type_info.typecheck_program(&program)?;
if stop == CompilerPassStop::TypecheckDesugared {
Expand Down Expand Up @@ -1231,7 +1231,7 @@ impl EGraph {
}

pub(crate) fn get_sort(&self, value: &Value) -> Option<&ArcSort> {
self.desugar.type_info.sorts.get(&value.tag)
self.type_info().sorts.get(&value.tag)
}

pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> {
Expand Down Expand Up @@ -1263,6 +1263,10 @@ impl EGraph {
self.msgs.dedup_by(|a, b| a.is_empty() && b.is_empty());
std::mem::take(&mut self.msgs)
}

pub(crate) fn type_info(&self) -> &TypeInfo {
&self.desugar.type_info
}
}

#[derive(Debug, Error)]
Expand Down
16 changes: 8 additions & 8 deletions src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl<'a> Context<'a> {
pub fn new(egraph: &'a EGraph) -> Self {
Self {
egraph,
unit: egraph.desugar.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(),
unit: egraph.type_info().sorts[&Symbol::from(UNIT_SYM)].clone(),
types: Default::default(),
errors: Vec::default(),
unionfind: UnionFind::default(),
Expand Down Expand Up @@ -396,7 +396,7 @@ impl<'a> Context<'a> {
(self.add_node(ENode::Var(*sym)), ty)
}
Expr::Lit(lit) => {
let t = self.egraph.desugar.type_info.infer_literal(lit);
let t = self.egraph.type_info().infer_literal(lit);
(self.add_node(ENode::Literal(lit.clone())), Some(t))
}
Expr::Call(sym, args) => {
Expand All @@ -415,7 +415,7 @@ impl<'a> Context<'a> {
.collect();
let t = f.schema.output.clone();
(self.add_node(ENode::Func(*sym, ids)), Some(t))
} else if let Some(prims) = self.egraph.desugar.type_info.primitives.get(sym) {
} else if let Some(prims) = self.egraph.type_info().primitives.get(sym) {
let (ids, arg_tys): (Vec<Id>, Vec<Option<ArcSort>>) =
args.iter().map(|arg| self.infer_query_expr(arg)).unzip();

Expand Down Expand Up @@ -533,7 +533,7 @@ impl<'a> ExprChecker<'a> for ActionChecker<'a> {
}

fn do_function(&mut self, f: Symbol, _args: Vec<Self::T>) -> Self::T {
let func_type = self.egraph.desugar.type_info.func_types.get(&f).unwrap();
let func_type = self.egraph.type_info().func_types.get(&f).unwrap();
self.instructions.push(Instruction::CallFunction(
f,
func_type.has_default || !func_type.has_merge,
Expand Down Expand Up @@ -595,11 +595,11 @@ trait ExprChecker<'a> {
match expr {
Expr::Lit(lit) => {
let t = self.do_lit(lit);
Ok((t, self.egraph().desugar.type_info.infer_literal(lit)))
Ok((t, self.egraph().type_info().infer_literal(lit)))
}
Expr::Var(sym) => self.infer_var(*sym),
Expr::Call(sym, args) => {
if let Some(functype) = self.egraph().desugar.type_info.func_types.get(sym) {
if let Some(functype) = self.egraph().type_info().func_types.get(sym) {
assert!(functype.input.len() == args.len());

let mut ts = vec![];
Expand All @@ -609,7 +609,7 @@ trait ExprChecker<'a> {

let t = self.do_function(*sym, ts);
Ok((t, functype.output.clone()))
} else if let Some(prims) = self.egraph().desugar.type_info.primitives.get(sym) {
} else if let Some(prims) = self.egraph().type_info().primitives.get(sym) {
let mut ts = Vec::with_capacity(args.len());
let mut tys = Vec::with_capacity(args.len());
for arg in args {
Expand Down Expand Up @@ -873,7 +873,7 @@ impl EGraph {
let (cost, expr) = self.extract(
values[0],
&mut termdag,
self.desugar.type_info.sorts.get(&values[0].tag).unwrap(),
self.type_info().sorts.get(&values[0].tag).unwrap(),
);
let extracted = termdag.to_string(&expr);
log::info!("extracted with cost {cost}: {}", extracted);
Expand Down

0 comments on commit 7ec920e

Please sign in to comment.