-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][AMDGPU] Add OCP FP8 support for new hardware #106160
base: main
Are you sure you want to change the base?
Changes from 4 commits
0272474
1848df4
1b790e3
de5a263
a62c37c
fd0f519
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, | |||||
} | ||||||
} | ||||||
|
||||||
/// Return true if `type` is the E5M2 variant of an 8-bit float that is | ||||||
/// supported by the `_bf8` instructions on the given `chipset`. | ||||||
static bool isNativeBf8(Chipset chipset, Type type) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it's a query about the type, not about the chipset, I'm happier with an "is" form. Since I'm wordy, how about isNativeBf8ForChipset(Type,Chipset)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the distinction. The ISA and the supported types are the property of the chipset, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, yes, but all chipsets have a native BF8 type -- FNUZ for the older ones, OCP for the newer -- so the answer to "has native BF8" is "yes". Rather, we want to know "is this type the same as the chipset's native BF8 type?" Also, in a context in which we're doing isa<T1,T2>() and isFloat(), "is" flows better. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The issue with chains of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation, I see the difference now. Maybe something like this then If you prefer the current name, could you add a comment that explains what the intention is (similar to how you explained it above)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An equivalent way of phrasing it, from where I'm standing, is "is this one of the types that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previous job would have called it "typeIsNativelySupportedBf8ForChipset(Type,Chipset)" -- we liked naming the parameters in the function name -- or "typeIsExpectedBf8ForChipset(Type,Chipset)". As I wrote those out, I think I like "expected" better than "native" or "supported," but I could be persuaded. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm ... But |
||||||
return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) || | ||||||
(hasOcpFp8(chipset) && type.isFloat8E5M2()); | ||||||
} | ||||||
|
||||||
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is | ||||||
/// supported by the `_fp8` instructions on the given `chipset`. | ||||||
static bool isNativeFp8(Chipset chipset, Type type) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) || | ||||||
(hasOcpFp8(chipset) && type.isFloat8E4M3FN()); | ||||||
} | ||||||
|
||||||
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` | ||||||
/// if one exists. This includes checking to ensure the intrinsic is supported | ||||||
/// on the architecture you are compiling for. | ||||||
|
@@ -550,38 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, | |||||
return ROCDL::mfma_f64_4x4x4f64::getOperationName(); | ||||||
} | ||||||
|
||||||
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) { | ||||||
if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) { | ||||||
// Known to be correct because there are no scalar f8 instructions and | ||||||
// because a length mismatch will have been caught by the verifier. | ||||||
Type sourceBElem = | ||||||
cast<VectorType>(mfma.getSourceB().getType()).getElementType(); | ||||||
if (m == 16 && n == 16 && k == 32 && b == 1) { | ||||||
if (sourceBElem.isFloat8E5M2FNUZ()) | ||||||
if (isNativeBf8(chipset, sourceBElem)) | ||||||
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); | ||||||
if (sourceBElem.isFloat8E4M3FNUZ()) | ||||||
if (isNativeFp8(chipset, sourceBElem)) | ||||||
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); | ||||||
} | ||||||
if (m == 32 && n == 32 && k == 16 && b == 1) { | ||||||
if (sourceBElem.isFloat8E5M2FNUZ()) | ||||||
if (isNativeBf8(chipset, sourceBElem)) | ||||||
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); | ||||||
if (sourceBElem.isFloat8E4M3FNUZ()) | ||||||
if (isNativeFp8(chipset, sourceBElem)) | ||||||
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); | ||||||
} | ||||||
} | ||||||
|
||||||
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) { | ||||||
if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) { | ||||||
Type sourceBElem = | ||||||
cast<VectorType>(mfma.getSourceB().getType()).getElementType(); | ||||||
if (m == 16 && n == 16 && k == 32 && b == 1) { | ||||||
if (sourceBElem.isFloat8E5M2FNUZ()) | ||||||
if (isNativeBf8(chipset, sourceBElem)) | ||||||
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); | ||||||
if (sourceBElem.isFloat8E4M3FNUZ()) | ||||||
if (isNativeFp8(chipset, sourceBElem)) | ||||||
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); | ||||||
} | ||||||
if (m == 32 && n == 32 && k == 16 && b == 1) { | ||||||
if (sourceBElem.isFloat8E5M2FNUZ()) | ||||||
if (isNativeBf8(chipset, sourceBElem)) | ||||||
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); | ||||||
if (sourceBElem.isFloat8E4M3FNUZ()) | ||||||
if (isNativeFp8(chipset, sourceBElem)) | ||||||
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); | ||||||
} | ||||||
} | ||||||
|
@@ -757,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( | |||||
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, | ||||||
ConversionPatternRewriter &rewriter) const { | ||||||
Location loc = op.getLoc(); | ||||||
if (chipset.majorVersion != 9 || chipset < kGfx940) | ||||||
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset))) | ||||||
return rewriter.notifyMatchFailure( | ||||||
loc, "Fp8 conversion instructions are not available on target " | ||||||
"architecture and their emulation is not implemented"); | ||||||
|
@@ -787,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( | |||||
} | ||||||
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); | ||||||
Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); | ||||||
if (sourceElemType.isFloat8E5M2FNUZ()) { | ||||||
if (isNativeBf8(chipset, sourceElemType)) { | ||||||
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source, | ||||||
wordSel); | ||||||
} else if (sourceElemType.isFloat8E4M3FNUZ()) { | ||||||
} else if (isNativeFp8(chipset, sourceElemType)) { | ||||||
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source, | ||||||
wordSel); | ||||||
} | ||||||
|
@@ -801,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( | |||||
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, | ||||||
ConversionPatternRewriter &rewriter) const { | ||||||
Location loc = op.getLoc(); | ||||||
if (chipset.majorVersion != 9 || chipset < kGfx940) | ||||||
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset))) | ||||||
return rewriter.notifyMatchFailure( | ||||||
loc, "Fp8 conversion instructions are not available on target " | ||||||
"architecture and their emulation is not implemented"); | ||||||
|
@@ -822,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( | |||||
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); | ||||||
|
||||||
Value result; | ||||||
if (resultElemType.isFloat8E5M2FNUZ()) | ||||||
if (isNativeBf8(chipset, resultElemType)) | ||||||
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB, | ||||||
existing, wordSel); | ||||||
else if (resultElemType.isFloat8E4M3FNUZ()) | ||||||
else if (isNativeFp8(chipset, resultElemType)) | ||||||
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB, | ||||||
existing, wordSel); | ||||||
|
||||||
|
@@ -838,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( | |||||
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, | ||||||
ConversionPatternRewriter &rewriter) const { | ||||||
Location loc = op.getLoc(); | ||||||
if (chipset.majorVersion != 9 || chipset < kGfx940) | ||||||
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset))) | ||||||
return rewriter.notifyMatchFailure( | ||||||
loc, "Fp8 conversion instructions are not available on target " | ||||||
"architecture and their emulation is not implemented"); | ||||||
|
@@ -857,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( | |||||
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); | ||||||
|
||||||
Value result; | ||||||
if (resultElemType.isFloat8E5M2FNUZ()) | ||||||
if (isNativeBf8(chipset, resultElemType)) | ||||||
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch, | ||||||
existing, byteSel); | ||||||
else if (resultElemType.isFloat8E4M3FNUZ()) | ||||||
else if (isNativeFp8(chipset, resultElemType)) | ||||||
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch, | ||||||
existing, byteSel); | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -509,7 +509,9 @@ bool TosaValidation::isValidElementType(Type type) { | |
if (isa<FloatType>(type)) { | ||
if (profile == TosaProfileEnum::BaseInference) | ||
return false; | ||
return type.isF32() || type.isF16() || type.isBF16(); | ||
return type.isF32() || type.isF16() || type.isBF16() || | ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() || | ||
type.isFloat8E4M3FN() || type.isFloat8E5M2(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use isa |
||
} | ||
if (auto intTy = dyn_cast<IntegerType>(type)) { | ||
if (intTy.isUnsigned()) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!