Skip to content

Commit

Permalink
Merge pull request #223 from oflatt/oflatt-distinguish-datatypes
Browse files Browse the repository at this point in the history
Distinguish datatypes from other tables
  • Loading branch information
oflatt authored Sep 12, 2023
2 parents f283176 + f6df3ff commit 4d67f26
Show file tree
Hide file tree
Showing 18 changed files with 196 additions and 139 deletions.
13 changes: 7 additions & 6 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,10 @@ pub(crate) fn desugar_command(
vec![NCommand::SetOption { name, value }]
}
Command::Function(fdecl) => desugar.desugar_function(&fdecl),
Command::Relation {
constructor,
inputs,
} => desugar.desugar_function(&FunctionDecl::relation(constructor, inputs)),
Command::Declare { name, sort } => desugar.declare(name, sort),
Command::Datatype { name, variants } => desugar_datatype(name, variants),
Command::Rewrite(ruleset, rewrite) => {
Expand Down Expand Up @@ -803,17 +807,14 @@ impl Desugar {
}

pub fn desugar_function(&mut self, fdecl: &FunctionDecl) -> Vec<NCommand> {
let mut res = vec![];
let schema = fdecl.schema.clone();
res.push(NCommand::Function(NormFunctionDecl {
vec![NCommand::Function(NormFunctionDecl {
name: fdecl.name,
schema,
schema: fdecl.schema.clone(),
default: fdecl.default.clone(),
merge: fdecl.merge.clone(),
merge_action: flatten_actions(&fdecl.merge_action, self),
cost: fdecl.cost,
unextractable: fdecl.unextractable,
}));
res
})]
}
}
35 changes: 34 additions & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,29 @@ pub enum Command {
sort: Symbol,
},
Sort(Symbol, Option<(Symbol, Vec<Expr>)>),
/// Declare an egglog function.
/// The function is a datatype when:
/// - The output is not a primitive
/// - No merge function is provided
/// - No default is provided
Function(FunctionDecl),
/// Declare an egglog relation, which is simply sugar
/// for a function returning the `Unit` type.
/// Example:
/// ```lisp
/// (relation path (i64 i64))
/// (relation edge (i64 i64))
/// ```

