Skip to content

Commit

Permalink
Merge pull request #389 from egraphs-good/yihozhang-span-annotated-asts
Browse files Browse the repository at this point in the history
Adding span annotations to the internal representation
  • Loading branch information
yihozhang committed Jul 25, 2024
2 parents 6e70b79 + 0179742 commit fa45d46
Show file tree
Hide file tree
Showing 29 changed files with 1,048 additions and 737 deletions.
32 changes: 16 additions & 16 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ struct ActionCompiler<'a> {
impl<'a> ActionCompiler<'a> {
fn compile_action(&mut self, action: &GenericCoreAction<ResolvedCall, ResolvedVar>) {
match action {
GenericCoreAction::Let(v, f, args) => {
GenericCoreAction::Let(_ann, v, f, args) => {
self.do_call(f, args);
self.locals.insert(v.clone());
}
GenericCoreAction::LetAtomTerm(v, at) => {
GenericCoreAction::LetAtomTerm(_ann, v, at) => {
self.do_atom_term(at);
self.locals.insert(v.clone());
}
GenericCoreAction::Extract(e, b) => {
GenericCoreAction::Extract(_ann, e, b) => {
self.do_atom_term(e);
self.do_atom_term(b);
self.instructions.push(Instruction::Extract(2));
}
GenericCoreAction::Set(f, args, e) => {
GenericCoreAction::Set(_ann, f, args, e) => {
let ResolvedCall::Func(func) = f else {
panic!("Cannot set primitive- should have been caught by typechecking!!!")
};
Expand All @@ -39,7 +39,7 @@ impl<'a> ActionCompiler<'a> {
self.do_atom_term(e);
self.instructions.push(Instruction::Set(func.name));
}
GenericCoreAction::Change(change, f, args) => {
GenericCoreAction::Change(_ann, change, f, args) => {
let ResolvedCall::Func(func) = f else {
panic!("Cannot change primitive- should have been caught by typechecking!!!")
};
Expand All @@ -49,12 +49,12 @@ impl<'a> ActionCompiler<'a> {
self.instructions
.push(Instruction::Change(*change, func.name));
}
GenericCoreAction::Union(arg1, arg2) => {
GenericCoreAction::Union(_ann, arg1, arg2) => {
self.do_atom_term(arg1);
self.do_atom_term(arg2);
self.instructions.push(Instruction::Union(2));
}
GenericCoreAction::Panic(msg) => {
GenericCoreAction::Panic(_ann, msg) => {
self.instructions.push(Instruction::Panic(msg.clone()));
}
}
Expand All @@ -72,18 +72,18 @@ impl<'a> ActionCompiler<'a> {

fn do_atom_term(&mut self, at: &ResolvedAtomTerm) {
match at {
ResolvedAtomTerm::Var(var) => {
ResolvedAtomTerm::Var(_ann, var) => {
if let Some((i, _ty)) = self.locals.get_full(var) {
self.instructions.push(Instruction::Load(Load::Stack(i)));
} else {
let (i, _, _ty) = self.types.get_full(&var.name).unwrap();
self.instructions.push(Instruction::Load(Load::Subst(i)));
}
}
ResolvedAtomTerm::Literal(lit) => {
ResolvedAtomTerm::Literal(_ann, lit) => {
self.instructions.push(Instruction::Literal(lit.clone()));
}
ResolvedAtomTerm::Global(_var) => {
ResolvedAtomTerm::Global(_ann, _var) => {
panic!("Global variables should have been desugared");
}
}
Expand Down Expand Up @@ -301,16 +301,16 @@ impl EGraph {
value
}
_ => {
return Err(Error::NotFoundError(NotFoundError(Expr::Var(
(),
format!("No value found for {f} {:?}", values).into(),
return Err(Error::NotFoundError(NotFoundError(format!(
"No value found for {f} {:?}",
values
))))
}
}
} else {
return Err(Error::NotFoundError(NotFoundError(Expr::Var(
(),
format!("No value found for {f} {:?}", values).into(),
return Err(Error::NotFoundError(NotFoundError(format!(
"No value found for {f} {:?}",
values
))));
};

Expand Down
130 changes: 89 additions & 41 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,22 @@ fn desugar_rewrite(
rewrite: &Rewrite,
subsume: bool,
) -> Vec<NCommand> {
let span = rewrite.span.clone();
let var = Symbol::from("rewrite_var__");
let mut head = Actions::singleton(Action::Union((), Expr::Var((), var), rewrite.rhs.clone()));
let mut head = Actions::singleton(Action::Union(
span.clone(),
Expr::Var(span.clone(), var),
rewrite.rhs.clone(),
));
if subsume {
match &rewrite.lhs {
Expr::Call(_, f, args) => {
head.0
.push(Action::Change((), Change::Subsume, *f, args.to_vec()));
head.0.push(Action::Change(
span.clone(),
Change::Subsume,
*f,
args.to_vec(),
));
}
_ => {
panic!("Subsumed rewrite must have a function call on the lhs");
Expand All @@ -65,17 +74,23 @@ fn desugar_rewrite(
ruleset,
name,
rule: Rule {
body: [Fact::Eq(vec![Expr::Var((), var), rewrite.lhs.clone()])]
.into_iter()
.chain(rewrite.conditions.clone())
.collect(),
span: span.clone(),
body: [Fact::Eq(
span.clone(),
vec![Expr::Var(span, var), rewrite.lhs.clone()],
)]
.into_iter()
.chain(rewrite.conditions.clone())
.collect(),
head,
},
}]
}

fn desugar_birewrite(ruleset: Symbol, name: Symbol, rewrite: &Rewrite) -> Vec<NCommand> {
let span = rewrite.span.clone();
let rw2 = Rewrite {
span,
lhs: rewrite.rhs.clone(),
rhs: rewrite.lhs.clone(),
conditions: rewrite.conditions.clone(),
Expand All @@ -91,6 +106,7 @@ fn desugar_birewrite(ruleset: Symbol, name: Symbol, rewrite: &Rewrite) -> Vec<NC
.collect()
}

// TODO(yz): we can delete this code once we enforce that all rule bodies cannot read the database (except EqSort).
fn add_semi_naive_rule(desugar: &mut Desugar, rule: Rule) -> Option<Rule> {
let mut new_rule = rule;
// Whenever an Let(_, expr@Call(...)) or Set(_, expr@Call(...)) is present in action,
Expand All @@ -102,24 +118,24 @@ fn add_semi_naive_rule(desugar: &mut Desugar, rule: Rule) -> Option<Rule> {
let mut var_set = HashSet::default();
for head_slice in new_rule.head.0.iter_mut().rev() {
match head_slice {
Action::Set(_ann, _, _, expr) => {
Action::Set(span, _, _, expr) => {
var_set.extend(expr.vars());
if let Expr::Call((), _, _) = expr {
if let Expr::Call(..) = expr {
add_new_rule = true;

let fresh_symbol = desugar.get_fresh();
let fresh_var = Expr::Var((), fresh_symbol);
let fresh_var = Expr::Var(span.clone(), fresh_symbol);
let expr = std::mem::replace(expr, fresh_var.clone());
new_head_atoms.push(Fact::Eq(vec![fresh_var, expr]));
new_head_atoms.push(Fact::Eq(span.clone(), vec![fresh_var, expr]));
};
}
Action::Let(_ann, symbol, expr) if var_set.contains(symbol) => {
Action::Let(span, symbol, expr) if var_set.contains(symbol) => {
var_set.extend(expr.vars());
if let Expr::Call((), _, _) = expr {
if let Expr::Call(..) = expr {
add_new_rule = true;

let var = Expr::Var((), *symbol);
new_head_atoms.push(Fact::Eq(vec![var, expr.clone()]));
let var = Expr::Var(span.clone(), *symbol);
new_head_atoms.push(Fact::Eq(span.clone(), vec![var, expr.clone()]));
}
}
_ => (),
Expand All @@ -140,15 +156,20 @@ fn add_semi_naive_rule(desugar: &mut Desugar, rule: Rule) -> Option<Rule> {
}

fn desugar_simplify(desugar: &mut Desugar, expr: &Expr, schedule: &Schedule) -> Vec<NCommand> {
let span = expr.span();
let mut res = vec![NCommand::Push(1)];
let lhs = desugar.get_fresh();
res.push(NCommand::CoreAction(Action::Let((), lhs, expr.clone())));
res.push(NCommand::CoreAction(Action::Let(
span.clone(),
lhs,
expr.clone(),
)));
res.push(NCommand::RunSchedule(schedule.clone()));
res.extend(
desugar_command(
Command::QueryExtract {
variants: 0,
expr: Expr::Var((), lhs),
expr: Expr::Var(span, lhs),
},
desugar,
false,
Expand All @@ -163,6 +184,7 @@ fn desugar_simplify(desugar: &mut Desugar, expr: &Expr, schedule: &Schedule) ->

pub(crate) fn desugar_calc(
desugar: &mut Desugar,
span: Span,
idents: Vec<IdentSort>,
exprs: Vec<Expr>,
seminaive_transform: bool,
Expand All @@ -171,7 +193,11 @@ pub(crate) fn desugar_calc(

// first, push all the idents
for IdentSort { ident, sort } in idents {
res.push(Command::Declare { name: ident, sort });
res.push(Command::Declare {
span: span.clone(),
name: ident,
sort,
});
}

// now, for every pair of exprs we need to prove them equal
Expand All @@ -182,23 +208,30 @@ pub(crate) fn desugar_calc(

// add the two exprs only when they are calls (consts and vars don't need to be populated).
if let Expr::Call(..) = expr1 {
res.push(Command::Action(Action::Expr((), expr1.clone())));
res.push(Command::Action(Action::Expr(expr1.span(), expr1.clone())));
}
if let Expr::Call(..) = expr2 {
res.push(Command::Action(Action::Expr((), expr2.clone())));
res.push(Command::Action(Action::Expr(expr2.span(), expr2.clone())));
}

res.push(Command::RunSchedule(Schedule::Saturate(Box::new(
Schedule::Run(RunConfig {
ruleset: "".into(),
until: Some(vec![Fact::Eq(vec![expr1.clone(), expr2.clone()])]),
}),
))));
res.push(Command::RunSchedule(Schedule::Saturate(
span.clone(),
Box::new(Schedule::Run(
span.clone(),
RunConfig {
ruleset: "".into(),
until: Some(vec![Fact::Eq(
span.clone(),
vec![expr1.clone(), expr2.clone()],
)]),
},
)),
)));

res.push(Command::Check(vec![Fact::Eq(vec![
expr1.clone(),
expr2.clone(),
])]));
res.push(Command::Check(
span.clone(),
vec![Fact::Eq(span.clone(), vec![expr1.clone(), expr2.clone()])],
));

res.push(Command::Pop(1));
}
Expand Down Expand Up @@ -228,7 +261,7 @@ pub(crate) fn desugar_command(
constructor,
inputs,
} => desugar.desugar_function(&FunctionDecl::relation(constructor, inputs)),
Command::Declare { name, sort } => desugar.declare(name, sort),
Command::Declare { span, name, sort } => desugar.declare(span, name, sort),
Command::Datatype { name, variants } => desugar_datatype(name, variants),
Command::Rewrite(ruleset, rewrite, subsume) => {
desugar_rewrite(ruleset, rewrite_name(&rewrite).into(), &rewrite, subsume)
Expand All @@ -240,7 +273,7 @@ pub(crate) fn desugar_command(
let s = std::fs::read_to_string(&file)
.unwrap_or_else(|_| panic!("Failed to read file {file}"));
return desugar_commands(
desugar.parse_program(&s)?,
desugar.parse_program(Some(file), &s)?,
desugar,
get_all_proofs,
seminaive_transform,
Expand Down Expand Up @@ -280,7 +313,9 @@ pub(crate) fn desugar_command(
}
Command::Action(action) => vec![NCommand::CoreAction(action)],
Command::Simplify { expr, schedule } => desugar_simplify(desugar, &expr, &schedule),
Command::Calc(idents, exprs) => desugar_calc(desugar, idents, exprs, seminaive_transform)?,
Command::Calc(span, idents, exprs) => {
desugar_calc(desugar, span, idents, exprs, seminaive_transform)?
}
Command::RunSchedule(sched) => {
vec![NCommand::RunSchedule(sched.clone())]
}
Expand All @@ -290,7 +325,7 @@ pub(crate) fn desugar_command(
Command::QueryExtract { variants, expr } => {
let fresh = desugar.get_fresh();
let fresh_ruleset = desugar.get_fresh();
let desugaring = if let Expr::Var((), v) = expr {
let desugaring = if let Expr::Var(_, v) = expr {
format!("(extract {v} {variants})")
} else {
format!(
Expand All @@ -304,13 +339,13 @@ pub(crate) fn desugar_command(
};

desugar.desugar_program(
desugar.parse_program(&desugaring).unwrap(),
desugar.parse_program(None, &desugaring).unwrap(),
get_all_proofs,
seminaive_transform,
)?
}
Command::Check(facts) => {
let res = vec![NCommand::Check(facts)];
Command::Check(span, facts) => {
let res = vec![NCommand::Check(span, facts)];

if get_all_proofs {
// TODO check proofs
Expand Down Expand Up @@ -383,15 +418,24 @@ impl Desugar {
Ok(res)
}

pub fn parse_program(&self, input: &str) -> Result<Vec<Command>, Error> {
pub fn parse_program(
&self,
filename: Option<String>,
input: &str,
) -> Result<Vec<Command>, Error> {
let filename = filename.unwrap_or_else(|| DEFAULT_FILENAME.to_string());
let srcfile = Arc::new(SrcFile {
name: filename,
contents: Some(input.to_string()),
});
Ok(self
.parser
.parse(input)
.parse(&srcfile, input)
.map_err(|e| e.map_token(|tok| tok.to_string()))?)
}

// TODO declare by creating a new global function. See issue #334
pub fn declare(&mut self, name: Symbol, sort: Symbol) -> Vec<NCommand> {
pub fn declare(&mut self, span: Span, name: Symbol, sort: Symbol) -> Vec<NCommand> {
let fresh = self.get_fresh();
vec![
NCommand::Function(FunctionDecl {
Expand All @@ -407,7 +451,11 @@ impl Desugar {
unextractable: false,
ignore_viz: false,
}),
NCommand::CoreAction(Action::Let((), name, Expr::Call((), fresh, vec![]))),
NCommand::CoreAction(Action::Let(
span.clone(),
name,
Expr::Call(span.clone(), fresh, vec![]),
)),
]
}

Expand Down
Loading

0 comments on commit fa45d46

Please sign in to comment.