Skip to content

Commit

Permalink
extract may fail
Browse files Browse the repository at this point in the history
  • Loading branch information
yihozhang committed Aug 24, 2024
1 parent 5a75caa commit b424a0b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ impl EGraph {
values[0],
&mut termdag,
self.type_info().sorts.get(&values[0].tag).unwrap(),
);
)?;
let extracted = termdag.to_string(&term);
log::info!("extracted with cost {cost}: {extracted}");
self.print_msg(extracted);
Expand Down
48 changes: 25 additions & 23 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use hashbrown::hash_map::Entry;
use crate::ast::Symbol;
use crate::termdag::{Term, TermDag};
use crate::util::HashMap;
use crate::{ArcSort, EGraph, Function, Id, Value};
use crate::{ArcSort, EGraph, Error, Function, Id, Value};

pub type Cost = usize;

Expand Down Expand Up @@ -37,33 +37,35 @@ impl EGraph {
/// let (sort, value) = egraph
/// .eval_expr(&egglog::ast::Expr::var_no_span("expr"))
/// .unwrap();
/// let (_, extracted) = egraph.extract(value, &mut termdag, &sort);
/// let (_, extracted) = egraph.extract(value, &mut termdag, &sort).unwrap();
/// assert_eq!(termdag.to_string(&extracted), "(Add 1 1)");
/// ```
pub fn extract(&self, value: Value, termdag: &mut TermDag, arcsort: &ArcSort) -> (Cost, Term) {
pub fn extract(
&self,
value: Value,
termdag: &mut TermDag,
arcsort: &ArcSort,
) -> Result<(Cost, Term), Error> {
let extractor = Extractor::new(self, termdag);
extractor
.find_best(value, termdag, arcsort)
.unwrap_or_else(|| {
log::error!("No cost for {:?}", value);
for func in self.functions.values() {
for (inputs, output) in func.nodes.iter(false) {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);
log::error!(
"{:?}",
inputs
.iter()
.map(|input| extractor.costs.get(&extractor.find_id(*input)))
.collect::<Vec<_>>()
);
}
extractor.find_best(value, termdag, arcsort).ok_or_else(|| {
log::error!("No cost for {:?}", value);
for func in self.functions.values() {
for (inputs, output) in func.nodes.iter(false) {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);
log::error!(
"{:?}",
inputs
.iter()
.map(|input| extractor.costs.get(&extractor.find_id(*input)))
.collect::<Vec<_>>()
);
}
}

panic!("No cost for {:?}", value)
})
}
Error::ExtractError(value)
})
}

pub fn extract_variants(
Expand Down
16 changes: 9 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -888,18 +888,18 @@ impl EGraph {
/// Extract a value to a [`TermDag`] and [`Term`]
/// in the [`TermDag`].
/// See also extract_value_to_string for convenience.
pub fn extract_value(&self, value: Value) -> (TermDag, Term) {
pub fn extract_value(&self, value: Value) -> Result<(TermDag, Term), Error> {
let mut termdag = TermDag::default();
let sort = self.type_info().sorts.get(&value.tag).unwrap();
let term = self.extract(value, &mut termdag, sort).1;
(termdag, term)
let term = self.extract(value, &mut termdag, sort)?.1;
Ok((termdag, term))
}

/// Extract a value to a string for printing.
/// See also extract_value for more control.
pub fn extract_value_to_string(&self, value: Value) -> String {
let (termdag, term) = self.extract_value(value);
termdag.to_string(&term)
pub fn extract_value_to_string(&self, value: Value) -> Result<String, Error> {
let (termdag, term) = self.extract_value(value)?;
Ok(termdag.to_string(&term))
}

fn run_rules(&mut self, span: &Span, config: &ResolvedRunConfig) -> RunReport {
Expand Down Expand Up @@ -1376,7 +1376,7 @@ impl EGraph {
for expr in exprs {
let value = self.eval_resolved_expr(&expr, true)?;
let expr_type = expr.output_type(self.type_info());
let term = self.extract(value, &mut termdag, &expr_type).1;
let term = self.extract(value, &mut termdag, &expr_type)?.1;
use std::io::Write;
writeln!(f, "{}", termdag.to_string(&term))
.map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
Expand Down Expand Up @@ -1655,6 +1655,8 @@ pub enum Error {
IoError(PathBuf, std::io::Error, Span),
#[error("Cannot subsume function with merge: {0}")]
SubsumeMergeError(Symbol),
#[error("extraction failure: {:?}", .0)]
ExtractError(Value),
}

#[cfg(test)]
Expand Down

0 comments on commit b424a0b

Please sign in to comment.