Skip to content

Commit

Permalink
[Triton] Remove upstream bug workaround (NFC) (#5152)
Browse files Browse the repository at this point in the history
Upstream handling of splatted bools in `DenseElementsAttr` was fixed, so
the workaround can be removed when lowering `arith.constant` to
TritonGPU.

Co-authored-by: peterbell10 <[email protected]>
  • Loading branch information
Mogball and peterbell10 authored Nov 14, 2024
1 parent 38f6a6d commit 8bf3ae9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
13 changes: 4 additions & 9 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,10 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
Type retType = getTypeConverter()->convertType(op.getType());
auto retShapedType = cast<ShapedType>(retType);
auto value = dyn_cast<DenseElementsAttr>(adaptor.getValue());
if (dyn_cast<RankedTensorType>(retShapedType)) {
assert(value);
if (value.getElementType().isInteger(1) && value.isSplat())
// Workaround until https://reviews.llvm.org/D133743 is included.
value =
DenseElementsAttr::get(retShapedType, value.getSplatValue<bool>());
else
// This is a hack. We just want to add encoding
value = value.reshape(retShapedType);
if (isa<RankedTensorType>(retShapedType)) {
assert(value && "expected a dense elements attribute");
// This is a hack. We just want to add encoding.
value = value.reshape(retShapedType);
}
addNamedAttrs(rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, retShapedType, value),
Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/triton_to_tritongpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,16 @@ tt.func public @select_op(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
tt.return
}
}

// -----

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} {
tt.func @arith_splat_bool(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// CHECK-LABEL: arith_splat_bool

// Test arith.constant with splatted bool.
// CHECK-NEXT: arith.constant dense<true> : tensor<128xi1, #{{.*}}>
%mask = arith.constant dense<true> : tensor<128xi1>
tt.return
}
}

0 comments on commit 8bf3ae9

Please sign in to comment.