-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND] Cleanup redundant broadcast combine pattern #5167
Conversation
The triton-combine pass has a pattern that folds broadcast(cst) -> cst but this is redundant with the existing folder. Also, in the triton-reorder-broadcast pass we create a pattern which calls `Op::canonicalize`, but we can just use `Op::getCanonicalizationPatterns` instead.
07dcfa7
to
3c84c26
Compare
passes.common.add_canonicalizer(pm) | ||
passes.ttir.add_combine(pm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm swapping these passes around so that the combine pass can assume the input is canonicalized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that makes sense. I wonder if this will prevent some canonicalization to happen as broadcast ops may be in the middle. What kind of canonicalization do you need for this pass?
Worst case if we see problems we can run canonicalization one more time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all the broadcast related things are now in canonicalize and combine is now mainly dot-related things, plus the addptr(addptr(ptr, a), b) -> addptr(ptr, a + b)
pattern which shouldn't be effected by broadcasting really. So I think this is a good separation.
I do wonder if it might be worthwhile to add another canonicalize pass though, since LICM and loop unrollng could potentially connect patterns that were previously separated by region boundaries. The reorder broadcast pass could also connect arith patterns that were broken up by broadcasts and/or splats.
Just a hunch though, I don't have any examples in mind.
if (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())) | ||
return true; | ||
// broadcast(constant_0) | ||
if (auto bc = val.getDefiningOp<BroadcastOp>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the input is now canonicalized, we don't need to match against non-canonical forms.
Summary of changes: - Remove `broadcast(cst) -> cst` from the triton-combine pass since it's redundant with the existing folder. - Reorder the triton-combine pass to come after the canonicalize pass, to simplify pattern matching - Cleanup patterns in triton-reorder-broadcast that called `Op::canonicalize` in favor of `Op::getCanonicalizationPatterns`.
Summary of changes:
broadcast(cst) -> cst
from the triton-combine pass since it's redundant with the existing folder.Op::canonicalize
in favor ofOp::getCanonicalizationPatterns
.