Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation #106

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions src/enumo/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mod test {

#[test]
fn metric_lt() {
let wkld = Workload::from_vec(vec![
let wkld = Workload::new([
"(+ a a)",
"(+ a b)",
"(+ a (+ a b))",
Expand All @@ -47,13 +47,13 @@ mod test {
"(~ (+ a b))",
]);
let actual = wkld.filter(Filter::MetricLt(Metric::Atoms, 5)).force();
let expected = Workload::from_vec(vec!["(+ a a)", "(+ a b)", "(~ (+ a b))"]).force();
let expected = Workload::new(["(+ a a)", "(+ a b)", "(~ (+ a b))"]).force();
assert_eq!(actual, expected)
}

#[test]
fn contains() {
let wkld = Workload::from_vec(vec![
let wkld = Workload::new([
"(+ a a)",
"(+ a b)",
"(+ a (+ a b))",
Expand All @@ -64,29 +64,26 @@ mod test {
let actual = wkld
.filter(Filter::Contains("(+ ?x ?x)".parse().unwrap()))
.force();
let expected =
Workload::from_vec(vec!["(+ a a)", "(+ a (+ b b))", "(+ (+ a b) (+ a b))"]).force();
let expected = Workload::new(["(+ a a)", "(+ a (+ b b))", "(+ (+ a b) (+ a b))"]).force();
assert_eq!(actual, expected);
}

#[test]
fn and() {
let wkld = Workload::from_vec(vec![
"x", "y", "(x y)", "(y x)", "(x x x)", "(y y z)", "(x y z)",
]);
let wkld = Workload::new(["x", "y", "(x y)", "(y x)", "(x x x)", "(y y z)", "(x y z)"]);
let actual = wkld
.filter(Filter::And(
Box::new(Filter::Contains("x".parse().unwrap())),
Box::new(Filter::Contains("y".parse().unwrap())),
))
.force();
let expected = Workload::from_vec(vec!["(x y)", "(y x)", "(x y z)"]).force();
let expected = Workload::new(["(x y)", "(y x)", "(x y z)"]).force();
assert_eq!(actual, expected);
}

#[test]
fn invert() {
let wkld = Workload::from_vec(vec![
let wkld = Workload::new([
"(+ a a)",
"(+ a b)",
"(+ a (+ a b))",
Expand All @@ -99,8 +96,7 @@ mod test {
"(+ ?x ?x)".parse().unwrap(),
))))
.force();
let expected =
Workload::from_vec(vec!["(+ a b)", "(+ a (+ a b))", "(+ (+ a b) (+ b a))"]).force();
let expected = Workload::new(["(+ a b)", "(+ a (+ a b))", "(+ (+ a b) (+ b a))"]).force();
assert_eq!(actual, expected);
}
}
3 changes: 1 addition & 2 deletions src/enumo/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ mod test {
.map(|x| x.parse::<Pattern>().unwrap())
.collect();

let exprs =
Workload::from_vec(vec!["a", "x", "(+ x y)", "(+ y y)", "(+ (* a b) (* a b))"]).force();
let exprs = Workload::new(["a", "x", "(+ x y)", "(+ y y)", "(+ (* a b) (* a b))"]).force();

let expected = vec![
vec![true, true, true, true, true],
Expand Down
20 changes: 10 additions & 10 deletions src/enumo/sexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ mod test {
#[test]
fn plug() {
let x = "x".parse::<Sexp>().unwrap();
let pegs = Workload::from_vec(vec!["1", "2", "3"]).force();
let pegs = Workload::new(["1", "2", "3"]).force();
let expected = vec![x.clone()];
let actual = x.plug("a", &pegs);
assert_eq!(actual, expected);
Expand All @@ -199,8 +199,8 @@ mod test {
#[test]
fn plug_cross_product() {
let term = "(x x)";
let pegs = Workload::from_vec(vec!["1", "2", "3"]).force();
let expected = Workload::from_vec(vec![
let pegs = Workload::new(["1", "2", "3"]).force();
let expected = Workload::new([
"(1 1)", "(1 2)", "(1 3)", "(2 1)", "(2 2)", "(2 3)", "(3 1)", "(3 2)", "(3 3)",
])
.force();
Expand All @@ -210,11 +210,11 @@ mod test {

#[test]
fn multi_plug() {
let wkld = Workload::from_vec(vec!["(a b)", "(a)", "(b)"]);
let a_s = Workload::from_vec(vec!["1", "2", "3"]);
let b_s = Workload::from_vec(vec!["x", "y"]);
let actual = wkld.plug("a", &a_s).plug("b", &b_s).force();
let expected = Workload::from_vec(vec![
let wkld = Workload::new(["(a b)", "(a)", "(b)"]);
let a_s = Workload::new(["1", "2", "3"]);
let b_s = Workload::new(["x", "y"]);
let actual = wkld.plug("a", a_s).plug("b", b_s).force();
let expected = Workload::new([
"(1 x)", "(1 y)", "(2 x)", "(2 y)", "(3 x)", "(3 y)", "(1)", "(2)", "(3)", "(x)", "(y)",
])
.force();
Expand All @@ -223,7 +223,7 @@ mod test {

#[test]
fn canon() {
let inputs = Workload::from_vec(vec![
let inputs = Workload::new([
"(+ (/ c b) a)",
"(+ (- c c) (/ a a))",
"a",
Expand All @@ -241,7 +241,7 @@ mod test {
"(+ a (+ c b))",
])
.force();
let expecteds = Workload::from_vec(vec![
let expecteds = Workload::new([
"(+ (/ a b) c)",
"(+ (- a a) (/ b b))",
"a",
Expand Down
129 changes: 113 additions & 16 deletions src/enumo/workload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,40 @@ use crate::{HashSet, SynthAnalysis, SynthLanguage};

use super::*;

/// A `Workload` compactly and lazily describes the set of terms that the initial
/// egraph will be seeded with. Terms are described in a top-down manner. For
/// example, you can start with the following pattern:
///
/// ```lisp
/// (binop expr expr)
/// ```
///
/// By itself, this only describes a single term. However, you can "plug" in other
/// workloads to any atom in the pattern. So if you have a workload that describes
/// the terms: `[+, -]`. You can plug that in for `binop` to get the terms:
///
/// ```lisp
/// (+ expr expr)
/// (- expr expr)
/// ```
///
/// You can now expand `expr` to get more expressions. You can plug in the workload
/// `[0, 1, a, b]` for `expr` to get the terms:
///
/// ```lisp
/// (+ 0 0) (+ 1 0) (+ a 0) (+ b 0)
/// (+ 0 1) (+ 1 1) (+ a 1) (+ b 1)
/// (+ 0 a) (+ 1 a) (+ a a) (+ b a)
/// (+ 0 b) (+ 1 b) (+ a b) (+ b b)
///
/// (- 0 0) (- 1 0) (- a 0) (- b 0)
/// (- 0 1) (- 1 1) (- a 1) (- b 1)
/// (- 0 a) (- 1 a) (- a a) (- b a)
/// (- 0 b) (- 1 b) (- a b) (- b b)
/// ```
///
/// This simple mechanism allows you to express a large number of terms through
/// the composition of smaller workloads.
#[derive(PartialEq, Eq, Clone, Debug)]
pub enum Workload {
Set(Vec<Sexp>),
Expand All @@ -13,10 +47,12 @@ pub enum Workload {
}

impl Workload {
pub fn from_vec(strs: Vec<&str>) -> Self {
Self::Set(strs.iter().map(|x| x.parse().unwrap()).collect())
/// Construct a new workload from anything that can iterator over `&str`.
pub fn new<'a>(vals: impl IntoIterator<Item = &'a str>) -> Self {
Self::Set(vals.into_iter().map(|x| x.parse().unwrap()).collect())
}

/// Construct an `EGraph` from the terms represented by this workload.
pub fn to_egraph<L: SynthLanguage>(&self) -> EGraph<L, SynthAnalysis> {
let mut egraph = EGraph::default();
let sexps = self.force();
Expand Down Expand Up @@ -47,6 +83,7 @@ impl Workload {
egraph
}

/// Force the construction of all the terms represented in this workload.
pub fn force(&self) -> Vec<Sexp> {
match self {
Workload::Set(set) => set.clone(),
Expand All @@ -73,40 +110,94 @@ impl Workload {
}
}

/// Recursively expands `self` to a depth of `n` by plugging in `self` for `atom`.
///
/// For the workload:
/// ```
/// let wkld = Workload::new(["x", "(bop expr expr)"]);
/// ```
///
/// `wkld.iter("expr", 2)` generates
///
/// ```
/// x
/// (uop x x)
/// ```
///
/// and `wkld.iter("expr", 3)` generates
///
/// ```
/// x
/// (uop x x)
/// (uop x (uop x x))
/// (uop (uop x x) x)
/// (uop (uop x x) (uop x x))
/// ```
///
///
fn iter(self, atom: &str, n: usize) -> Self {
if n == 0 {
Self::Set(vec![])
} else {
let rec = self.clone().iter(atom, n - 1);
self.plug(atom, &rec)
self.plug(atom, rec)
}
}

/// Expands the workload at `atom` up to a depth of `n` and then filters by `met`.
pub fn iter_metric(self, atom: &str, met: Metric, n: usize) -> Self {
self.iter(atom, n).filter(Filter::MetricLt(met, n + 1))
}

/// A convenience function to quickly create a workload for a standard language.
/// - `n`: The depth of terms to generate in the language
/// - `consts`: The constant expressions in the language
/// - `vars`: The variables in the language
/// - `uops`: The unary operators in the language
/// - `bops`: The binary operators in the language
///
/// The following workload:
///
/// ```rust
/// Workload::iter_lang(
/// 3,
/// &["0", "1"],
/// &["x", "y"],
/// &["~"],
/// &["+", "-"]
/// )
/// ```
///
/// will generate terms in a simple arithmetic language.
pub fn iter_lang(
n: usize,
consts: &[&str],
vars: &[&str],
uops: &[&str],
bops: &[&str],
) -> Self {
let lang = Workload::from_vec(vec!["cnst", "var", "(uop expr)", "(bop expr expr)"]);
let lang = Workload::new(["cnst", "var", "(uop expr)", "(bop expr expr)"]);

lang.iter_metric("expr", Metric::Atoms, n)
.filter(Filter::Contains("var".parse().unwrap()))
.plug("cnst", &Workload::from_vec(consts.to_vec()))
.plug("var", &Workload::from_vec(vars.to_vec()))
.plug("uop", &Workload::from_vec(uops.to_vec()))
.plug("bop", &Workload::from_vec(bops.to_vec()))
.plug("cnst", consts)
.plug("var", vars)
.plug("uop", uops)
.plug("bop", bops)
}

pub fn plug(self, name: impl Into<String>, workload: &Workload) -> Self {
Workload::Plug(Box::new(self), name.into(), Box::new(workload.clone()))
/// Compose two workloads together by replacing every instance of `name` with
/// `workload`.
pub fn plug(self, name: impl Into<String>, workload: impl Into<Workload>) -> Self {
Workload::Plug(Box::new(self), name.into(), Box::new(workload.into()))
}

/// Append to workloads together.
pub fn append(self, workload: impl Into<Workload>) -> Self {
Workload::Append(vec![self, workload.into()])
}

/// Modify a workload by excluding all the terms matched by `filter`.
pub fn filter(self, filter: Filter) -> Self {
if filter.is_monotonic() {
if let Workload::Plug(wkld, name, pegs) = self {
Expand All @@ -123,13 +214,19 @@ impl Workload {
}
}

impl From<&[&str]> for Workload {
fn from(value: &[&str]) -> Self {
Workload::new(value.iter().map(|x| *x))
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn iter() {
let lang = Workload::from_vec(vec!["cnst", "var", "(uop expr)", "(bop expr expr)"]);
let lang = Workload::new(["cnst", "var", "(uop expr)", "(bop expr expr)"]);
let actual2 = lang.clone().iter("expr", 2).force();
assert_eq!(actual2.len(), 8);

Expand All @@ -139,7 +236,7 @@ mod test {

#[test]
fn iter_metric() {
let lang = Workload::from_vec(vec!["cnst", "var", "(uop expr)", "(bop expr expr)"]);
let lang = Workload::new(["cnst", "var", "(uop expr)", "(bop expr expr)"]);
let actual2 = lang.clone().iter_metric("expr", Metric::Atoms, 2).force();
assert_eq!(actual2.len(), 4);

Expand All @@ -150,22 +247,22 @@ mod test {
#[test]
fn iter_metric_fast() {
// This test will not finish if the pushing monotonic filters through plugs optimization is not working.
let lang = Workload::from_vec(vec!["cnst", "var", "(uop expr)", "(bop expr expr)"]);
let lang = Workload::new(["cnst", "var", "(uop expr)", "(bop expr expr)"]);
let six = lang.iter_metric("expr", Metric::Atoms, 6);
assert_eq!(six.force().len(), 188);
}

#[test]
fn contains() {
let lang = Workload::from_vec(vec!["cnst", "var", "(uop expr)", "(bop expr expr)"]);
let lang = Workload::new(["cnst", "var", "(uop expr)", "(bop expr expr)"]);

let actual3 = lang
.clone()
.iter_metric("expr", Metric::Atoms, 3)
.filter(Filter::Contains("var".parse().unwrap()))
.force();

let expected3 = Workload::from_vec(vec![
let expected3 = Workload::new([
"var",
"(uop var)",
"(uop (uop var))",
Expand All @@ -182,7 +279,7 @@ mod test {
.filter(Filter::Contains("var".parse().unwrap()))
.force();

let expected4 = Workload::from_vec(vec![
let expected4 = Workload::new([
"var",
"(uop var)",
"(uop (uop var))",
Expand Down
12 changes: 12 additions & 0 deletions src/equality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ use std::{str::FromStr, sync::Arc};

use crate::*;

/// A equality between the patterns `lhs` and `rhs`.
/// - `name` is a string representing this equality.
/// - `lhs` and `rhs` are the left hand side and right hand side of the equality.
/// - `rewrite` holds the `egg` version of the rewrite rule that this equality
/// represents.
#[derive(Clone, Debug)]
pub struct Equality<L: SynthLanguage> {
pub name: Arc<str>,
Expand Down Expand Up @@ -70,6 +75,7 @@ impl<L: SynthLanguage> Applier<L, SynthAnalysis> for Rhs<L> {
}

impl<L: SynthLanguage> Equality<L> {
/// Construct a new equality from two `egg::RecExpr`s.
pub fn new(e1: &RecExpr<L>, e2: &RecExpr<L>) -> Option<Self> {
let map = &mut HashMap::default();
let l_pat = L::generalize(e1, map);
Expand All @@ -86,6 +92,12 @@ impl<L: SynthLanguage> Equality<L> {
})
}

/// Checks if an equality will only ever shrink, or keep constant, the number of
/// e-classes in an e-graph.
///
/// This works by adding both sides of an equality to an egraph. If after unioning
/// the root of these two expressions, there are fewer e-classes than there were
/// originally, we say that this equality is saturating.
pub fn is_saturating(&self) -> bool {
let mut egraph: EGraph<L, SynthAnalysis> = Default::default();
let l_id = egraph.add_expr(&L::instantiate(&self.lhs));
Expand Down
Loading