Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Cleanup redundant broadcast combine pattern #5167

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 1 addition & 31 deletions lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"

Expand All @@ -18,35 +17,7 @@ namespace mlir::triton {
namespace {

bool isZero(Value val) {
if (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()))
return true;
// broadcast(constant_0)
if (auto bc = val.getDefiningOp<BroadcastOp>()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the input is now canonicalized, we don't need to match against non-canonical forms.

if (matchPattern(bc.getSrc(), m_Zero()) ||
matchPattern(bc.getSrc(), m_AnyZeroFloat()))
return true;
}
return false;
}

bool isBroadcastConstantCombinable(Attribute value) {
if (auto denseValue = dyn_cast<DenseElementsAttr>(value)) {
return denseValue.isSplat();
}
return isa<FloatAttr, IntegerAttr>(value);
}

DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
Value bcast_res) {
auto resType = cast<ShapedType>(bcast_res.getType());
DenseElementsAttr res;
if (auto denseValue = dyn_cast<DenseElementsAttr>(value)) {
res =
DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
} else {
res = DenseElementsAttr::get(resType, value);
}
return res;
return (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()));
}

bool isAddPtrOffsetCombinable(Value first, Value second) {
Expand Down Expand Up @@ -231,7 +202,6 @@ class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
// %}
patterns.add<CombineSelectMaskedLoadPattern>(context);
patterns.add<CombineAddPtrPattern>(context);
patterns.add<CombineBroadcastConstantPattern>(context);
patterns.add<CombineBroadcastMulReducePattern>(context);

if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
Expand Down
7 changes: 0 additions & 7 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,4 @@ def CombineAddPtrPattern : Pat<
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)),
[(Constraint<CPred<"isAddPtrOffsetCombinable($0, $1)">> $idx0, $idx1)]>;

// broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
def CombineBroadcastConstantPattern : Pat<
(TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)),
(Arith_ConstantOp (getConstantValue $value, $bcast_res), (location $bcast_res)),
[(Constraint<CPred<"isBroadcastConstantCombinable($0)">> $value)]>;

#endif
16 changes: 2 additions & 14 deletions lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,6 @@ struct MoveBroadcastAfterElementwisePattern
}
};

template <typename OpType>
class CanonicalizePattern : public OpRewritePattern<OpType> {
public:
explicit CanonicalizePattern(MLIRContext *context)
: OpRewritePattern<OpType>(context) {}

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
return OpType::canonicalize(op, rewriter);
}
};

class ReorderBroadcastPass
: public ::impl::TritonReorderBroadcastBase<ReorderBroadcastPass> {
public:
Expand All @@ -226,8 +214,8 @@ class ReorderBroadcastPass
RewritePatternSet patterns(context);
ModuleOp m = getOperation();

patterns.add<CanonicalizePattern<BroadcastOp>>(context);
patterns.add<CanonicalizePattern<ExpandDimsOp>>(context);
BroadcastOp::getCanonicalizationPatterns(patterns, context);
ExpandDimsOp::getCanonicalizationPatterns(patterns, context);
// elementwise(broadcast(a)) => broadcast(elementwise(a))
patterns.add<MoveBroadcastAfterElementwisePattern>(context);
// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))
Expand Down
12 changes: 12 additions & 0 deletions test/Triton/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
}
} // end module

// -----

// CHECK-LABEL: @fold_broadcast_constant_pattern
tt.func @fold_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
%const = arith.constant dense<1.0> : tensor<8x1xf32>
%bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32>

// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
tt.return %bst_out : tensor<8x2xf32>
}
10 changes: 0 additions & 10 deletions test/Triton/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,6 @@ tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32
tt.return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}

// CHECK-LABEL: @test_combine_broadcast_constant_pattern
tt.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
%const = arith.constant dense<1.0> : tensor<8x1xf32>
%bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32>

// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
tt.return %bst_out : tensor<8x2xf32>
}

// CHECK-LABEL: @test_canonicalize_masked_load_pattern
tt.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def make_ttir(mod, metadata, options):
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_combine(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def make_ttir(mod, metadata, opt):
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_combine(pm)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm swapping these passes around so that the combine pass can assume the input is canonicalized.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense. I wonder if this will prevent some canonicalization to happen as broadcast ops may be in the middle. What kind of canonicalization do you need for this pass?
Worst case if we see problems we can run canonicalization one more time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all the broadcast related things are now in canonicalize and combine is now mainly dot-related things, plus the addptr(addptr(ptr, a), b) -> addptr(ptr, a + b) pattern which shouldn't be effected by broadcasting really. So I think this is a good separation.

I do wonder if it might be worthwhile to add another canonicalize pass though, since LICM and loop unrollng could potentially connect patterns that were previously separated by region boundaries. The reorder broadcast pass could also connect arith patterns that were broken up by broadcasts and/or splats.

Just a hunch though, I don't have any examples in mind.

passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
Expand Down
Loading