Skip to content

Commit

Permalink
datatype*
Browse files Browse the repository at this point in the history
  • Loading branch information
yihozhang committed Oct 4, 2024
1 parent 7dba971 commit ebc31ec
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 5 deletions.
43 changes: 43 additions & 0 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,49 @@ pub(crate) fn desugar_command(
Command::Input { span, name, file } => {
vec![NCommand::Input { span, name, file }]
}
Command::Datatypes { span: _, datatypes } => {
let mut res = vec![];
for datatype in datatypes.iter() {
let span = datatype.0.clone();
let name = datatype.1;
if datatype.2.is_ok() {
res.push(NCommand::Sort(span, name, None));
}
}
let (variants_vec, sorts): (Vec<_>, Vec<_>) = datatypes
.into_iter()
.partition(|datatype| datatype.2.is_ok());

for sort in sorts {
let span = sort.0.clone();
let name = sort.1;
let constructor = sort.2.unwrap_err();
res.push(NCommand::Sort(span, name, Some(constructor)));
}

for variants in variants_vec {
let datatype = variants.1;
let variants = variants.2.unwrap();
for variant in variants {
res.push(NCommand::Function(FunctionDecl {
name: variant.name,
schema: Schema {
input: variant.types,
output: datatype,
},
merge: None,
merge_action: Actions::default(),
default: None,
cost: variant.cost,
unextractable: false,
ignore_viz: false,
span: variant.span,
}));
}
}

res
}
};

Ok(res)
Expand Down
22 changes: 17 additions & 5 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ pub type Command = GenericCommand<Symbol, Symbol>;

pub type Subsume = bool;

pub type Subdatatypes = Result<Vec<Variant>, (Symbol, Vec<Expr>)>;

/// A [`Command`] is the top-level construct in egglog.
/// It includes defining rules, declaring functions,
/// adding to tables, and running rules (via a [`Schedule`]).
Expand Down Expand Up @@ -488,6 +490,10 @@ where
name: Symbol,
variants: Vec<Variant>,
},
Datatypes {
span: Span,
datatypes: Vec<(Span, Symbol, Subdatatypes)>,
},
/// Create a new user-defined sort, which can then
/// be used in new [`Command::Function`] declarations.
/// The [`Command::Datatype`] command desugars directly to this command, with one [`Command::Function`]
Expand All @@ -503,11 +509,7 @@ where
/// ```
///
/// Now `MathVec` can be used as an input or output sort.
Sort(
Span,
Symbol,
Option<(Symbol, Vec<GenericExpr<Symbol, Symbol>>)>,
),
Sort(Span, Symbol, Option<(Symbol, Vec<Expr>)>),
/// Declare an egglog function, which is a database table with a
/// a functional dependency (also called a primary key) on its inputs to one output.
///
Expand Down Expand Up @@ -878,6 +880,16 @@ where
expr,
schedule,
} => list!("simplify", schedule, expr),
GenericCommand::Datatypes { span: _, datatypes } => {
let datatypes: Vec<_> = datatypes
.iter()
.map(|(_, name, variants)| match variants {
Ok(variants) => list!(name, ++ variants),
Err((head, args)) => list!("sort", name, list!(head, ++ args)),
})
.collect();
list!("datatypes", ++ datatypes)
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,19 @@ Comma<T>: Vec<T> = {
}
};

RecDatatype: (Span, Symbol, Result<Vec<Variant>, (Symbol, Vec<Expr>)>) = {
<lo:LParen> <name:Ident> <variants:(Variant)*> <hi:RParen> => (Span(srcfile.clone(), lo, hi), name, Ok(variants)),
<lo:LParen> "sort" <name:Ident> LParen <head:Ident> <exprs:(Expr)*> RParen <hi:RParen> => (Span(srcfile.clone(), lo, hi), name, Err((head, exprs))),
}

Command: Command = {
LParen "set-option" <name:Ident> <value:Expr> RParen => Command::SetOption { name, value },
<lo:LParen> "datatype" <name:Ident> <variants:(Variant)*> <hi:RParen> => Command::Datatype { span: Span(srcfile.clone(), lo, hi), name, variants },
<lo:LParen> "sort" <name:Ident> LParen <head:Ident> <tail:(Expr)*> RParen <hi:RParen> => Command::Sort (Span(srcfile.clone(), lo, hi), name, Some((head, tail))),
<lo:LParen> "sort" <name:Ident> <hi:RParen> => Command::Sort (Span(srcfile.clone(), lo, hi), name, None),
<lo:LParen> "datatype*"
<datatypes:RecDatatype*>
<hi:RParen> => Command::Datatypes { span: Span(srcfile.clone(), lo, hi), datatypes },
<lo:LParen> "function" <name:Ident> <schema:Schema> <cost:Cost>
<unextractable:(":unextractable")?>
<merge_action:(":on_merge" <List<Action>>)?>
Expand Down
11 changes: 11 additions & 0 deletions tests/datatypes.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
(datatype*
(Math
(Add Math Math)
(Sum MathVec)
(B Bool))
(sort MathVec (Vec Math))
(Bool
(True)
(False)))

(let expr (Add (Sum (vec-of (B (True)) (B (False)))) (B (True))))

0 comments on commit ebc31ec

Please sign in to comment.