From 7d2c619cf6f055b762eae7dba08adbf3b4836a74 Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Fri, 26 Jan 2024 18:20:19 -0800 Subject: [PATCH] add ExprUsesArgs analysis --- tree_unique_args/src/arg_used_analysis.rs | 129 ++++++++++++++++++++++ tree_unique_args/src/lib.rs | 2 + tree_unique_args/src/schema.egg | 4 + 3 files changed, 135 insertions(+) create mode 100644 tree_unique_args/src/arg_used_analysis.rs diff --git a/tree_unique_args/src/arg_used_analysis.rs b/tree_unique_args/src/arg_used_analysis.rs new file mode 100644 index 000000000..c6bf08740 --- /dev/null +++ b/tree_unique_args/src/arg_used_analysis.rs @@ -0,0 +1,129 @@ +use crate::ir::{Constructor, ESort, Purpose}; +use strum::IntoEnumIterator; + +fn arg_used_rule_for_ctor(ctor: Constructor) -> Option { + if ctor == Constructor::Arg { + return Some(format!( + " + (rule ( + (ExprUsesArgs-demand e) + (= e (Get (Arg id) i)) + ) ( + (set (ExprUsesArgs e) (set-of i)) + ) + :ruleset always-run)" + )); + } + + let children_queries = ctor + .filter_map_fields(|field| match field.purpose { + Purpose::Static(_) + | Purpose::CapturingId + | Purpose::ReferencingId + | Purpose::CapturedExpr => None, + Purpose::SubExpr | Purpose::SubListExpr => { + let var = field.var(); + let sort = field.sort().name(); + Some(format!("(= args-{var} ({sort}UsesArgs {var}))")) + } + }) + .join(" "); + + let children_demand = ctor + .filter_map_fields(|field| match field.purpose { + Purpose::Static(_) + | Purpose::CapturingId + | Purpose::ReferencingId + | Purpose::CapturedExpr => None, + Purpose::SubExpr | Purpose::SubListExpr => { + let var = field.var(); + let sort = field.sort().name(); + Some(format!("({sort}UsesArgs-demand {var})")) + } + }) + .join(" "); + + let fields = ctor + .fields() + .into_iter() + .filter(|field| field.purpose == Purpose::SubExpr || field.purpose == Purpose::SubListExpr) + .collect::>(); + let union_expr = match fields.len() { + 0 => return None, + 1 => format!("args-{}", fields[0].var()), + _ => { + let mut union_expr = vec![]; + let (last_field, fields) = fields.split_last().unwrap(); + for field in fields { + let var = field.var(); + union_expr.push(format!("(set-union args-{var} ")); + } + union_expr.push(format!("args-{}", last_field.var())); + for _ in fields { + union_expr.push(")".into()); + } + union_expr.join(" ") + } + }; + + let ctor_pattern = ctor.construct(|field| field.var()); + + let sort = ctor.sort().name(); + Some(format!( + " + ;; propagation of demand + (rule ( + ({sort}UsesArgs-demand e) + (= e {ctor_pattern}) + ) ( + {children_demand} + ) + :ruleset always-run) + + ;; collecting set of args + (rule ( + ({sort}UsesArgs-demand e) + (= e {ctor_pattern}) + {children_queries} + ) ( + (set ({sort}UsesArgs e) {union_expr}) + ) + :ruleset always-run)" + )) +} + +pub(crate) fn arg_used_analysis_rules() -> Vec { + ESort::iter() + .map(|sort| { + " + (function *UsesArgs (*) I64Set :merge (set-union old new)) + (relation *UsesArgs-demand (*)) + + (rule ((*UsesArgs-demand e)) + ((set (*UsesArgs e) (set-empty))) :ruleset always-run) + " + .replace('*', sort.name()) + }) + .chain(Constructor::iter().filter_map(arg_used_rule_for_ctor)) + .collect::>() +} + +#[test] +fn test_args_used_analysis() -> Result<(), egglog::Error> { + let build = &*" + (let id1 (Id (i64-fresh!))) + (let id2 (Id (i64-fresh!))) + (let expr1 + (All (Parallel) (Pair (Let id2 (All (Parallel) (Pair (Get (Arg id1) 3) + (Num id1 1))) + (Get (Arg id2) 0)) + (Add (Get (Arg id1) 1) + (Get (Arg id1) 2))))) + (ExprUsesArgs-demand expr1) + " + .to_string(); + let check = " + (check (= (ExprUsesArgs expr1) (set-of 1 2 3))) + "; + crate::run_test(build, check) +} diff --git a/tree_unique_args/src/lib.rs b/tree_unique_args/src/lib.rs index 1caeddce1..696ca065c 100644 --- a/tree_unique_args/src/lib.rs +++ b/tree_unique_args/src/lib.rs @@ -1,3 +1,4 @@ +pub(crate) mod arg_used_analysis; pub mod ast; pub(crate) mod body_contains; pub(crate) mod conditional_invariant_code_motion; @@ -137,6 +138,7 @@ pub fn run_test(build: &str, check: &str) -> Result { &is_valid::rules().join("\n"), &purity_analysis::purity_analysis_rules().join("\n"), &body_contains::rules().join("\n"), + &arg_used_analysis::arg_used_analysis_rules().join("\n"), &subst::subst_rules().join("\n"), &deep_copy::deep_copy_rules().join("\n"), include_str!("sugar.egg"), diff --git a/tree_unique_args/src/schema.egg b/tree_unique_args/src/schema.egg index c0f553385..c1425b850 100644 --- a/tree_unique_args/src/schema.egg +++ b/tree_unique_args/src/schema.egg @@ -1,3 +1,7 @@ +;; container definitions + +(sort I64Set (Set i64)) + ; We could generate this from ir.rs, this manual version is just easier reference. (sort IdSort)