-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorization error with predicates depending on vectorized loop #615
Comments
@inducer: I think this is unavoidable for the kernels that @sv2518 is working with. Since changing the logic for emitting predicates in tunit = lp.make_kernel(
"{[i, j]: 0<=i<100 and 0<=j<4}",
"""
for i
for j
<> tmp1[j] = i+j
<> tmp2[j] = 0
if j>2
tmp2[j] = 2 * tmp1[j] {id=w_tmp2}
end
out[i, j] = 2*tmp2[j]
end
end
""", seq_dependencies=True)
tunit = lp.tag_array_axes(tunit, "tmp1,tmp2", "vec")
tunit = lp.tag_inames(tunit, "j:vec")
# {{{ fallback -->
knl = lp.distribute_loops(tunit.default_entrypoint,
"id:w_tmp2",
outer_inames=frozenset("i"))
renamed_j, = knl.id_to_insn["w_tmp2"].within_inames - {"i"}
knl = lp.untag_inames(knl, renamed_j, VectorizeTag)
knl = lp.tag_inames(knl, {renamed_j: "unr"})
# }}}
tunit = tunit.with_kernel(knl)
print(lp.generate_code_v2(tunit).device_code()) With the potential use-case for loop distribution, any opinions on (me) moving forward with it? |
@sv2518: If you want to access this functionality, please checkout the branch |
Thanks! |
|
Oops, yep. Messed up while cherry-picking the commits. Pushed a fix. |
Thanks! Is the fallback code meant to go in our codebase or can this be done in loopy? I added this snippet to PyOP2
but I run into the following error. Did I drive this wrong?
|
I'm not sure I agree. I think we can expect the predicate to be an expression, and thus we can determine that there's a dependency on that iname in the predicates. Based on that, we just need to raise |
We did that in #617. The issue has to do more with the way predicates are emitted in the codegen pipeline. It needs to be taught about UnvectorizableErrors.
I think there is a minor error there. I think it should be like: from loopy.math import Id, Or
cinsn_ids = [cinsn.id
for cinsn in kernel.instructions
if (isinstance(cinsn, lp.CInstruction) and cinsn.predicates)]
cinsn_match = Or(tuple(Id(cinsn_id) for cinsn_id in cinsns_ids))
outer_inames = frozenset([shifted_iname+"_outer"])
kernel = lp.distribute_loops(kernel,
cinsn_match,
outer_inames=outer_inames)
inames_to_untag = [kernel.id_to_insn[cinsn_id].within_inames - outer_inames
for cinsn_id in cinsn_ids]
kernel = lp.untag_inames(kernel, inames_to_untag, VectorizeTag)
kernel = lp.tag_inames(kernel, {iname_to_untag: "unr"
for iname_to_untag in inames_to_untag}) |
Ah yes I got further now, but now I error with
The |
It needed a |
Okay, cool, thanks! Now one step further I run into
|
Sorry, hadn't accounted for CInstruction. Pushed a fix that runs the following snippet as expected: import loopy as lp
from loopy.symbolic import parse
tunit = lp.make_kernel(
"{[i]: 0<=i<4}",
["<> tmp[i] = 0 {id=w_tmp}",
lp.CInstruction(iname_exprs=("i", "i"),
code="break;",
predicates={parse("tmp[i] > n")},
read_variables={"i", "n"},
depends_on=frozenset({"w_tmp"}),
id="break",),
lp.Assignment("out_callee",
"i",
depends_on=frozenset(["break"]))
],
[lp.ValueArg("n", dtype="int32"), ...],
name="circuit_breaker")
knl = tunit.default_entrypoint
knl = lp.tag_inames(knl, "i:vec")
knl = lp.distribute_loops(knl, "id:break", frozenset())
knl = lp.untag_inames(knl, "i_1", lp.VectorizeTag)
knl = lp.tag_inames(knl, "i_1:unr")
tunit = tunit.with_kernel(knl)
print(lp.generate_code_v2(knl).device_code()) |
Fixed as part of #557 |
The following kernel --
generates:
Notice the stray
(j)
in the conditional. A short term solution which is not vectorizing such instructions will be included as a part of #557./cc @sv2518
The text was updated successfully, but these errors were encountered: