Skip to content

Commit

Permalink
Merge pull request #222 from oflatt/oflatt-desugar-merge-action
Browse files Browse the repository at this point in the history
Desugar merge actions for consistency
  • Loading branch information
oflatt authored Sep 5, 2023
2 parents c83fc75 + 0acee28 commit 152110c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 36 deletions.
26 changes: 21 additions & 5 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ fn desugar_datatype(name: Symbol, variants: Vec<Variant>) -> Vec<NCommand> {
vec![NCommand::Sort(name, None)]
.into_iter()
.chain(variants.into_iter().map(|variant| {
NCommand::Function(FunctionDecl {
NCommand::Function(NormFunctionDecl {
name: variant.name,
schema: Schema {
input: variant.types,
Expand Down Expand Up @@ -513,6 +513,9 @@ pub(crate) fn rewrite_name(rewrite: &Rewrite) -> String {
rewrite.to_string().replace('\"', "'")
}

/// Desugars a single command into the normalized form.
/// Gets rid of a bunch of syntactic sugar, but also
/// makes rules into a SSA-like format (see [`NormFact`]).
pub(crate) fn desugar_command(
command: Command,
desugar: &mut Desugar,
Expand All @@ -523,9 +526,7 @@ pub(crate) fn desugar_command(
Command::SetOption { name, value } => {
vec![NCommand::SetOption { name, value }]
}
Command::Function(fdecl) => {
vec![NCommand::Function(fdecl)]
}
Command::Function(fdecl) => desugar.desugar_function(&fdecl),
Command::Declare { name, sort } => desugar.declare(name, sort),
Command::Datatype { name, variants } => desugar_datatype(name, variants),
Command::Rewrite(ruleset, rewrite) => {
Expand Down Expand Up @@ -781,7 +782,7 @@ impl Desugar {
pub fn declare(&mut self, name: Symbol, sort: Symbol) -> Vec<NCommand> {
let fresh = self.get_fresh();
vec![
NCommand::Function(FunctionDecl {
NCommand::Function(NormFunctionDecl {
name: fresh,
schema: Schema {
input: vec![],
Expand All @@ -796,4 +797,19 @@ impl Desugar {
NCommand::NormAction(NormAction::Let(name, NormExpr::Call(fresh, vec![]))),
]
}

pub fn desugar_function(&mut self, fdecl: &FunctionDecl) -> Vec<NCommand> {
let mut res = vec![];
let schema = fdecl.schema.clone();
res.push(NCommand::Function(NormFunctionDecl {
name: fdecl.name,
schema,
default: fdecl.default.clone(),
merge: fdecl.merge.clone(),
merge_action: flatten_actions(&fdecl.merge_action, self),
cost: fdecl.cost,
unextractable: fdecl.unextractable,
}));
res
}
}
37 changes: 34 additions & 3 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub enum NCommand {
value: Expr,
},
Sort(Symbol, Option<(Symbol, Vec<Expr>)>),
Function(FunctionDecl),
Function(NormFunctionDecl),
AddRuleset(Symbol),
NormRule {
name: Symbol,
Expand Down Expand Up @@ -132,7 +132,7 @@ impl NCommand {
value: value.clone(),
},
NCommand::Sort(name, params) => Command::Sort(*name, params.clone()),
NCommand::Function(f) => Command::Function(f.clone()),
NCommand::Function(f) => Command::Function(f.to_fdecl()),
NCommand::AddRuleset(name) => Command::AddRuleset(*name),
NCommand::NormRule {
name,
Expand Down Expand Up @@ -498,12 +498,43 @@ impl NormRunConfig {
}
}

/// A normalized function declaration- the desugared
/// version of a [`FunctionDecl`].
/// TODO so far only the merge action is normalized,
/// not the default value or merge expression.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct NormFunctionDecl {
pub name: Symbol,
pub schema: Schema,
// todo desugar default, merge
pub default: Option<Expr>,
pub merge: Option<Expr>,
pub merge_action: Vec<NormAction>,
pub cost: Option<usize>,
pub unextractable: bool,
}

impl NormFunctionDecl {
pub fn to_fdecl(&self) -> FunctionDecl {
FunctionDecl {
name: self.name,
schema: self.schema.clone(),
default: self.default.clone(),
merge: self.merge.clone(),
merge_action: self.merge_action.iter().map(|a| a.to_action()).collect(),
cost: self.cost,
unextractable: self.unextractable,
}
}
}

/// Represents the declaration of a function
/// directly parsed from source syntax.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct FunctionDecl {
pub name: Symbol,
pub schema: Schema,
pub default: Option<Expr>,
// TODO we should desugar merge and merge action
pub merge: Option<Expr>,
pub merge_action: Vec<Action>,
pub cost: Option<usize>,
Expand Down
29 changes: 2 additions & 27 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,8 @@ impl EGraph {
self.unionfind.n_unions() - n_unions + function.clear_updates()
}

pub fn declare_function(&mut self, decl: &FunctionDecl) -> Result<(), Error> {
let function = Function::new(self, decl)?;
pub fn declare_function(&mut self, decl: &NormFunctionDecl) -> Result<(), Error> {
let function = Function::new(self, &decl.to_fdecl())?;
let old = self.functions.insert(decl.name, function);
if old.is_some() {
panic!(
Expand All @@ -453,31 +453,6 @@ impl EGraph {
Ok(())
}

pub fn declare_constructor(
&mut self,
variant: Variant,
sort: impl Into<Symbol>,
) -> Result<(), Error> {
let name = variant.name;
let sort = sort.into();
self.declare_function(&FunctionDecl {
name,
schema: Schema {
input: variant.types,
output: sort,
},
merge: None,
merge_action: vec![],
default: None,
cost: variant.cost,
unextractable: false,
})?;
// if let Some(ctors) = self.sorts.get_mut(&sort) {
// ctors.push(name);
// }
Ok(())
}

pub fn eval_lit(&self, lit: &Literal) -> Value {
match lit {
Literal::Int(i) => i.store(&self.type_info().get_sort()).unwrap(),
Expand Down
5 changes: 4 additions & 1 deletion src/typechecking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ impl TypeInfo {
Ok(())
}

pub(crate) fn function_to_functype(&self, func: &FunctionDecl) -> Result<FuncType, TypeError> {
pub(crate) fn function_to_functype(
&self,
func: &NormFunctionDecl,
) -> Result<FuncType, TypeError> {
let input = func
.schema
.input
Expand Down

0 comments on commit 152110c

Please sign in to comment.