From f0bf927d4c5b72bdafa42ee4a4737979329e8ffc Mon Sep 17 00:00:00 2001 From: oflatt Date: Mon, 11 Nov 2024 12:38:37 -0800 Subject: [PATCH] functionhastype --- dag_in_context/src/lib.rs | 13 ++++ .../src/optimizations/function_inlining.rs | 59 +++++++++---------- dag_in_context/src/schema.egg | 4 +- dag_in_context/src/type_analysis.egg | 4 +- 4 files changed, 45 insertions(+), 35 deletions(-) diff --git a/dag_in_context/src/lib.rs b/dag_in_context/src/lib.rs index 89472f186..1e576e1a8 100644 --- a/dag_in_context/src/lib.rs +++ b/dag_in_context/src/lib.rs @@ -170,6 +170,19 @@ pub fn build_program( let loop_context_unions = cache.get_unions_with_sharing(&mut printed, &mut tree_state, &mut term_cache); + // set the type of each function + for func in fns { + let func = program.get_function(func).unwrap(); + let func_name = func.func_name().unwrap(); + let input_ty = func.func_input_ty().unwrap(); + let func_ty = func.func_output_ty().unwrap(); + writeln!( + &mut printed, + "(FunctionHasType \"{func_name}\" {input_ty} {func_ty})", + ) + .unwrap(); + } + let prologue = prologue(); format!( diff --git a/dag_in_context/src/optimizations/function_inlining.rs b/dag_in_context/src/optimizations/function_inlining.rs index 8a9e5e4bf..aad08cfb3 100644 --- a/dag_in_context/src/optimizations/function_inlining.rs +++ b/dag_in_context/src/optimizations/function_inlining.rs @@ -137,42 +137,37 @@ pub fn print_function_inlining_pairs( let printed_pairs = function_inlining_pairs .iter() .map(|cb| { - if let Expr::Call(callee, _) = cb.call.as_ref() { - let call_term = cb.call.to_egglog_with(tree_state); - let call_with_intermed = print_with_intermediate_helper( - &tree_state.termdag, - call_term.clone(), - term_cache, - printed, - ); - - let body_term = cb.body.to_egglog_with(tree_state); - let inlined_with_intermed = print_with_intermediate_helper( - &tree_state.termdag, - body_term, - term_cache, - printed, - ); - - let call_args = cb.call.children_exprs()[0].to_egglog_with(tree_state); - let call_args_with_intermed = print_with_intermediate_helper( - &tree_state.termdag, - call_args.clone(), - term_cache, - printed, - ); - format!( - // We need to subsume, otherwise the Call in the original program could get - // substituted into another context during optimization and no longer match InlinedCall. - " + let Expr::Call(callee, _) = cb.call.as_ref() else { + panic!("Tried to inline non-call") + }; + let call_term = cb.call.to_egglog_with(tree_state); + let call_with_intermed = print_with_intermediate_helper( + &tree_state.termdag, + call_term.clone(), + term_cache, + printed, + ); + + let body_term = cb.body.to_egglog_with(tree_state); + let inlined_with_intermed = + print_with_intermediate_helper(&tree_state.termdag, body_term, term_cache, printed); + + let call_args = cb.call.children_exprs()[0].to_egglog_with(tree_state); + let call_args_with_intermed = print_with_intermediate_helper( + &tree_state.termdag, + call_args.clone(), + term_cache, + printed, + ); + format!( + // We need to subsume, otherwise the Call in the original program could get + // substituted into another context during optimization and no longer match InlinedCall. + " (union {call_with_intermed} {inlined_with_intermed}) (InlinedCall \"{callee}\" {call_args_with_intermed}) (subsume (Call \"{callee}\" {call_args_with_intermed})) ", - ) - } else { - panic!("Tried to inline non-call") - } + ) }) .collect::>() .join("\n"); diff --git a/dag_in_context/src/schema.egg b/dag_in_context/src/schema.egg index 5f4db287d..e6351b840 100644 --- a/dag_in_context/src/schema.egg +++ b/dag_in_context/src/schema.egg @@ -200,7 +200,9 @@ ; name input ty output ty output (function Function (String Type Type Expr) Expr) - +; to get the type of a funciton, look in this table +; since we might not be optimizing the entire program +(relation FunctionHasType (String Type Type)) ; Rulesets (ruleset always-run) diff --git a/dag_in_context/src/type_analysis.egg b/dag_in_context/src/type_analysis.egg index a0289a597..69da1e8cd 100644 --- a/dag_in_context/src/type_analysis.egg +++ b/dag_in_context/src/type_analysis.egg @@ -523,7 +523,7 @@ (rule ( (= lhs (Call name arg)) - (Function name in-ty out-ty body) + (FunctionHasType name in-ty out-ty) ) ; Expect the arg to have the right type for the function ((ExpectType arg in-ty "function called with wrong arg type")) @@ -531,7 +531,7 @@ (rule ( (= lhs (Call name arg)) - (Function name in-ty out-ty body) + (FunctionHasType name in-ty out-ty) (HasType arg in-ty) ; We don't need to check the type of the function body, it will ; be checked elsewhere. If we did require (HasType body out-ty),