diff --git a/Cargo.lock b/Cargo.lock index ab96e00e..e72c7f2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -333,7 +333,7 @@ dependencies = [ [[package]] name = "egraph-serialize" version = "0.1.0" -source = "git+https://github.com/egraphs-good/egraph-serialize?rev=9ce281291635b0e1e7685b488de67bb5a3fee3db#9ce281291635b0e1e7685b488de67bb5a3fee3db" +source = "git+https://github.com/egraphs-good/egraph-serialize?rev=325f7c6b4b909752ee0f57f9619f698a52cc343e#325f7c6b4b909752ee0f57f9619f698a52cc343e" dependencies = [ "graphviz-rust", "indexmap", diff --git a/Cargo.toml b/Cargo.toml index 6a9e397f..cb8838a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ smallvec = "1.11" generic_symbolic_expressions = "5.0.4" -egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "9ce281291635b0e1e7685b488de67bb5a3fee3db", features = [ +egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "325f7c6b4b909752ee0f57f9619f698a52cc343e", features = [ "serde", "graphviz", ] } diff --git a/src/serialize.rs b/src/serialize.rs index 9b56a38d..4746e2ed 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -2,7 +2,7 @@ use ordered_float::NotNan; use std::collections::VecDeque; use symbol_table::GlobalSymbol; -use crate::{ast::ResolvedFunctionDecl, util::HashMap, EGraph, Value}; +use crate::{ast::ResolvedFunctionDecl, util::HashMap, EGraph, TupleOutput, Value}; pub struct SerializeConfig { // Maximumum number of functions to include in the serialized graph, any after this will be discarded @@ -87,7 +87,7 @@ impl EGraph { let all_calls: Vec<( &ResolvedFunctionDecl, &[Value], - &Value, + &TupleOutput, egraph_serialize::ClassId, egraph_serialize::NodeId, )> = self @@ -103,7 +103,7 @@ impl EGraph { ( &function.decl, input, - &output.value, + output, self.value_to_class_id(&output.value), self.to_node_id(SerializedNode::Function { name: *name, @@ -126,7 +126,11 @@ impl EGraph { let mut node_ids: NodeIDs = all_calls.iter().fold( HashMap::default(), |mut acc, (_decl, _input, output, class_id, node_id)| { - if self.get_sort_from_value(output).unwrap().is_eq_sort() { + if self + .get_sort_from_value(&output.value) + .unwrap() + .is_eq_sort() + { acc.entry(class_id.clone()) .or_insert_with(VecDeque::new) .push_back(node_id.clone()); @@ -137,7 +141,8 @@ impl EGraph { let mut egraph = egraph_serialize::EGraph::default(); for (decl, input, output, class_id, node_id) in all_calls { - self.serialize_value(&mut egraph, &mut node_ids, output, &class_id); + self.serialize_value(&mut egraph, &mut node_ids, &output.value, &class_id); + let children: Vec<_> = input .iter() .map(|v| { @@ -151,6 +156,9 @@ impl EGraph { eclass: class_id.clone(), cost: NotNan::new(decl.cost.unwrap_or(1) as f64).unwrap(), children, + data: egraph_serialize::NodeData { + subsumed: output.subsumed, + }, }, ); } @@ -256,6 +264,7 @@ impl EGraph { eclass: class_id.clone(), cost: NotNan::new(f64::INFINITY).unwrap(), children: vec![], + data: egraph_serialize::NodeData { subsumed: false }, }, ); VecDeque::from(vec![node_id]) @@ -287,6 +296,8 @@ impl EGraph { eclass: class_id.clone(), cost: NotNan::new(1.0).unwrap(), children, + // primitives can never be subsumed + data: egraph_serialize::NodeData { subsumed: false }, }, ); }; diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 7cc6b24c..f3e9f982 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,4 +1,6 @@ -use egglog::{ast::Expr, EGraph, ExtractReport, Function, SerializeConfig, Term, Value}; +use egglog::{ + ast::Expr, sort::EqSort, EGraph, ExtractReport, Function, SerializeConfig, Term, Value, +}; use symbol_table::GlobalSymbol; #[test] @@ -406,3 +408,34 @@ fn test_value_to_classid() { assert!(serialized.class_data.get(&class_id).is_some()); assert_eq!(value, egraph.class_id_to_value(&class_id)); } + +#[test] +fn test_serialize_subsume_status() { + let mut egraph = EGraph::default(); + + egraph + .parse_and_run_program( + None, + r#" + (datatype Math) + (function a () Math ) + (function b () Math ) + (a) + (b) + (subsume (a)) + "#, + ) + .unwrap(); + + let serialized = egraph.serialize(SerializeConfig::default()); + let a_id = egraph.to_node_id(egglog::SerializedNode::Function { + name: ("a").into(), + offset: 0, + }); + let b_id = egraph.to_node_id(egglog::SerializedNode::Function { + name: "b".into(), + offset: 0, + }); + assert!(serialized.nodes[&a_id].data.subsumed); + assert!(!serialized.nodes[&b_id].data.subsumed); +}