Skip to content

Commit

Permalink
rename_inames: replace old inames that appear as params in other domains
Browse files Browse the repository at this point in the history
  • Loading branch information
isuruf committed Jan 12, 2023
1 parent c072a86 commit 19dc291
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
13 changes: 12 additions & 1 deletion loopy/transform/iname.py
Original file line number Diff line number Diff line change
Expand Up @@ -2327,6 +2327,7 @@ def rename_inames(kernel, old_inames, new_iname, existing_ok=False,
raise LoopyError(f"iname '{new_iname}' conflicts with an existing identifier"
" --cannot rename")

orig_old_inames = old_inames
if not does_exist:
# {{{ rename old_inames[0] -> new_iname
# so that the code below can focus on "merging" inames that already exist
Expand Down Expand Up @@ -2404,6 +2405,16 @@ def does_insn_involve_iname(kernel, insn, *args):
smap.map_kernel(kernel, within=does_insn_involve_iname,
map_tvs=False, map_args=False))

# replace instances where the old inames appear as a param
new_domains = []
for dom in kernel.domains:
for old_iname in orig_old_inames:
d = dom.get_var_dict()
if old_iname in d and new_iname not in d:
var_type, var_num = d[old_iname]
dom = dom.set_dim_name(var_type, var_num, new_iname)
new_domains.append(dom)

new_instructions = [insn.copy(within_inames=((insn.within_inames
- frozenset(old_inames))
| frozenset([new_iname])))
Expand All @@ -2412,7 +2423,7 @@ def does_insn_involve_iname(kernel, insn, *args):
else insn
for insn in kernel.instructions]

kernel = kernel.copy(instructions=new_instructions)
kernel = kernel.copy(instructions=new_instructions, domains=new_domains)

return kernel

Expand Down
15 changes: 15 additions & 0 deletions test/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,21 @@ def test_rename_inames(ctx_factory):
lp.auto_test_vs_ref(knl, ctx, ref_knl)


def test_rename_inames_with_params(ctx_factory):
ctx = ctx_factory()

knl = lp.make_kernel(
[
"{ [i]: 0<=i<10 }",
"{ [k]: 0<=k<i }",
],
"out[i, k] = 2"
)
ref_knl = knl
knl = lp.rename_inames(knl, ["i"], "j")
lp.auto_test_vs_ref(knl, ctx, ref_knl)


def test_buffer_array_preserves_rev_deps(ctx_factory):
# See https://github.com/inducer/loopy/issues/546
ctx = ctx_factory()
Expand Down

0 comments on commit 19dc291

Please sign in to comment.