diff --git a/loopy/kernel/dependency.py b/loopy/kernel/dependency.py index 5992a1cb1..ca93fc52b 100644 --- a/loopy/kernel/dependency.py +++ b/loopy/kernel/dependency.py @@ -41,76 +41,62 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> LoopKernel: """ new_insns = [] - for iafter, insn_after in enumerate(knl.instructions): - if iafter == 0: new_insns.append(insn_after) else: - - insn_before = knl.instructions[iafter - 1] - shared_inames = insn_after.within_inames & insn_before.within_inames + insn_before = knl.instructions[iafter-1] domain_before = knl.get_inames_domain(insn_before.within_inames) domain_after = knl.get_inames_domain(insn_after.within_inames) - happens_before = isl.Map.from_domain_and_range( - domain_before, domain_after - ) - - for idim in range(happens_before.dim(dim_type.out)): - happens_before = happens_before.set_dim_name( - dim_type.out, idim, - happens_before.get_dim_name(dim_type.out, idim) + "'" - ) - - n_inames_before = happens_before.dim(dim_type.in_) - happens_before_set = happens_before.move_dims( - dim_type.out, 0, - dim_type.in_, 0, - n_inames_before).range() + + shared_inames = insn_before.within_inames & insn_after.within_inames + + happens_after = isl.Map.from_domain_and_range( + domain_before, + domain_after) + + for idim in range(happens_after.dim(dim_type.out)): + happens_after = happens_after.set_dim_name( + dim_type.out, + idim, + happens_after.get_dim_name(dim_type.out, idim) + "'") shared_inames_order_before = [ domain_before.get_dim_name(dim_type.out, idim) for idim in range(domain_before.dim(dim_type.out)) if domain_before.get_dim_name(dim_type.out, idim) - in shared_inames - ] + in shared_inames] + shared_inames_order_after = [ domain_after.get_dim_name(dim_type.out, idim) for idim in range(domain_after.dim(dim_type.out)) if domain_after.get_dim_name(dim_type.out, idim) - in shared_inames - ] + in shared_inames] + assert shared_inames_order_after == shared_inames_order_before shared_inames_order = shared_inames_order_after - affs = isl.affs_from_space(happens_before_set.space) + affs_in = isl.affs_from_space(happens_after.domain().space) + affs_out = isl.affs_from_space(happens_after.range().space) - lex_set = isl.Set.empty(happens_before_set.space) - for iinnermost, innermost_iname in enumerate(shared_inames_order): - - innermost_set = affs[innermost_iname].lt_set( - affs[innermost_iname+"'"] - ) + lex_map = isl.Map.empty(happens_after.space) + for iinnermost, innermost_iname in enumerate(shared_inames): + innermost_map = affs_in[innermost_iname].lt_map( + affs_out[innermost_iname + "'"]) for outer_iname in shared_inames_order[:iinnermost]: - innermost_set = innermost_set & ( - affs[outer_iname].eq_set(affs[outer_iname + "'"]) - ) - - lex_set = lex_set | innermost_set + innermost_map = innermost_map & ( + affs_in[outer_iname].eq_map( + affs_out[outer_iname + "'"])) - lex_map = isl.Map.from_range(lex_set).move_dims( - dim_type.in_, 0, - dim_type.out, 0, - n_inames_before) + lex_map = lex_map | innermost_map - happens_before = happens_before & lex_map + happens_after = happens_after & lex_map new_happens_after = { - insn_before.id: HappensAfter(None, happens_before) - } + insn_before.id: HappensAfter(None, happens_after)} insn_after = insn_after.copy(happens_after=new_happens_after)