Skip to content

Commit

Permalink
Serialize Whether Nodes Subsumed
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Sep 11, 2024
1 parent ae75bb7 commit 67d3b40
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ smallvec = "1.11"

generic_symbolic_expressions = "5.0.4"

egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "9ce281291635b0e1e7685b488de67bb5a3fee3db", features = [
egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "325f7c6b4b909752ee0f57f9619f698a52cc343e", features = [
"serde",
"graphviz",
] }
Expand Down
21 changes: 16 additions & 5 deletions src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ordered_float::NotNan;
use std::collections::VecDeque;
use symbol_table::GlobalSymbol;

use crate::{ast::ResolvedFunctionDecl, util::HashMap, EGraph, Value};
use crate::{ast::ResolvedFunctionDecl, util::HashMap, EGraph, TupleOutput, Value};

pub struct SerializeConfig {
// Maximumum number of functions to include in the serialized graph, any after this will be discarded
Expand Down Expand Up @@ -87,7 +87,7 @@ impl EGraph {
let all_calls: Vec<(
&ResolvedFunctionDecl,
&[Value],
&Value,
&TupleOutput,
egraph_serialize::ClassId,
egraph_serialize::NodeId,
)> = self
Expand All @@ -103,7 +103,7 @@ impl EGraph {
(
&function.decl,
input,
&output.value,
output,
self.value_to_class_id(&output.value),
self.to_node_id(SerializedNode::Function {
name: *name,
Expand All @@ -126,7 +126,11 @@ impl EGraph {
let mut node_ids: NodeIDs = all_calls.iter().fold(
HashMap::default(),
|mut acc, (_decl, _input, output, class_id, node_id)| {
if self.get_sort_from_value(output).unwrap().is_eq_sort() {
if self
.get_sort_from_value(&output.value)
.unwrap()
.is_eq_sort()
{
acc.entry(class_id.clone())
.or_insert_with(VecDeque::new)
.push_back(node_id.clone());
Expand All @@ -137,7 +141,8 @@ impl EGraph {

let mut egraph = egraph_serialize::EGraph::default();
for (decl, input, output, class_id, node_id) in all_calls {
self.serialize_value(&mut egraph, &mut node_ids, output, &class_id);
self.serialize_value(&mut egraph, &mut node_ids, &output.value, &class_id);

let children: Vec<_> = input
.iter()
.map(|v| {
Expand All @@ -151,6 +156,9 @@ impl EGraph {
eclass: class_id.clone(),
cost: NotNan::new(decl.cost.unwrap_or(1) as f64).unwrap(),
children,
data: egraph_serialize::NodeData {
subsumed: output.subsumed,
},
},
);
}
Expand Down Expand Up @@ -256,6 +264,7 @@ impl EGraph {
eclass: class_id.clone(),
cost: NotNan::new(f64::INFINITY).unwrap(),
children: vec![],
data: egraph_serialize::NodeData { subsumed: false },
},
);
VecDeque::from(vec![node_id])
Expand Down Expand Up @@ -287,6 +296,8 @@ impl EGraph {
eclass: class_id.clone(),
cost: NotNan::new(1.0).unwrap(),
children,
// primitives can never be subsumed
data: egraph_serialize::NodeData { subsumed: false },
},
);
};
Expand Down
35 changes: 34 additions & 1 deletion tests/integration_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use egglog::{ast::Expr, EGraph, ExtractReport, Function, SerializeConfig, Term, Value};
use egglog::{
ast::Expr, sort::EqSort, EGraph, ExtractReport, Function, SerializeConfig, Term, Value,
};
use symbol_table::GlobalSymbol;

#[test]
Expand Down Expand Up @@ -406,3 +408,34 @@ fn test_value_to_classid() {
assert!(serialized.class_data.get(&class_id).is_some());
assert_eq!(value, egraph.class_id_to_value(&class_id));
}

#[test]
fn test_serialize_subsume_status() {
let mut egraph = EGraph::default();

egraph
.parse_and_run_program(
None,
r#"
(datatype Math)
(function a () Math )
(function b () Math )
(a)
(b)
(subsume (a))
"#,
)
.unwrap();

let serialized = egraph.serialize(SerializeConfig::default());
let a_id = egraph.to_node_id(egglog::SerializedNode::Function {
name: ("a").into(),
offset: 0,
});
let b_id = egraph.to_node_id(egglog::SerializedNode::Function {
name: "b".into(),
offset: 0,
});
assert!(serialized.nodes[&a_id].data.subsumed);
assert!(!serialized.nodes[&b_id].data.subsumed);
}

0 comments on commit 67d3b40

Please sign in to comment.