From 19dc29161658cdfadb984fc7f7d5962daa935979 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 10 Jan 2023 09:43:59 +0530 Subject: [PATCH] rename_inames: replace old inames that appear as params in other domains --- loopy/transform/iname.py | 13 ++++++++++++- test/test_transform.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 82112fd8a..c500f14ba 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -2327,6 +2327,7 @@ def rename_inames(kernel, old_inames, new_iname, existing_ok=False, raise LoopyError(f"iname '{new_iname}' conflicts with an existing identifier" " --cannot rename") + orig_old_inames = old_inames if not does_exist: # {{{ rename old_inames[0] -> new_iname # so that the code below can focus on "merging" inames that already exist @@ -2404,6 +2405,16 @@ def does_insn_involve_iname(kernel, insn, *args): smap.map_kernel(kernel, within=does_insn_involve_iname, map_tvs=False, map_args=False)) + # replace instances where the old inames appear as a param + new_domains = [] + for dom in kernel.domains: + for old_iname in orig_old_inames: + d = dom.get_var_dict() + if old_iname in d and new_iname not in d: + var_type, var_num = d[old_iname] + dom = dom.set_dim_name(var_type, var_num, new_iname) + new_domains.append(dom) + new_instructions = [insn.copy(within_inames=((insn.within_inames - frozenset(old_inames)) | frozenset([new_iname]))) @@ -2412,7 +2423,7 @@ def does_insn_involve_iname(kernel, insn, *args): else insn for insn in kernel.instructions] - kernel = kernel.copy(instructions=new_instructions) + kernel = kernel.copy(instructions=new_instructions, domains=new_domains) return kernel diff --git a/test/test_transform.py b/test/test_transform.py index 1b62f31d9..2c1ce434e 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -1468,6 +1468,21 @@ def test_rename_inames(ctx_factory): lp.auto_test_vs_ref(knl, ctx, ref_knl) +def test_rename_inames_with_params(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + [ + "{ [i]: 0<=i<10 }", + "{ [k]: 0<=k