From 3a840d54f0cf0c37f36ea6055e27783ffe0775c5 Mon Sep 17 00:00:00 2001 From: zhiyuan yan Date: Tue, 18 Jun 2024 15:08:35 -0700 Subject: [PATCH] fix of yihong's review --- dag_in_context/src/pretty_print.rs | 17 ++++++- dag_in_context/src/schema_helpers.rs | 2 + ...n_context__pretty_print__pretty_print.snap | 49 ++++++------------- 3 files changed, 33 insertions(+), 35 deletions(-) diff --git a/dag_in_context/src/pretty_print.rs b/dag_in_context/src/pretty_print.rs index 96cba0a9..cd788ac1 100644 --- a/dag_in_context/src/pretty_print.rs +++ b/dag_in_context/src/pretty_print.rs @@ -1,3 +1,9 @@ +// behaviors: the pretty printer take an RcExpr and return a log with folded top level Expr at the end and folded children +// of top level Expr before it. +// limitations: +// if two similar expr have different context, the to rust pretty printer will still print two expr +// when to rust, print to parallel!/switch! might not produce the optimal looking thing + use crate::{ from_egglog::FromEgglog, prologue, @@ -182,8 +188,13 @@ impl PrettyPrinter { .iter() .map(|fresh_var| { let node = self.table.get(fresh_var).unwrap(); - let ast = node.ast_node_to_str(true); - format!("let {fresh_var} = {ast};") + match node { + AstNode::Assumption(_) => String::new(), + _ => { + let ast = node.ast_node_to_str(true); + format!("let {fresh_var} = {ast};") + } + } }) .collect::>() .join("\n"); @@ -315,6 +326,8 @@ impl PrettyPrinter { } } Expr::Symbolic(_) => panic!("Expected non symbolic"), + Expr::Concat(..) | Expr::Single(..) if to_rust => expr + .map_expr_children(|e| self.refactor_shared_expr(e, fold_when, to_rust, log)), _ => { let expr2 = expr.map_expr_type(|ty| self.refactor_shared_type(ty, log)); let expr3 = expr2.map_expr_assum(|assum| { diff --git a/dag_in_context/src/schema_helpers.rs b/dag_in_context/src/schema_helpers.rs index 47b74297..dd39a719 100644 --- a/dag_in_context/src/schema_helpers.rs +++ b/dag_in_context/src/schema_helpers.rs @@ -274,6 +274,8 @@ impl Expr { } } + // this function might violate RcExpr's invariant + // for example function map_child is id function that create new RcExpr, and &self have two same children pub fn map_expr_children(self: &RcExpr, mut map_child: F) -> RcExpr where F: FnMut(&RcExpr) -> RcExpr, diff --git a/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap b/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap index 36b31f6b..ff785350 100644 --- a/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap +++ b/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap @@ -3,40 +3,23 @@ source: src/pretty_print.rs expression: ast --- let tpl_s_v0 = tuplet!(statet()); -let in_func_v1 = infunc("dummy"); -let concat_v2 = parallel!(int(4), getat(0)); -let concat_v3 = concat(single(int(2)), -concat(single(int(3)), -concat_v2.clone())); -let less_than_v4 = less_than(getat(0), + +let less_than_v2 = less_than(getat(0), getat(3)); -let single_v5 = single(getat(0)); -let single_v6 = single(getat(1)); -let single_v7 = single(getat(2)); -let single_v8 = single(getat(3)); -let sub_v9 = sub(getat(2), +let sub_v3 = sub(getat(2), getat(1)); -let tprint_v10 = tprint(sub_v9.clone(), +let tprint_v4 = tprint(sub_v3.clone(), getat(4)); -let concat_v11 = concat(concat(single(less_than_v4.clone()), -concat(single_v5.clone(), -single_v6.clone())), -concat(concat(single_v7.clone(), -single_v8.clone()), -single(tprint_v10.clone()))); -let tpl__v12 = emptyt(); -let tpl_i_v13 = tuplet!(intt()); -let in_func_v14 = infunc(" loop_ctx_0"); -let less_than_v15 = less_than(getat(0), +let dowhile_v5 = dowhile(parallel!(int(1), int(2), int(3), int(4), getat(0)), +parallel!(less_than_v2.clone(), getat(0), getat(1), getat(2), getat(3), tprint_v4.clone())); +let tpl__v6 = emptyt(); +let tpl_i_v7 = tuplet!(intt()); + +let less_than_v9 = less_than(getat(0), int(3)); -let in_switch_v16 = inswitch(0, -int(0), -arg()); -let switch_v17 = switch!(int(0), arg(); parallel!(int(4), int(5))); -let dowhile_v18 = dowhile(single(int(1)), -parallel!(less_than_v15.clone(), get(switch_v17.clone(), 0))); -let concat_v19 = concat(dowhile(concat(single(int(1)), -concat_v3.clone()), -concat_v11.clone()), -dowhile_v18.clone()); -let concat_v20 = concat_v19.clone(); + +let switch_v11 = switch!(int(0), arg(); parallel!(int(4), int(5))); +let dowhile_v12 = dowhile(single(int(1)), +parallel!(less_than_v9.clone(), get(switch_v11.clone(), 0))); +let concat_v13 = concat(dowhile_v5.clone(), +dowhile_v12.clone());