Skip to content

Commit

Permalink
[BACKEND] fix bf16 representation in TritonNvidiaGPU and bf16 tl.sort…
Browse files Browse the repository at this point in the history
… bug (#3975)

Since LLVM now support `bf16`, it is not necessary that [represent
`bf16` as
`i16`](#1245 (comment))
in TritonGPUtoLLVM conversion, in which case `bf16` compare makes
mistake as compare is converted to `arith.cmpf` while `i16` is not
compatible with `arith.cmpf`, thus `bf16` compare and `tl.sort` both
report [bug](#3873).

Meanwhile, use of `core.arange` in `_compare_and_swap` causes the
unaligned data type when call `tl.sort` for `bf16`. Data type of `left`
and `right` needs to be casted to `y.dtype` to fix `tl.sort`.

The revision have passed the python tests as below in docker on H100:
```sh
$ sudo pip uninstall pytorch-triton
$ cd triton
$ pip install -e python
$ python -m pytest python/test/unit
# ...
  11309 passed, 1219 skipped, 156 warnings in 3222.26s (0:53:42)
```
However, I cannot build cpp test with the errors:
```sh
$ cd python/build/cmake.linux-x86_64-cpython-3.10/
$ ninja test
[0/1] Re-running CMake...
/bin/bash: line 1: /tmp/pip-build-env-phcu6k1b/overlay/local/lib/python3.10/dist-packages/cmake/data/bin/cmake: No such file or directory
FAILED: build.ninja 
/tmp/pip-build-env-phcu6k1b/overlay/local/lib/python3.10/dist-packages/cmake/data/bin/cmake --regenerate-during-build -S/home/scratch.haoruoc_gpu/repos/triton -B/home/scratch.haoruoc_gpu/repos/triton/python/build/cmake.linux-x86_64-cpython-3.10
ninja: error: rebuilding 'build.ninja': subcommand failed
```
The given path in `build.ninja` does not exist.

Besides, I do not revise AMD backend as I have no access to
corresponding hardware.

---------

Co-authored-by: haoruoc <[email protected]>
  • Loading branch information
horrorChen and haoruoc authored May 29, 2024
1 parent 100e2aa commit 445d5ed
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 33 deletions.
4 changes: 0 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
addConversion([&](mlir::Float8E5M2FNUZType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 16);
});
}

Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
Expand Down
4 changes: 2 additions & 2 deletions python/triton/language/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr):
y = core.reshape(x, shape)
# slice left/right with 'stride' 2**(n_dims - i - 1)
mask = core.arange(0, 2)[None, :, None]
left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)
left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype)
left = core.reshape(left, x.shape)
right = core.reshape(right, x.shape)
# actual compare-and-swap
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: test_local_load_bf16
// CHECK: llvm.extractelement {{.*}} : vector<8xi16>
// CHECK: llvm.extractelement {{.*}} : vector<8xbf16>
tt.func public @test_local_load_bf16() {
%c0_i32 = arith.constant 0 : i32
%19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x1x2048xbf16, #shared, mutable>
Expand Down
11 changes: 9 additions & 2 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,18 @@ struct DotOpMFMAConversionHelper {
int kpack = kWidth / kBase;
SmallVector<Value> results;
auto vecTy = vec_ty(type, kBase);
if (type.isBF16())
vecTy = vec_ty(i16_ty, kBase);
for (int k = 0; k < kpack; ++k) {
Value vec = undef(vecTy);
for (int elemId = 0; elemId < kBase; ++elemId) {
auto val = extract_element(type, rawElems, i32_val(elemId + k * kBase));
vec = insert_element(vecTy, vec, val, i32_val(elemId));
if (type.isBF16()) {
// rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type
auto cast = bitcast(val, i16_ty);
vec = insert_element(vecTy, vec, cast, i32_val(elemId));
} else
vec = insert_element(vecTy, vec, val, i32_val(elemId));
}
if (type.getIntOrFloatBitWidth() == 8) {
if (4 == kBase)
Expand Down Expand Up @@ -329,7 +336,7 @@ struct DotOpMFMAConversionHelper {
if (type.getIntOrFloatBitWidth() == 8) {
vals = extractOperands(rawElems, kWidth, kBase, i8_ty);
} else if (type.isBF16()) {
vals = extractOperands(rawElems, kWidth, kBase, i16_ty);
vals = extractOperands(rawElems, kWidth, kBase, bf16_ty);
} else {
assert(type.isF16() && "Unsupported data type");
vals = extractOperands(rawElems, kWidth, kBase, f16_ty);
Expand Down
30 changes: 15 additions & 15 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ static Value convertFp32ToBf16(Location loc,
auto as_int32 = bitcast(v, i32_ty);
auto shifted = lshr(i32_ty, as_int32, i32_val(16));
auto truncated = trunc(i16_ty, shifted);
return bitcast(truncated, i16_ty);
return bitcast(truncated, bf16_ty);
}
// Otherwise it is (rounding == RoundingMode::RTNE)
auto as_uint32 = bitcast(v, i32_ty);
Expand All @@ -335,7 +335,7 @@ static Value convertFp32ToBf16(Location loc,

auto shifted = lshr(i32_ty, res, i32_val(16));
auto truncated = trunc(i16_ty, shifted);
return truncated;
return bitcast(truncated, bf16_ty);
}

static Value Fp8E5M2FNUZ_to_Fp16_oneValue(Location loc,
Expand Down Expand Up @@ -445,20 +445,20 @@ static SmallVector<Value> Fp8E5M2_to_Bf16(Location loc,
out0 = or_(i32_ty, out0, sign0);
out1 = or_(i32_ty, out1, sign1);

auto bf16x2VecTy = vec_ty(i16_ty, 2);
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
out0 = bitcast(out0, bf16x2VecTy);
out1 = bitcast(out1, bf16x2VecTy);

return {extract_element(i16_ty, out0, i32_val(0)),
extract_element(i16_ty, out0, i32_val(1)),
extract_element(i16_ty, out1, i32_val(0)),
extract_element(i16_ty, out1, i32_val(1))};
return {extract_element(bf16_ty, out0, i32_val(0)),
extract_element(bf16_ty, out0, i32_val(1)),
extract_element(bf16_ty, out1, i32_val(0)),
extract_element(bf16_ty, out1, i32_val(1))};
}

static SmallVector<Value> Bf16_to_Fp8E5M2(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto bf16x2VecTy = vec_ty(i16_ty, 2);
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0));
Expand Down Expand Up @@ -714,22 +714,22 @@ static SmallVector<Value> Fp8E4M3_to_Bf16(Location loc,
Value sign0 = and_(i32_ty, a0, i32_val(0x80008000));
Value sign1 = and_(i32_ty, a1, i32_val(0x80008000));

auto bf16x2VecTy = vec_ty(i16_ty, 2);
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
Value bf16x2Vec0 = or_(i32_ty, sign0, b0);
Value bf16x2Vec1 = or_(i32_ty, sign1, b1);
bf16x2Vec0 = bitcast(bf16x2Vec0, bf16x2VecTy);
bf16x2Vec1 = bitcast(bf16x2Vec1, bf16x2VecTy);

return {extract_element(i16_ty, bf16x2Vec0, i32_val(0)),
extract_element(i16_ty, bf16x2Vec0, i32_val(1)),
extract_element(i16_ty, bf16x2Vec1, i32_val(0)),
extract_element(i16_ty, bf16x2Vec1, i32_val(1))};
return {extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
extract_element(bf16_ty, bf16x2Vec0, i32_val(1)),
extract_element(bf16_ty, bf16x2Vec1, i32_val(0)),
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))};
}

static SmallVector<Value> Bf16_to_Fp8E4M3(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto bf16x2VecTy = vec_ty(i16_ty, 2);
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0));
Expand Down Expand Up @@ -1102,7 +1102,7 @@ static SmallVector<Value> S8_to_Bf16(Location loc,
f32Val = bitcast(f32Val, i32_ty);
auto shifted = lshr(i32_ty, f32Val, i32_val(16));
auto truncated = trunc(i16_ty, shifted);
outValues.push_back(truncated);
outValues.push_back(bitcast(truncated, bf16_ty));
}
return outValues;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ static Value loadA(Value tensor, const SharedMemoryObject &smemObj,
Type elemX2Ty = vec_ty(f16_ty, 2);
Type elemTy = f16_ty;
if (tensorTy.getElementType().isBF16()) {
elemX2Ty = vec_ty(i16_ty, 2);
elemTy = i16_ty;
elemX2Ty = vec_ty(bf16_ty, 2);
elemTy = bf16_ty;
}

// prepare arguments
Expand Down Expand Up @@ -276,8 +276,8 @@ static Value loadB(Value tensor, const SharedMemoryObject &smemObj,
Type elemTy = f16_ty;
Type elemX2Ty = vec_ty(f16_ty, 2);
if (tensorTy.getElementType().isBF16()) {
elemTy = i16_ty;
elemX2Ty = vec_ty(i16_ty, 2);
elemTy = bf16_ty;
elemX2Ty = vec_ty(bf16_ty, 2);
}

SmallVector<Value> ptrB(numPtrB);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ Type getSharedMemTy(Type argType) {
if (argType.isF16())
return type::f16Ty(ctx);
else if (argType.isBF16())
return type::i16Ty(ctx);
return type::bf16Ty(ctx);
else if (argType.isF32())
return type::f32Ty(ctx);
else if (argType.getIntOrFloatBitWidth() == 8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ struct FpToFpOpConversion
cvt(res, operand);
// TODO: This is a hack to get the right type. We should be able to invoke
// the type converter
return builder.launch(rewriter, loc, i16_ty, false);
return builder.launch(rewriter, loc, bf16_ty, false);
}

static Value convertFp32ToFp16(Location loc,
Expand Down Expand Up @@ -574,7 +574,7 @@ struct FMulOpConversion
auto lhs = builder.newOperand(operands[0][0], "h");
auto rhs = builder.newOperand(operands[0][1], "h");
fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return {builder.launch(rewriter, loc, i16_ty, false)};
return {builder.launch(rewriter, loc, bf16_ty, false)};
} else {
return {rewriter.create<LLVM::FMulOp>(loc, elemTy, operands[0][0],
operands[0][1])};
Expand Down Expand Up @@ -604,7 +604,7 @@ struct FAddOpConversion
auto lhs = builder.newOperand(operands[0][0], "h");
auto rhs = builder.newOperand(operands[0][1], "h");
fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return {builder.launch(rewriter, loc, i16_ty, false)};
return {builder.launch(rewriter, loc, bf16_ty, false)};
} else {
return {rewriter.create<LLVM::FAddOp>(loc, elemTy, operands[0][0],
operands[0][1])};
Expand Down Expand Up @@ -634,7 +634,7 @@ struct FSubOpConversion
auto lhs = builder.newOperand(operands[0][0], "h");
auto rhs = builder.newOperand(operands[0][1], "h");
fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return {builder.launch(rewriter, loc, i16_ty, false)};
return {builder.launch(rewriter, loc, bf16_ty, false)};
} else {
return {rewriter.create<LLVM::FSubOp>(loc, elemTy, operands[0][0],
operands[0][1])};
Expand Down

0 comments on commit 445d5ed

Please sign in to comment.