Skip to content

Commit

Permalink
[MLIR][AMDGPU] Renaming using suggestions from review.
Browse files Browse the repository at this point in the history
  • Loading branch information
pcf000 committed Sep 30, 2024
1 parent de5a263 commit a62c37c
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,14 +456,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,

/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool isNativeBf8(Chipset chipset, Type type) {
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) ||
(hasOcpFp8(chipset) && type.isFloat8E5M2());
}

/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool isNativeFp8(Chipset chipset, Type type) {
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) ||
(hasOcpFp8(chipset) && type.isFloat8E4M3FN());
}
Expand Down Expand Up @@ -564,38 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}

if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) {
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (isNativeBf8(chipset, sourceBElem))
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
if (isNativeFp8(chipset, sourceBElem))
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (isNativeBf8(chipset, sourceBElem))
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
if (isNativeFp8(chipset, sourceBElem))
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}

if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) {
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (isNativeBf8(chipset, sourceBElem))
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
if (isNativeFp8(chipset, sourceBElem))
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (isNativeBf8(chipset, sourceBElem))
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
if (isNativeFp8(chipset, sourceBElem))
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
Expand Down Expand Up @@ -801,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
if (isNativeBf8(chipset, sourceElemType)) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
} else if (isNativeFp8(chipset, sourceElemType)) {
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
Expand Down Expand Up @@ -836,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());

Value result;
if (isNativeBf8(chipset, resultElemType))
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
else if (isNativeFp8(chipset, resultElemType))
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);

Expand Down Expand Up @@ -871,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());

Value result;
if (isNativeBf8(chipset, resultElemType))
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
else if (isNativeFp8(chipset, resultElemType))
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);

Expand Down

0 comments on commit a62c37c

Please sign in to comment.