Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support marking exprs as subsumed #301

Merged
merged 40 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
42c9eba
Add support for making expression unextractable
saulshanabrook Nov 27, 2023
50fa8b4
Add support for unextractable arg to functions
saulshanabrook Nov 27, 2023
ad4cec2
Add unextractable example
saulshanabrook Nov 27, 2023
190dcd5
Add two failing tests about rebuilding
saulshanabrook Nov 28, 2023
3fcc3bd
Fix breaking test to be accurate
saulshanabrook Nov 28, 2023
e7b430a
Add another test
saulshanabrook Nov 29, 2023
78bcdb5
Move unextractable to table
saulshanabrook Nov 29, 2023
fb58cf8
Add support for subsuming nodes
saulshanabrook Nov 29, 2023
5433414
Change rewrite arg to subsume and mark as unextractable
saulshanabrook Nov 29, 2023
6ab1817
Fix example file
saulshanabrook Nov 29, 2023
1dcbd72
Add subsumption to example file
saulshanabrook Nov 29, 2023
043fdb7
Fix subsumption handling by moving to querying
saulshanabrook Nov 29, 2023
5b56a67
Allow replace tests to fail with term encoding
saulshanabrook Nov 29, 2023
f0ea440
Skip those tests all together
saulshanabrook Nov 29, 2023
2d8a39f
Add failing test case based on @yihozhang's example
saulshanabrook Nov 30, 2023
5985cd1
Fix unsound behavior by taking lattice join of old and new functions
saulshanabrook Nov 30, 2023
6390c26
Move include subsumed to TrieAccess from LazyTrie
saulshanabrook Nov 30, 2023
effad84
Move include subsumed to get_index
saulshanabrook Nov 30, 2023
e1530e9
Nits
saulshanabrook Nov 30, 2023
53b337a
Type error on marking prim as unextractable
saulshanabrook Dec 1, 2023
66f016e
Test that primitives can be subsumed
saulshanabrook Dec 1, 2023
05c8b13
Move subsume and unextractable to output
saulshanabrook Dec 1, 2023
90f4ba2
Split :replace into :subsume and :unextractable
saulshanabrook Dec 1, 2023
64df4d5
Reduce repeated match statements
saulshanabrook Dec 1, 2023
2160283
Move integration tests to separate file
saulshanabrook Dec 6, 2023
02233a5
remove commented function
saulshanabrook Dec 26, 2023
a49e6d8
Merge egraphs-good/main into unextractable
saulshanabrook Feb 8, 2024
bdc7219
Update to combine flags and fix tests
saulshanabrook Feb 12, 2024
c4e9bd6
Remove unnecessary changes
saulshanabrook Feb 12, 2024
428e485
Fix typo
saulshanabrook Feb 12, 2024
e8691d9
Revert terms since it isn't used anymore
saulshanabrook Feb 12, 2024
67a6ebb
Dont need to expose typeinfo
saulshanabrook Feb 12, 2024
8627b32
Error when subsuming function with merge
saulshanabrook Feb 12, 2024
686b7d4
Clarify subsume's behavior in docs
saulshanabrook Feb 12, 2024
809adee
Add include subsumed flag to iter as well
saulshanabrook Feb 14, 2024
24d61e9
Merge egraphs-good/main into unextractable
saulshanabrook Feb 14, 2024
0837035
Clarify action docstring
saulshanabrook Feb 14, 2024
55e6d46
Make rewrite subsume doc more detailed
saulshanabrook Feb 14, 2024
29dff37
Add link to change
saulshanabrook Feb 14, 2024
c240560
Merge branch 'main' into unextractable
saulshanabrook Feb 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 78 additions & 22 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,25 @@ fn desugar_rewrite(
ruleset: Symbol,
name: Symbol,
rewrite: &Rewrite,
replace: bool,
desugar: &mut Desugar,
) -> Vec<NCommand> {
let var = Symbol::from("rewrite_var__");
// make two rules- one to insert the rhs, and one to union
// this way, the union rule can only be fired once,
// which helps proofs not add too much info
let mut head = vec![Action::Union(Expr::Var(var), rewrite.rhs.clone())];
if replace {
match &rewrite.lhs {
Expr::Call(f, args) => {
head.push(Action::Unextractable(*f, args.to_vec()));
head.push(Action::Subsume(*f, args.to_vec()));
}
_ => {
panic!("Unextractable rewrite must have a function call on the lhs");
}
}
}
vec![NCommand::NormRule {
ruleset,
name,
Expand All @@ -40,7 +53,7 @@ fn desugar_rewrite(
.into_iter()
.chain(rewrite.conditions.clone())
.collect(),
head: vec![Action::Union(Expr::Var(var), rewrite.rhs.clone())],
head,
},
desugar,
),
Expand All @@ -58,15 +71,22 @@ fn desugar_birewrite(
rhs: rewrite.lhs.clone(),
conditions: rewrite.conditions.clone(),
};
desugar_rewrite(ruleset, format!("{}=>", name).into(), rewrite, desugar)
.into_iter()
.chain(desugar_rewrite(
ruleset,
format!("{}<=", name).into(),
&rw2,
desugar,
))
.collect()
desugar_rewrite(
ruleset,
format!("{}=>", name).into(),
rewrite,
false,
desugar,
)
.into_iter()
.chain(desugar_rewrite(
ruleset,
format!("{}<=", name).into(),
&rw2,
false,
desugar,
))
.collect()
}

fn normalize_expr(
Expand Down Expand Up @@ -257,16 +277,48 @@ fn flatten_actions(actions: &Vec<Action>, desugar: &mut Desugar) -> Vec<NormActi
let added_variants = add_expr(variants.clone(), &mut res);
res.push(NormAction::Extract(added, added_variants));
}
// TODO: Reduce duplication in these three cases
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
Action::Delete(symbol, exprs) => {
let del = NormAction::Delete(NormExpr::Call(
*symbol,
exprs
.clone()
.into_iter()
.map(|ex| add_expr(ex, &mut res))
.collect(),
));
res.push(del);
let unex = NormAction::ChangeRow(
ChangeRow::Delete,
NormExpr::Call(
*symbol,
exprs
.clone()
.into_iter()
.map(|ex| add_expr(ex, &mut res))
.collect(),
),
);
res.push(unex);
}
Action::Unextractable(symbol, exprs) => {
let unex = NormAction::ChangeRow(
ChangeRow::Unextractable,
NormExpr::Call(
*symbol,
exprs
.clone()
.into_iter()
.map(|ex| add_expr(ex, &mut res))
.collect(),
),
);
res.push(unex);
}
Action::Subsume(symbol, exprs) => {
let unex = NormAction::ChangeRow(
ChangeRow::Subsume,
NormExpr::Call(
*symbol,
exprs
.clone()
.into_iter()
.map(|ex| add_expr(ex, &mut res))
.collect(),
),
);
res.push(unex);
}
Action::Union(lhs, rhs) => {
let un = NormAction::Union(
Expand Down Expand Up @@ -550,9 +602,13 @@ pub(crate) fn desugar_command(
} => desugar.desugar_function(&FunctionDecl::relation(constructor, inputs)),
Command::Declare { name, sort } => desugar.declare(name, sort),
Command::Datatype { name, variants } => desugar_datatype(name, variants),
Command::Rewrite(ruleset, rewrite) => {
desugar_rewrite(ruleset, rewrite_name(&rewrite).into(), &rewrite, desugar)
}
Command::Rewrite(ruleset, rewrite, replace) => desugar_rewrite(
ruleset,
rewrite_name(&rewrite).into(),
&rewrite,
replace,
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
desugar,
),
Command::BiRewrite(ruleset, rewrite) => {
desugar_birewrite(ruleset, rewrite_name(&rewrite).into(), &rewrite, desugar)
}
Expand Down
77 changes: 64 additions & 13 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ impl Display for NormSchedule {
}
}

pub type Replace = bool;

// TODO command before and after desugaring should be different
/// A [`Command`] is the top-level construct in egglog.
/// It includes defining rules, declaring functions,
Expand Down Expand Up @@ -541,7 +543,13 @@ pub enum Command {
/// :when ((= a (Num 0)))
/// ```
///
Rewrite(Symbol, Rewrite),
/// To make the left hand side unextractable and unmatchable (subsumsed) use the `:replace ` clause
///
/// ```text
/// (rewrite (Mul a 2) (bitshift-left a 1) :replace)
/// ```
///
Rewrite(Symbol, Rewrite, Replace),
/// Similar to [`Command::Rewrite`], but
/// generates two rules, one for each direction.
///
Expand Down Expand Up @@ -679,8 +687,10 @@ impl ToSexp for Command {
fn to_sexp(&self) -> Sexp {
match self {
Command::SetOption { name, value } => list!("set-option", name, value),
Command::Rewrite(name, rewrite) => rewrite.to_sexp(*name, false),
Command::BiRewrite(name, rewrite) => rewrite.to_sexp(*name, true),
Command::Rewrite(name, rewrite, unextractable) => {
rewrite.to_sexp(*name, false, *unextractable)
}
Command::BiRewrite(name, rewrite) => rewrite.to_sexp(*name, true, false),
Command::Datatype { name, variants } => list!("datatype", name, ++ variants),
Command::Declare { name, sort } => list!("declare", name, sort),
Command::Action(a) => a.to_sexp(),
Expand Down Expand Up @@ -1080,6 +1090,12 @@ pub enum Action {
/// Be wary! Only delete entries that are subsumed in some way or
/// guaranteed to be not useful.
Delete(Symbol, Vec<Expr>),
/// Set an entry to be `unextractable`
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
/// so that it cannot be extracted.
/// Note that this cannot be an expr but has to be a symbol and args, because we need to refer to a specific row
Unextractable(Symbol, Vec<Expr>),
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
/// `subsume` an an entry so that it cannot be queries for during rules or rewrites.
Subsume(Symbol, Vec<Expr>),
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
/// `union` two datatypes, making them equal
/// in the implicit, global equality relation
/// of egglog.
Expand Down Expand Up @@ -1107,14 +1123,31 @@ pub enum Action {
// If(Expr, Action, Action),
}

#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
pub enum ChangeRow {
Delete,
Subsume,
Unextractable,
}

impl ChangeRow {
pub fn to_action(&self, op: Symbol, args: Vec<Expr>) -> Action {
match self {
ChangeRow::Delete => Action::Delete(op, args),
ChangeRow::Subsume => Action::Subsume(op, args),
ChangeRow::Unextractable => Action::Unextractable(op, args),
}
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum NormAction {
Let(Symbol, NormExpr),
LetVar(Symbol, Symbol),
LetLit(Symbol, Literal),
Extract(Symbol, Symbol),
Set(NormExpr, Symbol),
Delete(NormExpr),
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
ChangeRow(ChangeRow, NormExpr),
Union(Symbol, Symbol),
Panic(String),
}
Expand All @@ -1133,8 +1166,8 @@ impl NormAction {
NormAction::Extract(symbol, variants) => {
Action::Extract(Expr::Var(*symbol), Expr::Var(*variants))
}
NormAction::Delete(NormExpr::Call(symbol, args)) => {
Action::Delete(*symbol, args.iter().map(|s| Expr::Var(*s)).collect())
NormAction::ChangeRow(change, NormExpr::Call(symbol, args)) => {
change.to_action(*symbol, args.iter().map(|s| Expr::Var(*s)).collect())
}
NormAction::Union(lhs, rhs) => Action::Union(Expr::Var(*lhs), Expr::Var(*rhs)),
NormAction::Panic(msg) => Action::Panic(msg.clone()),
Expand All @@ -1148,7 +1181,7 @@ impl NormAction {
NormAction::LetLit(symbol, lit) => NormAction::LetLit(*symbol, lit.clone()),
NormAction::Set(expr, other) => NormAction::Set(f(expr), *other),
NormAction::Extract(var, variants) => NormAction::Extract(*var, *variants),
NormAction::Delete(expr) => NormAction::Delete(f(expr)),
NormAction::ChangeRow(change, expr) => NormAction::ChangeRow(*change, f(expr)),
NormAction::Union(lhs, rhs) => NormAction::Union(*lhs, *rhs),
NormAction::Panic(msg) => NormAction::Panic(msg.clone()),
}
Expand All @@ -1170,7 +1203,9 @@ impl NormAction {
NormAction::Extract(var, variants) => {
NormAction::Extract(fvar(*var, false), fvar(*variants, false))
}
NormAction::Delete(expr) => NormAction::Delete(expr.map_def_use(fvar, false)),
NormAction::ChangeRow(change, expr) => {
NormAction::ChangeRow(*change, expr.map_def_use(fvar, false))
}
NormAction::Union(lhs, rhs) => NormAction::Union(fvar(*lhs, false), fvar(*rhs, false)),
NormAction::Panic(msg) => NormAction::Panic(msg.clone()),
}
Expand All @@ -1184,6 +1219,8 @@ impl ToSexp for Action {
Action::Set(lhs, args, rhs) => list!("set", list!(lhs, ++ args), rhs),
Action::Union(lhs, rhs) => list!("union", lhs, rhs),
Action::Delete(lhs, args) => list!("delete", list!(lhs, ++ args)),
Action::Unextractable(lhs, args) => list!("unextractable", list!(lhs, ++ args)),
Action::Subsume(lhs, args) => list!("subsume", list!(lhs, ++ args)),
Action::Extract(expr, variants) => list!("extract", expr, variants),
Action::Panic(msg) => list!("panic", format!("\"{}\"", msg.clone())),
Action::Expr(e) => e.to_sexp(),
Expand All @@ -1200,6 +1237,10 @@ impl Action {
Action::Set(*lhs, args.iter().map(f).collect(), right)
}
Action::Delete(lhs, args) => Action::Delete(*lhs, args.iter().map(f).collect()),
Action::Unextractable(lhs, args) => {
Action::Unextractable(*lhs, args.iter().map(f).collect())
}
Action::Subsume(lhs, args) => Action::Subsume(*lhs, args.iter().map(f).collect()),
Action::Union(lhs, rhs) => Action::Union(f(lhs), f(rhs)),
Action::Extract(expr, variants) => Action::Extract(f(expr), f(variants)),
Action::Panic(msg) => Action::Panic(msg.clone()),
Expand All @@ -1218,6 +1259,12 @@ impl Action {
Action::Delete(lhs, args) => {
Action::Delete(*lhs, args.iter().map(|e| e.subst(canon)).collect())
}
Action::Unextractable(lhs, args) => {
Action::Unextractable(*lhs, args.iter().map(|e| e.subst(canon)).collect())
}
Action::Subsume(lhs, args) => {
Action::Subsume(*lhs, args.iter().map(|e| e.subst(canon)).collect())
}
Action::Union(lhs, rhs) => Action::Union(lhs.subst(canon), rhs.subst(canon)),
Action::Extract(expr, variants) => {
Action::Extract(expr.subst(canon), variants.subst(canon))
Expand Down Expand Up @@ -1399,7 +1446,7 @@ impl NormRule {
_ => panic!("Expected call in set"),
}
}
NormAction::Delete(expr) => {
NormAction::ChangeRow(change, expr) => {
let new_expr = expr.to_expr();
new_expr.map(&mut |subexpr| {
if let Expr::Var(v) = subexpr {
Expand All @@ -1409,11 +1456,12 @@ impl NormRule {
});
match new_expr.subst(subst) {
Expr::Call(op, children) => {
head.push(Action::Delete(op, children));
head.push(change.to_action(op, children));
}
_ => panic!("Expected call in delete"),
_ => panic!("Expected call"),
}
}

NormAction::Union(lhs, rhs) => {
let new_lhs = subst.get(lhs).unwrap_or(&Expr::Var(*lhs)).clone();
let new_rhs = subst.get(rhs).unwrap_or(&Expr::Var(*rhs)).clone();
Expand Down Expand Up @@ -1559,7 +1607,7 @@ pub struct Rewrite {

impl Rewrite {
/// Converts the rewrite into an s-expression.
pub fn to_sexp(&self, ruleset: Symbol, is_bidirectional: bool) -> Sexp {
pub fn to_sexp(&self, ruleset: Symbol, is_bidirectional: bool, unextractable: bool) -> Sexp {
let mut res = vec![
Sexp::Symbol(if is_bidirectional {
"birewrite".into()
Expand All @@ -1569,6 +1617,9 @@ impl Rewrite {
self.lhs.to_sexp(),
self.rhs.to_sexp(),
];
if unextractable {
res.insert(1, Sexp::Symbol(":unextractable".into()));
}

if !self.conditions.is_empty() {
res.push(Sexp::Symbol(":when".into()));
Expand All @@ -1587,6 +1638,6 @@ impl Rewrite {

impl Display for Rewrite {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_sexp("".into(), false))
write!(f, "{}", self.to_sexp("".into(), false, false))
}
}
13 changes: 8 additions & 5 deletions src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ pub Program: Vec<Command> = { (Command)* => <> }

LParen: () = {
"(" => (),
"[" => (),
"[" => (),
};
RParen: () = {
")" => (),
"]" => (),
};

List<T>: Vec<T> = {
List<T>: Vec<T> = {
LParen <T*> RParen => <>,
}

Expand Down Expand Up @@ -60,9 +60,10 @@ Command: Command = {
LParen "ruleset" <name:Ident> RParen => Command::AddRuleset(name),
LParen "rule" <body:List<Fact>> <head:List<Action>> <ruleset:(":ruleset" <Ident>)?> <name:(":name" <String>)?> RParen => Command::Rule{ruleset: ruleset.unwrap_or("".into()), name: name.unwrap_or("".to_string()).into(), rule: Rule { head, body }},
LParen "rewrite" <lhs:Expr> <rhs:Expr>
<replace:(":replace")?>
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
<conditions:(":when" <List<Fact>>)?>
<ruleset:(":ruleset" <Ident>)?>
RParen => Command::Rewrite(ruleset.unwrap_or("".into()), Rewrite { lhs, rhs, conditions: conditions.unwrap_or_default() }),
RParen => Command::Rewrite(ruleset.unwrap_or("".into()), Rewrite { lhs, rhs, conditions: conditions.unwrap_or_default() }, replace.is_some()),
LParen "birewrite" <lhs:Expr> <rhs:Expr>
<conditions:(":when" <List<Fact>>)?>
<ruleset:(":ruleset" <Ident>)?>
Expand All @@ -71,7 +72,7 @@ Command: Command = {
<NonLetAction> => Command::Action(<>),
LParen "run" <limit:UNum> <until:(":until" <(Fact)*>)?> RParen => Command::RunSchedule(Schedule::Repeat(limit, Box::new(Schedule::Run(RunConfig { ruleset : "".into(), until })))),
LParen "run" <ruleset: Ident> <limit:UNum> <until:(":until" <(Fact)*>)?> RParen => Command::RunSchedule(Schedule::Repeat(limit, Box::new(Schedule::Run(RunConfig { ruleset, until })))),
LParen "simplify" <schedule:Schedule> <expr:Expr> RParen
LParen "simplify" <schedule:Schedule> <expr:Expr> RParen
=> Command::Simplify { expr, schedule },
LParen "calc" LParen <idents:IdentSort*> RParen <exprs:Expr+> RParen => Command::Calc(idents, exprs),
LParen "query-extract" <variants:(":variants" <UNum>)?> <expr:Expr> RParen => Command::QueryExtract { expr, variants: variants.unwrap_or(0) },
Expand All @@ -93,7 +94,7 @@ Schedule: Schedule = {
LParen "saturate" <Schedule*> RParen => Schedule::Saturate(Box::new(Schedule::Sequence(<>))),
LParen "seq" <Schedule*> RParen => Schedule::Sequence(<>),
LParen "repeat" <limit:UNum> <scheds:Schedule*> RParen => Schedule::Repeat(limit, Box::new(Schedule::Sequence(scheds))),
LParen "run" <until:(":until" <(Fact)*>)?> RParen =>
LParen "run" <until:(":until" <(Fact)*>)?> RParen =>
Schedule::Run(RunConfig { ruleset: "".into(), until }),
LParen "run" <ruleset: Ident> <until:(":until" <(Fact)*>)?> RParen => Schedule::Run(RunConfig { ruleset, until }),
<ident:Ident> => Schedule::Run(RunConfig { ruleset: ident, until: None }),
Expand All @@ -107,6 +108,8 @@ Cost: Option<usize> = {
NonLetAction: Action = {
LParen "set" LParen <f: Ident> <args:Expr*> RParen <v:Expr> RParen => Action::Set ( f, args, v ),
LParen "delete" LParen <f: Ident> <args:Expr*> RParen RParen => Action::Delete ( f, args),
LParen "unextractable" LParen <f: Ident> <args:Expr*> RParen RParen => Action::Unextractable ( f, args),
LParen "subsume" LParen <f: Ident> <args:Expr*> RParen RParen => Action::Subsume ( f, args),
LParen "union" <e1:Expr> <e2:Expr> RParen => Action::Union(<>),
LParen "panic" <msg:String> RParen => Action::Panic(msg),
LParen "extract" <expr:Expr> RParen => Action::Extract(expr, Expr::Lit(Literal::Int(0))),
Expand Down
Loading