/// Desugars to:
/// ```lisp
/// (function path (i64 i64) Unit :default ())
/// (function edge (i64 i64) Unit :default ())
/// ```
Relation {
constructor: Symbol,
inputs: Vec<Symbol>,
},
AddRuleset(Symbol),
Rule {
name: Symbol,
Expand Down Expand Up @@ -393,6 +415,10 @@ impl ToSexp for Command {
Command::Sort(name, None) => list!("sort", name),
Command::Sort(name, Some((name2, args))) => list!("sort", name, list!( name2, ++ args)),
Command::Function(f) => f.to_sexp(),
Command::Relation {
constructor,
inputs,
} => list!("relation", constructor, list!(++ inputs)),
Command::AddRuleset(name) => list!("ruleset", name),
Command::Rule {
name,
Expand Down Expand Up @@ -597,7 +623,7 @@ impl FunctionDecl {
},
merge: None,
merge_action: vec![],
default: None,
default: Some(Expr::Lit(Literal::Unit)),
cost: None,
unextractable: false,
}
Expand Down Expand Up @@ -748,8 +774,15 @@ impl Display for Fact {
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Action {
Let(Symbol, Expr),
/// `set` a table to a particular result.
/// `set` should not be used on datatypes-
/// instead, use `union`.
Set(Symbol, Vec<Expr>, Expr),
Delete(Symbol, Vec<Expr>),
/// `union` two datatypes, making them equal
/// in the implicit, global equality relation
/// of egglog.
/// All rules match modulo this equality relation.
Union(Expr, Expr),
Extract(Expr, Expr),
Panic(String),
Expand Down
4 changes: 2 additions & 2 deletions src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Command: Command = {
Command::Function(FunctionDecl { name, schema, merge, merge_action: merge_action.unwrap_or_default(), default, cost, unextractable: unextractable.is_some() })
},
LParen "declare" <name:Ident> <sort:Ident> RParen => Command::Declare{name, sort},
LParen "relation" <name:Ident> <types:List<Type>> RParen => Command::Function(FunctionDecl::relation(name, types)),
LParen "relation" <constructor:Ident> <inputs:List<Type>> RParen => Command::Relation{constructor, inputs},
LParen "ruleset" <name:Ident> RParen => Command::AddRuleset(name),
LParen "rule" <body:List<Fact>> <head:List<Action>> <ruleset:(":ruleset" <Ident>)?> <name:(":name" <String>)?> RParen => Command::Rule{ruleset: ruleset.unwrap_or("".into()), name: name.unwrap_or("".to_string()).into(), rule: Rule { head, body }},
LParen "rewrite" <lhs:Expr> <rhs:Expr>
Expand Down Expand Up @@ -141,7 +141,7 @@ pub Expr: Expr = {
};

Literal: Literal = {
// "(" ")" => Literal::Unit, // shouldn't need unit literals for now
"(" ")" => Literal::Unit,
<Num> => Literal::Int(<>),
<F64> => Literal::F64(<>),
<SymString> => Literal::String(<>),
Expand Down
118 changes: 67 additions & 51 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1133,57 +1133,7 @@ impl EGraph {
}
}
NCommand::Input { name, file } => {
let func = self.functions.get_mut(&name).unwrap();
let is_unit = func.schema.output.name().as_str() == "Unit";

let mut filename = self.fact_directory.clone().unwrap_or_default();
filename.push(file.as_str());

// check that the function uses supported types
for t in &func.schema.input {
match t.name().as_str() {
"i64" | "String" => {}
s => panic!("Unsupported type {} for input", s),
}
}
match func.schema.output.name().as_str() {
"i64" | "String" | "Unit" => {}
s => panic!("Unsupported type {} for input", s),
}

log::info!("Opening file '{:?}'...", filename);
let mut f = File::open(filename).unwrap();
let mut contents = String::new();
f.read_to_string(&mut contents).unwrap();

let mut actions: Vec<Action> = vec![];
let mut str_buf: Vec<&str> = vec![];
for line in contents.lines() {
str_buf.clear();
str_buf.extend(line.split('\t').map(|s| s.trim()));
if str_buf.is_empty() {
continue;
}

let parse = |s: &str| -> Expr {
if let Ok(i) = s.parse() {
Expr::Lit(Literal::Int(i))
} else {
Expr::Lit(Literal::String(s.into()))
}
};

let mut exprs: Vec<Expr> = str_buf.iter().map(|&s| parse(s)).collect();

actions.push(if is_unit {
Action::Expr(Expr::Call(name, exprs))
} else {
let out = exprs.pop().unwrap();
Action::Set(name, exprs, out)
});
}
self.eval_actions(&actions)?;
log::info!("Read {} facts into {name} from '{file}'.", actions.len())
self.input_file(name, file)?;
}
NCommand::Output { file, exprs } => {
let mut filename = self.fact_directory.clone().unwrap_or_default();
Expand All @@ -1210,6 +1160,72 @@ impl EGraph {
Ok(())
}

fn input_file(&mut self, name: Symbol, file: String) -> Result<(), Error> {
let function_type = self
.type_info()
.func_types
.get(&name)
.unwrap_or_else(|| panic!("Unrecognzed function name {}", name))
.clone();
let func = self.functions.get_mut(&name).unwrap();

let mut filename = self.fact_directory.clone().unwrap_or_default();
filename.push(file.as_str());

// check that the function uses supported types

for t in &func.schema.input {
match t.name().as_str() {
"i64" | "String" => {}
s => panic!("Unsupported type {} for input", s),
}
}

if !function_type.is_datatype {
match func.schema.output.name().as_str() {
"i64" | "String" | "Unit" => {}
s => panic!("Unsupported type {} for input", s),
}
}

log::info!("Opening file '{:?}'...", filename);
let mut f = File::open(filename).unwrap();
let mut contents = String::new();
f.read_to_string(&mut contents).unwrap();

let mut actions: Vec<Action> = vec![];
let mut str_buf: Vec<&str> = vec![];
for line in contents.lines() {
str_buf.clear();
str_buf.extend(line.split('\t').map(|s| s.trim()));
if str_buf.is_empty() {
continue;
}

let parse = |s: &str| -> Expr {
if let Ok(i) = s.parse() {
Expr::Lit(Literal::Int(i))
} else {
Expr::Lit(Literal::String(s.into()))
}
};

let mut exprs: Vec<Expr> = str_buf.iter().map(|&s| parse(s)).collect();

actions.push(
if function_type.is_datatype || function_type.output.name() == UNIT_SYM.into() {
Action::Expr(Expr::Call(name, exprs))
} else {
let out = exprs.pop().unwrap();
Action::Set(name, exprs, out)
},
);
}
self.eval_actions(&actions)?;
log::info!("Read {} facts into {name} from '{file}'.", actions.len());
Ok(())
}

pub fn clear(&mut self) {
for f in self.functions.values_mut() {
f.clear();
Expand Down
12 changes: 7 additions & 5 deletions src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ impl<'a> ExprChecker<'a> for ActionChecker<'a> {
let func_type = self.egraph.type_info().func_types.get(&f).unwrap();
self.instructions.push(Instruction::CallFunction(
f,
func_type.has_default || !func_type.has_merge,
func_type.has_default || func_type.is_datatype,
));
}

Expand Down Expand Up @@ -600,7 +600,12 @@ trait ExprChecker<'a> {
Expr::Var(sym) => self.infer_var(*sym),
Expr::Call(sym, args) => {
if let Some(functype) = self.egraph().type_info().func_types.get(sym) {
assert!(functype.input.len() == args.len());
assert_eq!(
functype.input.len(),
args.len(),
"Got wrong number of arguments for function {}",
functype.name
);

let mut ts = vec![];
for (expected, arg) in functype.input.iter().zip(args) {
Expand Down Expand Up @@ -747,9 +752,6 @@ impl EGraph {
let value = if let Some(out) = function.nodes.get(values) {
out.value
} else if make_defaults {
if function.merge.on_merge.is_some() {
panic!("No value found for function {} with values {:?}", f, values);
}
let ts = self.timestamp;
let out = &function.schema.output;
match function.decl.default.as_ref() {
Expand Down
47 changes: 26 additions & 21 deletions src/typechecking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,15 @@ pub const RULE_PROOF_KEYWORD: &str = "rule-proof";

#[derive(Clone, Debug)]
pub struct FuncType {
pub name: Symbol,
pub input: Vec<ArcSort>,
pub output: ArcSort,
pub has_merge: bool,
pub is_datatype: bool,
pub has_default: bool,
}

impl FuncType {
pub fn new(input: Vec<ArcSort>, output: ArcSort, has_merge: bool, has_default: bool) -> Self {
Self {
input,
output,
has_merge,
has_default,
}
}
}

/// Stores resolved typechecking information.
/// TODO make these not public, use accessor methods
#[derive(Clone)]
pub struct TypeInfo {
// get the sort from the sorts name()
Expand Down Expand Up @@ -144,12 +136,14 @@ impl TypeInfo {
} else {
Err(TypeError::Unbound(func.schema.output))
}?;
Ok(FuncType::new(

Ok(FuncType {
name: func.name,
input,
output,
func.merge.is_some(),
func.default.is_some(),
))
output: output.clone(),
is_datatype: output.is_eq_sort() && func.merge.is_none() && func.default.is_none(),
has_default: func.default.is_some(),
})
}

fn typecheck_ncommand(&mut self, command: &NCommand, id: CommandId) -> Result<(), TypeError> {
Expand Down Expand Up @@ -421,10 +415,13 @@ impl TypeInfo {
self.typecheck_expr(ctx, expr, true)?;
}
NormAction::Set(expr, other) => {
let func_type = self.typecheck_expr(ctx, expr, true)?.output;
let func_type = self.typecheck_expr(ctx, expr, true)?;
let other_type = self.lookup(ctx, *other)?;
if func_type.name() != other_type.name() {
return Err(TypeError::TypeMismatch(func_type, other_type));
if func_type.output.name() != other_type.name() {
return Err(TypeError::TypeMismatch(func_type.output, other_type));
}
if func_type.is_datatype {
return Err(TypeError::SetDatatype(func_type));
}
}
NormAction::Union(var1, var2) => {
Expand Down Expand Up @@ -580,7 +577,13 @@ impl TypeInfo {
if let Some(prims) = self.primitives.get(&sym) {
for prim in prims {
if let Some(return_type) = prim.accept(&input_types) {
return Ok(FuncType::new(input_types, return_type, false, true));
return Ok(FuncType {
name: sym,
input: input_types,
output: return_type,
is_datatype: false,
has_default: true,
});
}
}
}
Expand Down Expand Up @@ -656,6 +659,8 @@ pub enum TypeError {
FunctionAlreadyBound(Symbol),
#[error("Function declarations are not allowed after a push.")]
FunctionAfterPush(Symbol),
#[error("Cannot set the datatype {} to a value. Did you mean to use union?", .0.name)]
SetDatatype(FuncType),
#[error("Sort declarations are not allowed after a push.")]
SortAfterPush(Symbol),
#[error("Global already bound {0}")]
Expand Down
4 changes: 2 additions & 2 deletions tests/array.egg
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
(rewrite (select (store mem i e) i) e)
; select passes through wrong index
(rule ((= (select (store mem i1 e) i2) e1) (neq i1 i2))
((set (select mem i2) e1)))
((union (select mem i2) e1)))
; aliasing writes destroy old value
(rewrite (store (store mem i e1) i e2) (store mem i e2))
; non-aliasing writes commutes
(rule ((= (store (store mem i2 e2) i1 e1) mem1) (neq i1 i2))
((set (store (store mem i1 e1) i2 e2) mem1)))
((union (store (store mem i1 e1) i2 e2) mem1)))

; typical math rules
(rewrite (add x y) (add y x))
Expand Down
Loading

0 comments on commit 4d67f26

Please sign in to comment.