diff --git a/src/serialize.rs b/src/serialize.rs index c15ee7f7..3367aa35 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -2,10 +2,7 @@ use ordered_float::NotNan; use std::collections::VecDeque; use crate::{ - ast::{Id, ResolvedFunctionDecl}, - function::table::hash_values, - util::HashMap, - EGraph, Value, + ast::ResolvedFunctionDecl, function::table::hash_values, util::HashMap, EGraph, Value, }; pub struct SerializeConfig { @@ -101,10 +98,7 @@ impl EGraph { .iter() .filter_map(|(_decl, _input, output, node_id)| { if self.get_sort_from_value(output).unwrap().is_eq_sort() { - let id = output.bits as usize; - let canonical: usize = self.unionfind.find(Id::from(id)).into(); - let canonical_id: egraph_serialize::ClassId = canonical.to_string().into(); - Some((canonical_id, node_id)) + Some((self.value_to_class_id(output), node_id)) } else { None } @@ -118,17 +112,21 @@ impl EGraph { let mut egraph = egraph_serialize::EGraph::default(); for (decl, input, output, node_id) in all_calls { - let prim_node_id = if config.split_primitive_outputs { - Some(format!("{}-value", node_id.clone())) + // If we are splitting primitive outputs, then we will use the function node ID as the e-class for the output, so + // that even if two functions have the same primitive output, they will be in different e-classes. + let eclass = if config.split_primitive_outputs + && !self.get_sort_from_value(output).unwrap().is_eq_sort() + { + format!("{}-value", node_id.clone()).into() } else { - None + self.value_to_class_id(output) }; - let eclass = self - .serialize_value(&mut egraph, &mut node_ids, output, prim_node_id) - .0; + self.serialize_value(&mut egraph, &mut node_ids, output, &eclass); let children: Vec<_> = input .iter() - .map(|v| self.serialize_value(&mut egraph, &mut node_ids, v, None).1) + .map(|v| { + self.serialize_value(&mut egraph, &mut node_ids, v, &self.value_to_class_id(v)) + }) .collect(); egraph.nodes.insert( node_id, @@ -144,12 +142,35 @@ impl EGraph { egraph.root_eclasses = config .root_eclasses .iter() - .map(|v| self.serialize_value(&mut egraph, &mut node_ids, v, None).0) + .map(|v| self.value_to_class_id(v)) .collect(); egraph } + /// Gets the serialized class ID for a value. + pub fn value_to_class_id(&self, value: &Value) -> egraph_serialize::ClassId { + // Canonicalize the value first so that we always use the canonical e-class ID + let sort = self.get_sort_from_value(value).unwrap(); + let mut value = *value; + sort.canonicalize(&mut value, &self.unionfind); + assert!( + !value.tag.to_string().contains('-'), + "Tag cannot contain '-' when serializing" + ); + format!("{}-{}", value.tag, value.bits).into() + } + + /// Gets the value for a serialized class ID. + pub fn class_id_to_value(&self, eclass_id: &egraph_serialize::ClassId) -> Value { + let s = eclass_id.to_string(); + let (tag, bits) = s.split_once('-').unwrap(); + Value { + tag: tag.into(), + bits: bits.parse().unwrap(), + } + } + /// Serialize the value and return the eclass and node ID /// If this is a primitive value, we will add the node to the data, but if it is an eclass, we will not /// When this is called on the output of a node, we only use the e-class to know which e-class its a part of @@ -159,86 +180,65 @@ impl EGraph { egraph: &mut egraph_serialize::EGraph, node_ids: &mut NodeIDs, value: &Value, - // The node ID to use for a primitive value, if this is None, use the hash of the value and the sort name - // Set iff `split_primitive_outputs` is set and this is an output of a function. - prim_node_id: Option, - ) -> (egraph_serialize::ClassId, egraph_serialize::NodeId) { + class_id: &egraph_serialize::ClassId, + ) -> egraph_serialize::NodeId { let sort = self.get_sort_from_value(value).unwrap(); - let (class_id, node_id): (egraph_serialize::ClassId, egraph_serialize::NodeId) = - if sort.is_eq_sort() { - let id: usize = value.bits as usize; - let canonical: usize = self.unionfind.find(Id::from(id)).into(); - let class_id: egraph_serialize::ClassId = canonical.to_string().into(); - (class_id.clone(), get_node_id(egraph, node_ids, class_id)) - } else { - let (class_id, node_id): (egraph_serialize::ClassId, egraph_serialize::NodeId) = - if let Some(node_id) = prim_node_id { - (node_id.clone().into(), node_id.into()) - } else { - let sort_name = sort.name().to_string(); - let node_id_str = - format!("{}-{}", sort_name, hash_values(vec![*value].as_slice())); - (node_id_str.clone().into(), node_id_str.into()) - }; - // Add node for value - { - // Children will be empty unless this is a container sort - let children: Vec = sort - .inner_values(value) - .into_iter() - .map(|(_, v)| self.serialize_value(egraph, node_ids, &v, None).1) - .collect(); - // If this is a container sort, use the name, otherwise use the value - let op = if sort.is_container_sort() { - sort.serialized_name(value).to_string() - } else { - sort.make_expr(self, *value).1.to_string() - }; - egraph.nodes.insert( - node_id.clone(), - egraph_serialize::Node { - op, - eclass: class_id.clone(), - cost: NotNan::new(0.0).unwrap(), - children, - }, - ); + let node_id = if sort.is_eq_sort() { + let node_ids = node_ids.entry(class_id.clone()).or_insert_with(|| { + // If we don't find node IDs for this class, it means that all nodes for it were omitted due to size constraints + // In this case, add a dummy node in this class to represent the missing nodes + let node_id = egraph_serialize::NodeId::from(format!("{}-dummy", class_id)); + egraph.nodes.insert( + node_id.clone(), + egraph_serialize::Node { + op: "[...]".to_string(), + eclass: class_id.clone(), + cost: NotNan::new(f64::INFINITY).unwrap(), + children: vec![], + }, + ); + VecDeque::from(vec![node_id]) + }); + node_ids.rotate_left(1); + node_ids.front().unwrap().clone() + } else { + let node_id: egraph_serialize::NodeId = class_id.to_string().into(); + // Add node for value + { + // Children will be empty unless this is a container sort + let children: Vec = sort + .inner_values(value) + .into_iter() + .map(|(_, v)| { + self.serialize_value(egraph, node_ids, &v, &self.value_to_class_id(&v)) + }) + .collect(); + // If this is a container sort, use the name, otherwise use the value + let op = if sort.is_container_sort() { + sort.serialized_name(value).to_string() + } else { + sort.make_expr(self, *value).1.to_string() }; - (class_id, node_id) + egraph.nodes.insert( + node_id.clone(), + egraph_serialize::Node { + op, + eclass: class_id.clone(), + cost: NotNan::new(1.0).unwrap(), + children, + }, + ); }; + node_id + }; egraph.class_data.insert( class_id.clone(), egraph_serialize::ClassData { typ: Some(sort.name().to_string()), }, ); - (class_id, node_id) + node_id } } type NodeIDs = HashMap>; - -/// Returns the node ID for the given class ID, rotating the queue -fn get_node_id( - egraph: &mut egraph_serialize::EGraph, - node_ids: &mut HashMap>, - class_id: egraph_serialize::ClassId, -) -> egraph_serialize::NodeId { - // If we don't find node IDs for this class, it means that all nodes for it were omitted due to size constraints - // In this case, add a dummy node in this class to represent the missing nodes - let node_ids = node_ids.entry(class_id.clone()).or_insert_with(|| { - let node_id = egraph_serialize::NodeId::from(format!("{}-dummy", class_id)); - egraph.nodes.insert( - node_id.clone(), - egraph_serialize::Node { - op: "[...]".to_string(), - eclass: class_id.clone(), - cost: NotNan::new(f64::INFINITY).unwrap(), - children: vec![], - }, - ); - VecDeque::from(vec![node_id]) - }); - node_ids.rotate_left(1); - node_ids.front().unwrap().clone() -} diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 50026050..660913c4 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,4 +1,4 @@ -use egglog::{ast::Expr, EGraph, ExtractReport, Function, Term, Value}; +use egglog::{ast::Expr, EGraph, ExtractReport, Function, SerializeConfig, Term, Value}; use symbol_table::GlobalSymbol; #[test] @@ -377,3 +377,31 @@ fn test_cant_subsume_merge() { ); assert!(res.is_err()); } + +#[test] +fn test_value_to_classid() { + let mut egraph = EGraph::default(); + + egraph + .parse_and_run_program( + None, + r#" + (datatype Math) + (function exp () Math ) + (exp) + (query-extract (exp)) + "#, + ) + .unwrap(); + let report = egraph.get_extract_report().clone().unwrap(); + let ExtractReport::Best { term, termdag, .. } = report else { + panic!(); + }; + let expr = termdag.term_to_expr(&term); + let value = egraph.eval_expr(&expr).unwrap().1; + + let serialized = egraph.serialize(SerializeConfig::default()); + let class_id = egraph.value_to_class_id(&value); + assert!(serialized.class_data.get(&class_id).is_some()); + assert_eq!(value, egraph.class_id_to_value(&class_id)); +}