Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][AMDGPU] Add OCP FP8 support for new hardware #106160

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :

def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
Expand All @@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
Expand All @@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
Expand Down Expand Up @@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
VectorOfLengthAndType<[4], [F16]>,
VectorOfLengthAndType<[2, 4], [BF16]>,
VectorOfLengthAndType<[4, 8], [I8]>,
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ struct Chipset {
#undef DEFINE_COMP_OPERATOR
};

inline bool isGfx940Series(const Chipset &chipset) {
return chipset.majorVersion == 9 && chipset.minorVersion == 4;
}
inline bool hasOcpFp8(const Chipset &chipset) {
return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
chipset.majorVersion >= 12;
}

} // namespace mlir::amdgpu

#endif
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ class Type {
bool isF64() const;
bool isF80() const;
bool isF128() const;
/// Return true if this is an float type (with the specified width).
bool isFloat() const;
bool isFloat(unsigned width) const;
Comment on lines +143 to +145
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


/// Return true if this is an integer type (with the specified width).
bool isInteger() const;
Expand Down
52 changes: 33 additions & 19 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
}
}

/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool isNativeBf8(Chipset chipset, Type type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static bool isNativeBf8(Chipset chipset, Type type) {
static bool hasNativeBF8(Chipset chipset, Type type) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a query about the type, not about the chipset, I'm happier with an "is" form. Since I'm wordy, how about isNativeBf8ForChipset(Type,Chipset)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the distinction. The ISA and the supported types are the property of the chipset, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, yes, but all chipsets have a native BF8 type -- FNUZ for the older ones, OCP for the newer -- so the answer to "has native BF8" is "yes". Rather, we want to know "is this type the same as the chipset's native BF8 type?" Also, in a context in which we're doing isa<T1,T2>() and isFloat(), "is" flows better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, in a context in which we're doing isa<T1,T2>() and isFloat(), "is" flows better.

The issue with chains of .isType is that it won't be optimized across multiple types and always performed one by one (with short circuiting). Our casting infra might not do it today, but in order to utilize faster 'group' isa, you have to use it across the codebase in the first place.

Copy link
Member

@kuhar kuhar Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"is this type the same as the chipset's native BF8 type?"

Thanks for the explanation, I see the difference now. Maybe something like this then isSupportedByNativeBf8(Chipset, Type)?

If you prefer the current name, could you add a comment that explains what the intention is (similar to how you explained it above)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An equivalent way of phrasing it, from where I'm standing, is "is this one of the types that the _fp8 or _bf8 instructions on the given chipset will be expecting"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous job would have called it "typeIsNativelySupportedBf8ForChipset(Type,Chipset)" -- we liked naming the parameters in the function name -- or "typeIsExpectedBf8ForChipset(Type,Chipset)". As I wrote those out, I think I like "expected" better than "native" or "supported," but I could be persuaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm ... isImplementedBf8OnChipset()?

But Expected works too

return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) ||
(hasOcpFp8(chipset) && type.isFloat8E5M2());
}

/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool isNativeFp8(Chipset chipset, Type type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static bool isNativeFp8(Chipset chipset, Type type) {
static bool hasNativeFp8(Chipset chipset, Type type) {

return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) ||
(hasOcpFp8(chipset) && type.isFloat8E4M3FN());
}

/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
Expand Down Expand Up @@ -550,38 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}

if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (isNativeBf8(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (isNativeFp8(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (isNativeBf8(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (isNativeFp8(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}

if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (isNativeBf8(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (isNativeFp8(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (isNativeBf8(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (isNativeFp8(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
Expand Down Expand Up @@ -757,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx940)
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand Down Expand Up @@ -787,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
if (sourceElemType.isFloat8E5M2FNUZ()) {
if (isNativeBf8(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
} else if (isNativeFp8(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
Expand All @@ -801,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx940)
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand All @@ -822,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());

Value result;
if (resultElemType.isFloat8E5M2FNUZ())
if (isNativeBf8(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
else if (resultElemType.isFloat8E4M3FNUZ())
else if (isNativeFp8(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);

Expand All @@ -838,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx940)
if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand All @@ -857,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());

Value result;
if (resultElemType.isFloat8E5M2FNUZ())
if (isNativeBf8(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
else if (resultElemType.isFloat8E4M3FNUZ())
else if (isNativeFp8(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);

Expand Down
21 changes: 17 additions & 4 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;

Chipset chipset;
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}

LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
Expand Down Expand Up @@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final

} // end namespace

static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
if (isGfx940Series(chipset))
return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
if (hasOcpFp8(chipset))
return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
return failure();
}

static Value castF32To(Type elementType, Value f32, Location loc,
PatternRewriter &rewriter) {
if (elementType.isF32())
Expand All @@ -86,7 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
return isSupportedF8(inType, chipset);
}

void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Expand Down Expand Up @@ -216,7 +228,8 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());

return isSupportedF8(outType, chipset);
}

void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Expand Down Expand Up @@ -365,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {

if (convertFP8Arithmetic) {
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8Truncf, chipset);
}
Expand All @@ -384,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}

bool convertFP8Arithmetic =
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
isGfx940Series(*maybeChipset) || hasOcpFp8(*maybeChipset);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
}

Type sourceBType = getSourceB().getType();
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
if (sourceElem.isFloat(8)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
if (!sourceBElem.isFloat(8))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,9 @@ bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
if (profile == TosaProfileEnum::BaseInference)
return false;
return type.isF32() || type.isF16() || type.isBF16();
return type.isF32() || type.isF16() || type.isBF16() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() ||
type.isFloat8E4M3FN() || type.isFloat8E5M2();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use isa

}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }

bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }

/// Return true if this is an integer type with the specified width.
bool Type::isFloat(unsigned width) const {
if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
return fltTy.getWidth() == width;
return false;
}

bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }

bool Type::isInteger() const { return llvm::isa<IntegerType>(*this); }
Expand Down
Loading
Loading