From a62c37c343249ff059ba1ea113e22b5ca68bc52b Mon Sep 17 00:00:00 2001 From: Paul Fuqua Date: Mon, 30 Sep 2024 12:42:13 -0500 Subject: [PATCH] [MLIR][AMDGPU] Renaming using suggestions from review. --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 616c81f12b9885..4a76739c7a06a8 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -456,14 +456,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, /// Return true if `type` is the E5M2 variant of an 8-bit float that is /// supported by the `_bf8` instructions on the given `chipset`. -static bool isNativeBf8(Chipset chipset, Type type) { +static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) { return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) || (hasOcpFp8(chipset) && type.isFloat8E5M2()); } /// Return true if `type` is the E4M3FN variant of an 8-bit float that is /// supported by the `_fp8` instructions on the given `chipset`. -static bool isNativeFp8(Chipset chipset, Type type) { +static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) { return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) || (hasOcpFp8(chipset) && type.isFloat8E4M3FN()); } @@ -564,38 +564,38 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, return ROCDL::mfma_f64_4x4x4f64::getOperationName(); } - if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) { + if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) { // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { - if (isNativeBf8(chipset, sourceBElem)) + if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); - if (isNativeFp8(chipset, sourceBElem)) + if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { - if (isNativeBf8(chipset, sourceBElem)) + if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); - if (isNativeFp8(chipset, sourceBElem)) + if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); } } - if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) { + if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) { Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { - if (isNativeBf8(chipset, sourceBElem)) + if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); - if (isNativeFp8(chipset, sourceBElem)) + if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { - if (isNativeBf8(chipset, sourceBElem)) + if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); - if (isNativeFp8(chipset, sourceBElem)) + if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); } } @@ -801,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( } Value i32Source = rewriter.create(loc, i32, source); Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); - if (isNativeBf8(chipset, sourceElemType)) { + if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, wordSel); - } else if (isNativeFp8(chipset, sourceElemType)) { + } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, wordSel); } @@ -836,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); Value result; - if (isNativeBf8(chipset, resultElemType)) + if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, sourceA, sourceB, existing, wordSel); - else if (isNativeFp8(chipset, resultElemType)) + else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, sourceA, sourceB, existing, wordSel); @@ -871,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); Value result; - if (isNativeBf8(chipset, resultElemType)) + if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, source, stoch, existing, byteSel); - else if (isNativeFp8(chipset, resultElemType)) + else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, source, stoch, existing, byteSel);