Skip to content

Commit

Permalink
Fix rebuilding and extraction bugs for EqSort containers (#191)
Browse files Browse the repository at this point in the history
* fix backoff

* fix

* add test

* add trailing newline

* fix container (and primitive) extraction

* fix extraction

* fix environemtn canonicalization

* nits and use saturating_add for cost

* add an additional example
  • Loading branch information
yihozhang authored Aug 21, 2023
1 parent 94c5a7d commit f93adcb
Show file tree
Hide file tree
Showing 14 changed files with 241 additions and 88 deletions.
91 changes: 45 additions & 46 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ use crate::termdag::{Term, TermDag};
use crate::util::HashMap;
use crate::{ArcSort, EGraph, Function, Id, Value};

type Cost = usize;
pub type Cost = usize;

#[derive(Debug)]
pub(crate) struct Node<'a> {
sym: Symbol,
inputs: &'a [Value],
}

pub(crate) struct Extractor<'a> {
costs: HashMap<Id, (Cost, Term)>,
pub struct Extractor<'a> {
pub costs: HashMap<Id, (Cost, Term)>,
ctors: Vec<Symbol>,
egraph: &'a EGraph,
}
Expand All @@ -31,7 +31,29 @@ impl EGraph {
}

pub fn extract(&self, value: Value, termdag: &mut TermDag, arcsort: &ArcSort) -> (Cost, Term) {
Extractor::new(self, termdag).find_best(value, termdag, arcsort)
let extractor = Extractor::new(self, termdag);
extractor
.find_best(value, termdag, arcsort)
.unwrap_or_else(|| {
log::error!("No cost for {:?}", value);
for func in self.functions.values() {
for (inputs, output) in func.nodes.iter() {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);
log::error!(
"{:?}",
inputs
.iter()
.map(|input| extractor.costs.get(&extractor.find(input)))
.collect::<Vec<_>>()
);
}
}
}

panic!("No cost for {:?}", value)
})
}

pub fn extract_variants(
Expand All @@ -57,7 +79,9 @@ impl EGraph {
.filter_map(|(inputs, output)| {
(&output.value == output_value).then(|| {
let node = Node { sym, inputs };
ext.expr_from_node(&node, termdag)
ext.expr_from_node(&node, termdag).expect(
"extract_variants should be called after extractor initialization",
)
})
})
.collect()
Expand Down Expand Up @@ -89,46 +113,29 @@ impl<'a> Extractor<'a> {
extractor
}

fn expr_from_node(&self, node: &Node, termdag: &mut TermDag) -> Term {
fn expr_from_node(&self, node: &Node, termdag: &mut TermDag) -> Option<Term> {
let mut children = vec![];
for value in node.inputs {
let arcsort = self.egraph.get_sort(value).unwrap();
children.push(self.find_best(*value, termdag, arcsort).1)
children.push(self.find_best(*value, termdag, arcsort)?.1)
}

termdag.make(node.sym, children)
Some(termdag.make(node.sym, children))
}

pub fn find_best(&self, value: Value, termdag: &mut TermDag, sort: &ArcSort) -> (Cost, Term) {
pub fn find_best(
&self,
value: Value,
termdag: &mut TermDag,
sort: &ArcSort,
) -> Option<(Cost, Term)> {
if sort.is_eq_sort() {
let id = self.find(&value);
let (cost, node) = self
.costs
.get(&id)
.unwrap_or_else(|| {
log::error!("No cost for {:?}", value);
for func in self.egraph.functions.values() {
for (inputs, output) in func.nodes.iter() {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);
log::error!(
"{:?}",
inputs
.iter()
.map(|input| self.costs.get(&self.find(input)))
.collect::<Vec<_>>()
);
}
}
}

panic!("No cost for {:?}", value)
})
.clone();
(cost, node)
let (cost, node) = self.costs.get(&id)?.clone();
Some((cost, node))
} else {
(0, termdag.expr_to_term(&sort.make_expr(self.egraph, value)))
let (cost, node) = sort.extract_expr(self.egraph, value, self, termdag)?;
Some((cost, termdag.expr_to_term(&node)))
}
}

Expand All @@ -142,17 +149,9 @@ impl<'a> Extractor<'a> {
let types = &function.schema.input;
let mut terms: Vec<Term> = vec![];
for (ty, value) in types.iter().zip(children) {
cost = cost.saturating_add(if ty.is_eq_sort() {
let id = self.egraph.find(Id::from(value.bits as usize));
// TODO costs should probably map values?
let (cost, term) = self.costs.get(&id)?;
terms.push(term.clone());
*cost
} else {
let term = termdag.expr_to_term(&ty.make_expr(self.egraph, *value));
terms.push(term);
1
});
let (term_cost, term) = self.find_best(*value, termdag, ty)?;
terms.push(term.clone());
cost = cost.saturating_add(term_cost);
}
Some((terms, cost))
}
Expand Down
9 changes: 8 additions & 1 deletion src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,14 @@ impl Function {
) -> Result<(usize, Vec<DeferredMerge>), Error> {
// Make sure indexes are up to date.
self.update_indexes(self.nodes.num_offsets());
if self.schema.input.iter().all(|s| !s.is_eq_sort()) && !self.schema.output.is_eq_sort() {
if self
.schema
.input
.iter()
.all(|s| !s.is_eq_sort() && !s.is_eq_container_sort())
&& !self.schema.output.is_eq_sort()
&& !self.schema.output.is_eq_container_sort()
{
return Ok((std::mem::take(&mut self.updates), Default::default()));
}
let mut deferred_merges = Vec::new();
Expand Down
16 changes: 9 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,11 @@ impl EGraph {
}

// now update global bindings
let mut new_global_bindings = self.global_bindings.clone();
for (_sym, (_sort, value, ts)) in new_global_bindings.iter_mut() {
*value = self.bad_find_value(*value);
*ts = self.timestamp;
let mut new_global_bindings = std::mem::take(&mut self.global_bindings);
for (_sym, (sort, value, ts)) in new_global_bindings.iter_mut() {
if sort.canonicalize(value, &self.unionfind) {
*ts = self.timestamp;
}
}
self.global_bindings = new_global_bindings;

Expand Down Expand Up @@ -496,18 +497,19 @@ impl EGraph {
let mut children = Vec::new();
for (a, a_type) in ins.iter().copied().zip(&schema.input) {
if a_type.is_eq_sort() {
children.push(extractor.find_best(a, &mut termdag, a_type).1);
children.push(extractor.find_best(a, &mut termdag, a_type).unwrap().1);
} else {
children.push(termdag.expr_to_term(&a_type.make_expr(self, a)));
children.push(termdag.expr_to_term(&a_type.make_expr(self, a).1));
};
}

let out = if schema.output.is_eq_sort() {
extractor
.find_best(out.value, &mut termdag, &schema.output)
.unwrap()
.1
} else {
termdag.expr_to_term(&schema.output.make_expr(self, out.value))
termdag.expr_to_term(&schema.output.make_expr(self, out.value).1)
};
terms.push((termdag.make(sym, children), out));
}
Expand Down
2 changes: 1 addition & 1 deletion src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl EGraph {
log::warn!("{} is a container sort", sort.name());
sort.name().to_string()
} else {
sort.make_expr(self, *value).to_string()
sort.make_expr(self, *value).1.to_string()
};
egraph.nodes.insert(
node_id.clone(),
Expand Down
7 changes: 5 additions & 2 deletions src/sort/f64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ impl Sort for F64Sort {

}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
assert!(value.tag == self.name());
Expr::Lit(Literal::F64(OrderedFloat(f64::from_bits(value.bits))))
(
1,
Expr::Lit(Literal::F64(OrderedFloat(f64::from_bits(value.bits)))),
)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/sort/i64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ impl Sort for I64Sort {

}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
assert!(value.tag == self.name());
Expr::Lit(Literal::Int(value.bits as _))
(1, Expr::Lit(Literal::Int(value.bits as _)))
}
}

Expand Down
26 changes: 20 additions & 6 deletions src/sort/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,33 @@ impl Sort for MapSort {
});
}

fn make_expr(&self, egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr) {
let mut termdag = TermDag::default();
let extractor = Extractor::new(egraph, &mut termdag);
self.extract_expr(egraph, value, &extractor, &mut termdag)
.expect("Extraction should be successful since extractor has been fully initialized")
}

fn extract_expr(
&self,
_egraph: &EGraph,
value: Value,
extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
let map = ValueMap::load(self, &value);
let mut expr = Expr::call("map-empty", []);
let mut termdag = TermDag::default();
let mut cost = 0usize;
for (k, v) in map.iter().rev() {
let k = egraph.extract(*k, &mut termdag, &self.key).1;
let v = egraph.extract(*v, &mut termdag, &self.value).1;
let k = extractor.find_best(*k, termdag, &self.key)?;
let v = extractor.find_best(*v, termdag, &self.value)?;
cost = cost.saturating_add(k.0).saturating_add(v.0);
expr = Expr::call(
"map-insert",
[expr, termdag.term_to_expr(&k), termdag.term_to_expr(&v)],
[expr, termdag.term_to_expr(&k.1), termdag.term_to_expr(&v.1)],
)
}
expr
Some((cost, expr))
}
}

Expand Down
21 changes: 19 additions & 2 deletions src/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub use set::*;
mod vec;
pub use vec::*;

use crate::extract::{Cost, Extractor};
use crate::*;

pub trait Sort: Any + Send + Sync + Debug {
Expand Down Expand Up @@ -69,7 +70,23 @@ pub trait Sort: Any + Send + Sync + Debug {
let _ = info;
}

fn make_expr(&self, egraph: &EGraph, value: Value) -> Expr;
/// Extracting an expression (with smallest cost) out of a primitive value
fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr);

/// For values like EqSort containers, to make/extract an expression from it
/// requires an extractor. Moreover, the extraction may be unsuccessful if
/// the extractor is not fully initialized.
///
/// The default behavior is to call make_expr
fn extract_expr(
&self,
egraph: &EGraph,
value: Value,
_extractor: &Extractor,
_termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
Some(self.make_expr(egraph, value))
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -101,7 +118,7 @@ impl Sort for EqSort {
}
}

fn make_expr(&self, _egraph: &EGraph, _value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, _value: Value) -> (Cost, Expr) {
unimplemented!("No make_expr for EqSort {}", self.name)
}
}
Expand Down
17 changes: 10 additions & 7 deletions src/sort/rational.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,20 @@ impl Sort for RationalSort {
add_primitives!(eg, "<=" = |a: R, b: R| -> Opt { if a <= b {Some(())} else {None} });
add_primitives!(eg, ">=" = |a: R, b: R| -> Opt { if a >= b {Some(())} else {None} });
}
fn make_expr(&self, _egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
assert!(value.tag == self.name());
let rat = R::load(self, &value);
let numer = *rat.numer();
let denom = *rat.denom();
Expr::call(
"rational",
vec![
Expr::Lit(Literal::Int(numer)),
Expr::Lit(Literal::Int(denom)),
],
(
1,
Expr::call(
"rational",
vec![
Expr::Lit(Literal::Int(numer)),
Expr::Lit(Literal::Int(denom)),
],
),
)
}
}
Expand Down
24 changes: 19 additions & 5 deletions src/sort/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,29 @@ impl Sort for SetSort {
});
}

fn make_expr(&self, egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr) {
let mut termdag = TermDag::default();
let extractor = Extractor::new(egraph, &mut termdag);
self.extract_expr(egraph, value, &extractor, &mut termdag)
.expect("Extraction should be successful since extractor has been fully initialized")
}

fn extract_expr(
&self,
_egraph: &EGraph,
value: Value,
extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
let set = ValueSet::load(self, &value);
let mut expr = Expr::call("set-empty", []);
let mut termdag = TermDag::default();
let mut cost = 0usize;
for e in set.iter().rev() {
let e = egraph.extract(*e, &mut termdag, &self.element).1;
expr = Expr::call("set-insert", [expr, termdag.term_to_expr(&e)])
let e = extractor.find_best(*e, termdag, &self.element)?;
cost = cost.saturating_add(e.0);
expr = Expr::call("set-insert", [expr, termdag.term_to_expr(&e.1)])
}
expr
Some((cost, expr))
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/sort/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ impl Sort for StringSort {
self
}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
assert!(value.tag == self.name);
let sym = Symbol::from(NonZeroU32::new(value.bits as _).unwrap());
Expr::Lit(Literal::String(sym))
(1, Expr::Lit(Literal::String(sym)))
}

fn register_primitives(self: Arc<Self>, typeinfo: &mut TypeInfo) {
Expand Down
Loading

0 comments on commit f93adcb

Please sign in to comment.