Skip to content

Commit

Permalink
Merge pull request #396 from saulshanabrook/back-and-forth
Browse files Browse the repository at this point in the history
Serialized Class ID <-> Value
  • Loading branch information
saulshanabrook authored Jul 31, 2024
2 parents fa45d46 + e8d4f72 commit 5d8c025
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 87 deletions.
172 changes: 86 additions & 86 deletions src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use ordered_float::NotNan;
use std::collections::VecDeque;

use crate::{
ast::{Id, ResolvedFunctionDecl},
function::table::hash_values,
util::HashMap,
EGraph, Value,
ast::ResolvedFunctionDecl, function::table::hash_values, util::HashMap, EGraph, Value,
};

pub struct SerializeConfig {
Expand Down Expand Up @@ -101,10 +98,7 @@ impl EGraph {
.iter()
.filter_map(|(_decl, _input, output, node_id)| {
if self.get_sort_from_value(output).unwrap().is_eq_sort() {
let id = output.bits as usize;
let canonical: usize = self.unionfind.find(Id::from(id)).into();
let canonical_id: egraph_serialize::ClassId = canonical.to_string().into();
Some((canonical_id, node_id))
Some((self.value_to_class_id(output), node_id))
} else {
None
}
Expand All @@ -118,17 +112,21 @@ impl EGraph {

let mut egraph = egraph_serialize::EGraph::default();
for (decl, input, output, node_id) in all_calls {
let prim_node_id = if config.split_primitive_outputs {
Some(format!("{}-value", node_id.clone()))
// If we are splitting primitive outputs, then we will use the function node ID as the e-class for the output, so
// that even if two functions have the same primitive output, they will be in different e-classes.
let eclass = if config.split_primitive_outputs
&& !self.get_sort_from_value(output).unwrap().is_eq_sort()
{
format!("{}-value", node_id.clone()).into()
} else {
None
self.value_to_class_id(output)
};
let eclass = self
.serialize_value(&mut egraph, &mut node_ids, output, prim_node_id)
.0;
self.serialize_value(&mut egraph, &mut node_ids, output, &eclass);
let children: Vec<_> = input
.iter()
.map(|v| self.serialize_value(&mut egraph, &mut node_ids, v, None).1)
.map(|v| {
self.serialize_value(&mut egraph, &mut node_ids, v, &self.value_to_class_id(v))
})
.collect();
egraph.nodes.insert(
node_id,
Expand All @@ -144,12 +142,35 @@ impl EGraph {
egraph.root_eclasses = config
.root_eclasses
.iter()
.map(|v| self.serialize_value(&mut egraph, &mut node_ids, v, None).0)
.map(|v| self.value_to_class_id(v))
.collect();

egraph
}

/// Gets the serialized class ID for a value.
pub fn value_to_class_id(&self, value: &Value) -> egraph_serialize::ClassId {
// Canonicalize the value first so that we always use the canonical e-class ID
let sort = self.get_sort_from_value(value).unwrap();
let mut value = *value;
sort.canonicalize(&mut value, &self.unionfind);
assert!(
!value.tag.to_string().contains('-'),
"Tag cannot contain '-' when serializing"
);
format!("{}-{}", value.tag, value.bits).into()
}

/// Gets the value for a serialized class ID.
pub fn class_id_to_value(&self, eclass_id: &egraph_serialize::ClassId) -> Value {
let s = eclass_id.to_string();
let (tag, bits) = s.split_once('-').unwrap();
Value {
tag: tag.into(),
bits: bits.parse().unwrap(),
}
}

/// Serialize the value and return the eclass and node ID
/// If this is a primitive value, we will add the node to the data, but if it is an eclass, we will not
/// When this is called on the output of a node, we only use the e-class to know which e-class its a part of
Expand All @@ -159,86 +180,65 @@ impl EGraph {
egraph: &mut egraph_serialize::EGraph,
node_ids: &mut NodeIDs,
value: &Value,
// The node ID to use for a primitive value, if this is None, use the hash of the value and the sort name
// Set iff `split_primitive_outputs` is set and this is an output of a function.
prim_node_id: Option<String>,
) -> (egraph_serialize::ClassId, egraph_serialize::NodeId) {
class_id: &egraph_serialize::ClassId,
) -> egraph_serialize::NodeId {
let sort = self.get_sort_from_value(value).unwrap();
let (class_id, node_id): (egraph_serialize::ClassId, egraph_serialize::NodeId) =
if sort.is_eq_sort() {
let id: usize = value.bits as usize;
let canonical: usize = self.unionfind.find(Id::from(id)).into();
let class_id: egraph_serialize::ClassId = canonical.to_string().into();
(class_id.clone(), get_node_id(egraph, node_ids, class_id))
} else {
let (class_id, node_id): (egraph_serialize::ClassId, egraph_serialize::NodeId) =
if let Some(node_id) = prim_node_id {
(node_id.clone().into(), node_id.into())
} else {
let sort_name = sort.name().to_string();
let node_id_str =
format!("{}-{}", sort_name, hash_values(vec![*value].as_slice()));
(node_id_str.clone().into(), node_id_str.into())
};
// Add node for value
{
// Children will be empty unless this is a container sort
let children: Vec<egraph_serialize::NodeId> = sort
.inner_values(value)
.into_iter()
.map(|(_, v)| self.serialize_value(egraph, node_ids, &v, None).1)
.collect();
// If this is a container sort, use the name, otherwise use the value
let op = if sort.is_container_sort() {
sort.serialized_name(value).to_string()
} else {
sort.make_expr(self, *value).1.to_string()
};
egraph.nodes.insert(
node_id.clone(),
egraph_serialize::Node {
op,
eclass: class_id.clone(),
cost: NotNan::new(0.0).unwrap(),
children,
},
);
let node_id = if sort.is_eq_sort() {
let node_ids = node_ids.entry(class_id.clone()).or_insert_with(|| {
// If we don't find node IDs for this class, it means that all nodes for it were omitted due to size constraints
// In this case, add a dummy node in this class to represent the missing nodes
let node_id = egraph_serialize::NodeId::from(format!("{}-dummy", class_id));
egraph.nodes.insert(
node_id.clone(),
egraph_serialize::Node {
op: "[...]".to_string(),
eclass: class_id.clone(),
cost: NotNan::new(f64::INFINITY).unwrap(),
children: vec![],
},
);
VecDeque::from(vec![node_id])
});
node_ids.rotate_left(1);
node_ids.front().unwrap().clone()
} else {
let node_id: egraph_serialize::NodeId = class_id.to_string().into();
// Add node for value
{
// Children will be empty unless this is a container sort
let children: Vec<egraph_serialize::NodeId> = sort
.inner_values(value)
.into_iter()
.map(|(_, v)| {
self.serialize_value(egraph, node_ids, &v, &self.value_to_class_id(&v))
})
.collect();
// If this is a container sort, use the name, otherwise use the value
let op = if sort.is_container_sort() {
sort.serialized_name(value).to_string()
} else {
sort.make_expr(self, *value).1.to_string()
};
(class_id, node_id)
egraph.nodes.insert(
node_id.clone(),
egraph_serialize::Node {
op,
eclass: class_id.clone(),
cost: NotNan::new(1.0).unwrap(),
children,
},
);
};
node_id
};
egraph.class_data.insert(
class_id.clone(),
egraph_serialize::ClassData {
typ: Some(sort.name().to_string()),
},
);
(class_id, node_id)
node_id
}
}

type NodeIDs = HashMap<egraph_serialize::ClassId, VecDeque<egraph_serialize::NodeId>>;

/// Returns the node ID for the given class ID, rotating the queue
fn get_node_id(
egraph: &mut egraph_serialize::EGraph,
node_ids: &mut HashMap<egraph_serialize::ClassId, VecDeque<egraph_serialize::NodeId>>,
class_id: egraph_serialize::ClassId,
) -> egraph_serialize::NodeId {
// If we don't find node IDs for this class, it means that all nodes for it were omitted due to size constraints
// In this case, add a dummy node in this class to represent the missing nodes
let node_ids = node_ids.entry(class_id.clone()).or_insert_with(|| {
let node_id = egraph_serialize::NodeId::from(format!("{}-dummy", class_id));
egraph.nodes.insert(
node_id.clone(),
egraph_serialize::Node {
op: "[...]".to_string(),
eclass: class_id.clone(),
cost: NotNan::new(f64::INFINITY).unwrap(),
children: vec![],
},
);
VecDeque::from(vec![node_id])
});
node_ids.rotate_left(1);
node_ids.front().unwrap().clone()
}
30 changes: 29 additions & 1 deletion tests/integration_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use egglog::{ast::Expr, EGraph, ExtractReport, Function, Term, Value};
use egglog::{ast::Expr, EGraph, ExtractReport, Function, SerializeConfig, Term, Value};
use symbol_table::GlobalSymbol;

#[test]
Expand Down Expand Up @@ -377,3 +377,31 @@ fn test_cant_subsume_merge() {
);
assert!(res.is_err());
}

#[test]
fn test_value_to_classid() {
let mut egraph = EGraph::default();

egraph
.parse_and_run_program(
None,
r#"
(datatype Math)
(function exp () Math )
(exp)
(query-extract (exp))
"#,
)
.unwrap();
let report = egraph.get_extract_report().clone().unwrap();
let ExtractReport::Best { term, termdag, .. } = report else {
panic!();
};
let expr = termdag.term_to_expr(&term);
let value = egraph.eval_expr(&expr).unwrap().1;

let serialized = egraph.serialize(SerializeConfig::default());
let class_id = egraph.value_to_class_id(&value);
assert!(serialized.class_data.get(&class_id).is_some());
assert_eq!(value, egraph.class_id_to_value(&class_id));
}

0 comments on commit 5d8c025

Please sign in to comment.