Skip to content

Commit

Permalink
Merge pull request #448 from Alex-Fischman/cfg-tag
Browse files Browse the repository at this point in the history
Disable `Value.tag` in release mode
  • Loading branch information
Alex-Fischman authored Oct 26, 2024
2 parents 993582f + ef44dea commit 2e16561
Show file tree
Hide file tree
Showing 25 changed files with 574 additions and 284 deletions.
87 changes: 46 additions & 41 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ impl<'a> ActionCompiler<'a> {
self.locals.insert(v.clone());
}
GenericCoreAction::Extract(_ann, e, b) => {
self.do_atom_term(e);
let sort = self.do_atom_term(e);
self.do_atom_term(b);
self.instructions.push(Instruction::Extract(2));
self.instructions.push(Instruction::Extract(2, sort));
}
GenericCoreAction::Set(_ann, f, args, e) => {
let ResolvedCall::Func(func) = f else {
Expand All @@ -50,9 +50,9 @@ impl<'a> ActionCompiler<'a> {
.push(Instruction::Change(*change, func.name));
}
GenericCoreAction::Union(_ann, arg1, arg2) => {
self.do_atom_term(arg1);
let sort = self.do_atom_term(arg1);
self.do_atom_term(arg2);
self.instructions.push(Instruction::Union(2));
self.instructions.push(Instruction::Union(2, sort));
}
GenericCoreAction::Panic(_ann, msg) => {
self.instructions.push(Instruction::Panic(msg.clone()));
Expand All @@ -70,18 +70,21 @@ impl<'a> ActionCompiler<'a> {
}
}

fn do_atom_term(&mut self, at: &ResolvedAtomTerm) {
fn do_atom_term(&mut self, at: &ResolvedAtomTerm) -> ArcSort {
match at {
ResolvedAtomTerm::Var(_ann, var) => {
if let Some((i, _ty)) = self.locals.get_full(var) {
if let Some((i, ty)) = self.locals.get_full(var) {
self.instructions.push(Instruction::Load(Load::Stack(i)));
ty.sort.clone()
} else {
let (i, _, _ty) = self.types.get_full(&var.name).unwrap();
let (i, _, ty) = self.types.get_full(&var.name).unwrap();
self.instructions.push(Instruction::Load(Load::Subst(i)));
ty.clone()
}
}
ResolvedAtomTerm::Literal(_ann, lit) => {
self.instructions.push(Instruction::Literal(lit.clone()));
crate::sort::literal_sort(lit)
}
ResolvedAtomTerm::Global(_ann, _var) => {
panic!("Global variables should have been desugared");
Expand All @@ -97,10 +100,8 @@ impl<'a> ActionCompiler<'a> {
}

fn do_prim(&mut self, prim: &SpecializedPrimitive) {
self.instructions.push(Instruction::CallPrimitive(
prim.primitive.clone(),
prim.input.len(),
));
self.instructions
.push(Instruction::CallPrimitive(prim.clone(), prim.input.len()));
}
}

Expand All @@ -126,19 +127,19 @@ enum Instruction {
CallFunction(Symbol, bool),
/// Pop primitive arguments off the stack, calls the primitive,
/// and push the result onto the stack.
CallPrimitive(Primitive, usize),
CallPrimitive(SpecializedPrimitive, usize),
/// Pop function arguments off the stack and either deletes or subsumes the corresponding row
/// in the function.
Change(Change, Symbol),
/// Pop the value to be set and the function arguments off the stack.
/// Set the function at the given arguments to the new value.
Set(Symbol),
/// Union the last `n` values on the stack.
Union(usize),
Union(usize, ArcSort),
/// Extract the best expression. `n` is always 2.
/// The first value on the stack is the expression to extract,
/// and the second value is the number of variants to extract.
Extract(usize),
Extract(usize, ArcSort),
/// Panic with the given message.
Panic(String),
}
Expand Down Expand Up @@ -223,10 +224,11 @@ impl EGraph {
MergeFn::AssertEq => {
return Err(Error::MergeError(table, new_value, old_value));
}
MergeFn::Union => {
self.unionfind
.union_values(old_value, new_value, old_value.tag)
}
MergeFn::Union => self.unionfind.union_values(
old_value,
new_value,
function.decl.schema.output,
),
MergeFn::Expr(merge_prog) => {
let values = [old_value, new_value];
let mut stack = vec![];
Expand Down Expand Up @@ -268,14 +270,15 @@ impl EGraph {
},
Instruction::CallFunction(f, make_defaults) => {
let function = self.functions.get_mut(f).unwrap();
let output_tag = function.schema.output.name();
let new_len = stack.len() - function.schema.input.len();
let values = &stack[new_len..];

if cfg!(debug_assertions) {
for (ty, val) in function.schema.input.iter().zip(values) {
assert_eq!(ty.name(), val.tag,);
}
#[cfg(debug_assertions)]
let output_tag = function.schema.output.name();

#[cfg(debug_assertions)]
for (ty, val) in function.schema.input.iter().zip(values) {
assert_eq!(ty.name(), val.tag);
}

let value = if let Some(out) = function.nodes.get(values) {
Expand All @@ -289,8 +292,11 @@ impl EGraph {
Value::unit()
}
None if out.is_eq_sort() => {
let id = self.unionfind.make_set();
let value = Value::from_id(out.name(), id);
let value = Value {
#[cfg(debug_assertions)]
tag: out.name(),
bits: self.unionfind.make_set(),
};
function.insert(values, value, ts);
value
}
Expand All @@ -314,18 +320,24 @@ impl EGraph {
))));
};

// cfg is necessary because debug_assert_eq still evaluates its
// arguments in release mode (is has to because of side effects)
#[cfg(debug_assertions)]
debug_assert_eq!(output_tag, value.tag);

stack.truncate(new_len);
stack.push(value);
}
Instruction::CallPrimitive(p, arity) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
if let Some(value) = p.apply(values, Some(self)) {
if let Some(value) =
p.primitive.apply(values, (&p.input, &p.output), Some(self))
{
stack.truncate(new_len);
stack.push(value);
} else {
return Err(Error::PrimitiveError(p.clone(), values.to_vec()));
return Err(Error::PrimitiveError(p.primitive.clone(), values.to_vec()));
}
}
Instruction::Set(f) => {
Expand All @@ -338,32 +350,25 @@ impl EGraph {
self.perform_set(*f, new_value, stack)?;
stack.truncate(new_len)
}
Instruction::Union(arity) => {
Instruction::Union(arity, sort) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
let sort = values[0].tag;
let first = self.unionfind.find(Id::from(values[0].bits as usize));
let first = self.unionfind.find(values[0].bits);
values[1..].iter().fold(first, |a, b| {
let b = self.unionfind.find(Id::from(b.bits as usize));
self.unionfind.union(a, b, sort)
let b = self.unionfind.find(b.bits);
self.unionfind.union(a, b, sort.name())
});
stack.truncate(new_len);
}
Instruction::Extract(arity) => {
Instruction::Extract(arity, sort) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
let new_len = stack.len() - arity;
let mut termdag = TermDag::default();
let num_sort = values[1].tag;
assert!(num_sort.to_string() == "i64");

let variants = values[1].bits as i64;
if variants == 0 {
let (cost, term) = self.extract(
values[0],
&mut termdag,
self.type_info.sorts.get(&values[0].tag).unwrap(),
);
let (cost, term) = self.extract(values[0], &mut termdag, sort);
let extracted = termdag.to_string(&term);
log::info!("extracted with cost {cost}: {extracted}");
self.print_msg(extracted);
Expand All @@ -377,7 +382,7 @@ impl EGraph {
panic!("Cannot extract negative number of variants");
}
let terms =
self.extract_variants(values[0], variants as usize, &mut termdag);
self.extract_variants(sort, values[0], variants as usize, &mut termdag);
log::info!("extracted variants:");
let mut msg = String::default();
msg += "(\n";
Expand Down
21 changes: 0 additions & 21 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,6 @@ pub use expr::*;
pub mod desugar;
pub(crate) mod remove_globals;

#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct Id(usize);

impl From<usize> for Id {
fn from(n: usize) -> Self {
Id(n)
}
}

impl From<Id> for usize {
fn from(id: Id) -> Self {
id.0
}
}

impl Display for Id {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "id{}", self.0)
}
}

#[derive(Clone, Debug)]
/// The egglog internal representation of already compiled rules
pub(crate) enum Ruleset {
Expand Down
6 changes: 3 additions & 3 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ impl std::fmt::Display for Query<ResolvedCall, Symbol> {
writeln!(
f,
"({} {})",
filter.head.name(),
filter.head.primitive.name(),
ListDisplay(&filter.args, " ")
)?;
}
Expand All @@ -350,12 +350,12 @@ impl std::fmt::Display for Query<ResolvedCall, Symbol> {
}

impl<Leaf: Clone> Query<ResolvedCall, Leaf> {
pub fn filters(&self) -> impl Iterator<Item = GenericAtom<Primitive, Leaf>> + '_ {
pub fn filters(&self) -> impl Iterator<Item = GenericAtom<SpecializedPrimitive, Leaf>> + '_ {
self.atoms.iter().filter_map(|atom| match &atom.head {
ResolvedCall::Func(_) => None,
ResolvedCall::Primitive(head) => Some(GenericAtom {
span: atom.span.clone(),
head: head.primitive.clone(),
head: head.clone(),
args: atom.args.clone(),
}),
})
Expand Down
38 changes: 22 additions & 16 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub type Cost = usize;
#[derive(Debug)]
pub(crate) struct Node<'a> {
sym: Symbol,
func: &'a Function,
inputs: &'a [Value],
}

Expand Down Expand Up @@ -51,11 +52,16 @@ impl EGraph {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);

assert_eq!(inputs.len(), func.schema.input.len());
log::error!(
"{:?}",
inputs
.iter()
.map(|input| extractor.costs.get(&extractor.find_id(*input)))
.zip(&func.schema.input)
.map(|(input, sort)| extractor
.costs
.get(&extractor.egraph.find(sort, *input).bits))
.collect::<Vec<_>>()
);
}
Expand All @@ -68,11 +74,13 @@ impl EGraph {

pub fn extract_variants(
&mut self,
sort: &ArcSort,
value: Value,
limit: usize,
termdag: &mut TermDag,
) -> Vec<Term> {
let output_value = self.find(value);
let output_sort = sort.name();
let output_value = self.find(sort, value);
let ext = &Extractor::new(self, termdag);
ext.ctors
.iter()
Expand All @@ -85,9 +93,11 @@ impl EGraph {

func.nodes
.iter(false)
.filter(|&(_, output)| (output.value == output_value))
.filter(|&(_, output)| {
func.schema.output.name() == output_sort && output.value == output_value
})
.map(|(inputs, _output)| {
let node = Node { sym, inputs };
let node = Node { sym, func, inputs };
ext.expr_from_node(&node, termdag).expect(
"extract_variants should be called after extractor initialization",
)
Expand Down Expand Up @@ -123,8 +133,12 @@ impl<'a> Extractor<'a> {

fn expr_from_node(&self, node: &Node, termdag: &mut TermDag) -> Option<Term> {
let mut children = vec![];
for value in node.inputs {
let arcsort = self.egraph.get_sort_from_value(value).unwrap();

let values = node.inputs;
let arcsorts = &node.func.schema.input;
assert_eq!(values.len(), arcsorts.len());

for (value, arcsort) in values.iter().zip(arcsorts) {
children.push(self.find_best(*value, termdag, arcsort)?.1)
}

Expand All @@ -138,7 +152,7 @@ impl<'a> Extractor<'a> {
sort: &ArcSort,
) -> Option<(Cost, Term)> {
if sort.is_eq_sort() {
let id = self.find_id(value);
let id = self.egraph.find(sort, value).bits;
let (cost, node) = self.costs.get(&id)?.clone();
Some((cost, node))
} else {
Expand All @@ -164,14 +178,6 @@ impl<'a> Extractor<'a> {
Some((terms, cost))
}

fn find(&self, value: Value) -> Value {
self.egraph.find(value)
}

fn find_id(&self, value: Value) -> Id {
Id::from(self.find(value).bits as usize)
}

fn find_costs(&mut self, termdag: &mut TermDag) {
let mut did_something = true;
while did_something {
Expand All @@ -186,7 +192,7 @@ impl<'a> Extractor<'a> {
{
let make_new_pair = || (new_cost, termdag.app(sym, term_inputs));

let id = self.find_id(output.value);
let id = self.egraph.find(&func.schema.output, output.value).bits;
match self.costs.entry(id) {
Entry::Vacant(e) => {
did_something = true;
Expand Down
1 change: 1 addition & 0 deletions src/function/binary_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ mod tests {

fn make_value(bits: u32) -> Value {
Value {
#[cfg(debug_assertions)]
tag: "testing".into(),
bits: bits as u64,
}
Expand Down
Loading

0 comments on commit 2e16561

Please sign in to comment.