-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][AMDGPU] Add OCP FP8 support for new hardware #106160
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-tosa Author: Paul C Fuqua (pcf000) ChangesThis part mostly just allows the new types in places where the other F8 formats were allowed. Full diff: https://github.com/llvm/llvm-project/pull/106160.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e5c1a53f34bf64..04b66cea661afc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
- Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
- VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+ Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
+ VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -405,7 +405,7 @@ def AMDGPU_RawBufferAtomicUminOp :
def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
"The possible permutations for a DPP operation",
- [
+ [
I32EnumAttrCase<"quad_perm", 0>,
I32EnumAttrCase<"row_shl", 1>,
I32EnumAttrCase<"row_shr", 2>,
@@ -419,7 +419,7 @@ def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
I32EnumAttrCase<"row_bcast_15", 10>,
I32EnumAttrCase<"row_bcast_31", 11>
]> {
- let genSpecializedAttr = 0;
+ let genSpecializedAttr = 0;
let cppNamespace = "::mlir::amdgpu";
}
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
VectorOfLengthAndType<[4], [F16]>,
VectorOfLengthAndType<[2, 4], [BF16]>,
VectorOfLengthAndType<[4, 8], [I8]>,
- VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
+ VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index 38e0ebe68f943b..6de12a3d50878b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -18,6 +18,13 @@ struct Chipset {
: majorVersion(majorVersion), minorVersion(minorVersion){};
static FailureOr<Chipset> parse(StringRef name);
+ bool isGfx940() const {
+ return majorVersion == 9 && minorVersion >= 0x40 && majorVersion < 0x50;
+ }
+ bool hasOcpFp8() const {
+ return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
+ }
+
unsigned majorVersion = 0;
unsigned minorVersion = 0;
};
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 809e9448e80abf..9323fdc7dacd6d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -539,40 +539,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
- chipset.minorVersion >= 0x40) {
+ if (destElem.isF32() &&
+ ((sourceElem.isFloat8E5M2FNUZ() && chipset.isGfx940()) ||
+ (sourceElem.isFloat8E5M2() && chipset.hasOcpFp8()))) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
- chipset.minorVersion >= 0x40) {
+ if (destElem.isF32() &&
+ ((sourceElem.isFloat8E4M3FNUZ() && chipset.isGfx940()) ||
+ (sourceElem.isFloat8E4M3FN() && chipset.hasOcpFp8()))) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -762,10 +764,11 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (sourceElemType.isFloat8E5M2FNUZ()) {
+ if (sourceElemType.isFloat8E5M2FNUZ() || sourceElemType.isFloat8E5M2()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+ } else if (sourceElemType.isFloat8E4M3FNUZ() ||
+ sourceElemType.isFloat8E4M3FN()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -797,10 +800,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -832,10 +835,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index d36583c8118ff4..a66c13caa6d0ab 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,8 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+ return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
+ inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +217,11 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+
+ return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
+ chipset.isGfx940()) ||
+ ((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
+ chipset.hasOcpFp8())));
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 3943696364950f..2747eebebefa52 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -271,14 +271,16 @@ LogicalResult MFMAOp::verify() {
}
Type sourceBType = getSourceB().getType();
- if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+ if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ() ||
+ sourceElem.isFloat8E5M2() || sourceElem.isFloat8E4M3FN()) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
- if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+ if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ() &&
+ !sourceBElem.isFloat8E5M2() && !sourceBElem.isFloat8E4M3FN())
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b78c372af77e64..963fd6fd7c0511 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -509,7 +509,9 @@ bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
if (profile == TosaProfileEnum::BaseInference)
return false;
- return type.isF32() || type.isF16() || type.isBF16();
+ return type.isF32() || type.isF16() || type.isBF16() ||
+ type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() ||
+ type.isFloat8E4M3FN() || type.isFloat8E5M2();
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
|
@pcf000 There've been downstream changes that fix bugs here, could we add those to this PR |
7a83d58
to
a37ab7d
Compare
Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins. This commit adds support for these types to the MLIR wrappers around such operations, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware.
those operations were not being converted to the LLVM intrinsics they correspond to because the rewrite patterns were still checking for gfx940+. As part of this, factor out tests for type-match isto isNativeFp8() and isNativeBf8() functions in the AMDGPUToRocdl rewrites. Also, fix a typo in isGfx940() that caused it to be true for gfx950. Finally, test all these OCP format conversions by duplicating the gfx940 tests.
a37ab7d
to
1b790e3
Compare
static bool isNativeBf8(Chipset chipset, Type type) { | ||
return (chipset.isGfx940() && type.isFloat8E5M2FNUZ()) || | ||
(chipset.hasOcpFp8() && type.isFloat8E5M2()); | ||
} | ||
|
||
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is | ||
/// supported by the `_fp8` instructions on the given `chipset`. | ||
static bool isNativeFp8(Chipset chipset, Type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: hasNativeF8
to follow the type naming scheme in llvm/mlir?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... yeah, I can see hasNaviveF8
or isNativeF8
if we switched the argument order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file uses "FP8" extensively, but both as a generic term for Float8 and as the specific for 4.3 (with "BF8" for 5.2). The two isNative functions are doing the latter.
@@ -68,6 +72,15 @@ struct TruncfToFloat16RewritePattern final | |||
|
|||
} // end namespace | |||
|
|||
static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: supportsF8
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we've been using Fp8 in this file? But my memory of having written this code ages ago could be hazy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The various ops use "fp8" as do the saturate parameters. I made this guy isSupportedF8().
return success(elementType.isFloat8E5M2FNUZ() || | ||
elementType.isFloat8E4M3FNUZ()); | ||
if (chipset.hasOcpFp8()) | ||
return success(elementType.isFloat8E5M2() || elementType.isFloat8E4M3FN()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use isa<T1, T2>(elementType)
@@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, | |||
} | |||
} | |||
|
|||
/// Return true if `type` is the E5M2 variant of an 8-bit float that is | |||
/// supported by the `_bf8` instructions on the given `chipset`. | |||
static bool isNativeBf8(Chipset chipset, Type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static bool isNativeBf8(Chipset chipset, Type type) { | |
static bool hasNativeBF8(Chipset chipset, Type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's a query about the type, not about the chipset, I'm happier with an "is" form. Since I'm wordy, how about isNativeBf8ForChipset(Type,Chipset)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the distinction. The ISA and the supported types are the property of the chipset, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, yes, but all chipsets have a native BF8 type -- FNUZ for the older ones, OCP for the newer -- so the answer to "has native BF8" is "yes". Rather, we want to know "is this type the same as the chipset's native BF8 type?" Also, in a context in which we're doing isa<T1,T2>() and isFloat(), "is" flows better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, in a context in which we're doing isa<T1,T2>() and isFloat(), "is" flows better.
The issue with chains of .isType
is that it won't be optimized across multiple types and always performed one by one (with short circuiting). Our casting infra might not do it today, but in order to utilize faster 'group' isa, you have to use it across the codebase in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"is this type the same as the chipset's native BF8 type?"
Thanks for the explanation, I see the difference now. Maybe something like this then isSupportedByNativeBf8(Chipset, Type)
?
If you prefer the current name, could you add a comment that explains what the intention is (similar to how you explained it above)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An equivalent way of phrasing it, from where I'm standing, is "is this one of the types that the _fp8
or _bf8
instructions on the given chipset
will be expecting"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previous job would have called it "typeIsNativelySupportedBf8ForChipset(Type,Chipset)" -- we liked naming the parameters in the function name -- or "typeIsExpectedBf8ForChipset(Type,Chipset)". As I wrote those out, I think I like "expected" better than "native" or "supported," but I could be persuaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm ... isImplementedBf8OnChipset()
?
But Expected
works too
/// Return true if this is an float type (with the specified width). | ||
bool isFloat() const; | ||
bool isFloat(unsigned width) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
|
||
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is | ||
/// supported by the `_fp8` instructions on the given `chipset`. | ||
static bool isNativeFp8(Chipset chipset, Type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static bool isNativeFp8(Chipset chipset, Type type) { | |
static bool hasNativeFp8(Chipset chipset, Type type) { |
return type.isF32() || type.isF16() || type.isBF16() || | ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() || | ||
type.isFloat8E4M3FN() || type.isFloat8E5M2(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use isa
Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins.
This commit adds support for these types to the MLIR wrappers around such operaitons, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware.