Skip to content

Commit

Permalink
[BACKEND] Cleanup redundant broadcast combine pattern (triton-lang#5167)
Browse files Browse the repository at this point in the history
Summary of changes:
- Remove `broadcast(cst) -> cst` from the triton-combine pass since it's
redundant with the existing folder.
- Reorder the triton-combine pass to come after the canonicalize pass,
to simplify pattern matching
- Cleanup patterns in triton-reorder-broadcast that called
`Op::canonicalize` in favor of `Op::getCanonicalizationPatterns`.
  • Loading branch information
peterbell10 authored and hmalgewatta committed Nov 15, 2024
1 parent f2ce541 commit a9a7a04
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 64 deletions.
32 changes: 1 addition & 31 deletions lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"

Expand All @@ -18,35 +17,7 @@ namespace mlir::triton {
namespace {

bool isZero(Value val) {
if (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()))
return true;
// broadcast(constant_0)
if (auto bc = val.getDefiningOp<BroadcastOp>()) {
if (matchPattern(bc.getSrc(), m_Zero()) ||
matchPattern(bc.getSrc(), m_AnyZeroFloat()))
return true;
}
return false;
}

bool isBroadcastConstantCombinable(Attribute value) {
if (auto denseValue = dyn_cast<DenseElementsAttr>(value)) {
return denseValue.isSplat();
}
return isa<FloatAttr, IntegerAttr>(value);
}

DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
Value bcast_res) {
auto resType = cast<ShapedType>(bcast_res.getType());
DenseElementsAttr res;
if (auto denseValue = dyn_cast<DenseElementsAttr>(value)) {
res =
DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
} else {
res = DenseElementsAttr::get(resType, value);
}
return res;
return (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()));
}

bool isAddPtrOffsetCombinable(Value first, Value second) {
Expand Down Expand Up @@ -231,7 +202,6 @@ class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
// %}
patterns.add<CombineSelectMaskedLoadPattern>(context);
patterns.add<CombineAddPtrPattern>(context);
patterns.add<CombineBroadcastConstantPattern>(context);
patterns.add<CombineBroadcastMulReducePattern>(context);

if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
Expand Down
7 changes: 0 additions & 7 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,4 @@ def CombineAddPtrPattern : Pat<
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)),
[(Constraint<CPred<"isAddPtrOffsetCombinable($0, $1)">> $idx0, $idx1)]>;

// broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
def CombineBroadcastConstantPattern : Pat<
(TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)),
(Arith_ConstantOp (getConstantValue $value, $bcast_res), (location $bcast_res)),
[(Constraint<CPred<"isBroadcastConstantCombinable($0)">> $value)]>;

#endif
16 changes: 2 additions & 14 deletions lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,6 @@ struct MoveBroadcastAfterElementwisePattern
}
};

template <typename OpType>
class CanonicalizePattern : public OpRewritePattern<OpType> {
public:
explicit CanonicalizePattern(MLIRContext *context)
: OpRewritePattern<OpType>(context) {}

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
return OpType::canonicalize(op, rewriter);
}
};

class ReorderBroadcastPass
: public ::impl::TritonReorderBroadcastBase<ReorderBroadcastPass> {
public:
Expand All @@ -226,8 +214,8 @@ class ReorderBroadcastPass
RewritePatternSet patterns(context);
ModuleOp m = getOperation();

patterns.add<CanonicalizePattern<BroadcastOp>>(context);
patterns.add<CanonicalizePattern<ExpandDimsOp>>(context);
BroadcastOp::getCanonicalizationPatterns(patterns, context);
ExpandDimsOp::getCanonicalizationPatterns(patterns, context);
// elementwise(broadcast(a)) => broadcast(elementwise(a))
patterns.add<MoveBroadcastAfterElementwisePattern>(context);
// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))
Expand Down
12 changes: 12 additions & 0 deletions test/Triton/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
}
} // end module

// -----

// CHECK-LABEL: @fold_broadcast_constant_pattern
tt.func @fold_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
%const = arith.constant dense<1.0> : tensor<8x1xf32>
%bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32>

// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
tt.return %bst_out : tensor<8x2xf32>
}
10 changes: 0 additions & 10 deletions test/Triton/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,6 @@ tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32
tt.return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}

// CHECK-LABEL: @test_combine_broadcast_constant_pattern
tt.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
%const = arith.constant dense<1.0> : tensor<8x1xf32>
%bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32>

// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
tt.return %bst_out : tensor<8x2xf32>
}

// CHECK-LABEL: @test_canonicalize_masked_load_pattern
tt.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def make_ttir(mod, metadata, options):
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_combine(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def make_ttir(mod, metadata, opt):
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_combine(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
Expand Down

0 comments on commit a9a7a04

Please sign in to comment.