Skip to content

Commit

Permalink
[MLIR][AMDGPU] Clean up and redo after other recent patches here.
Browse files Browse the repository at this point in the history
  • Loading branch information
pcf000 committed Sep 16, 2024
1 parent 1848df4 commit 1b790e3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ struct Chipset {
#undef DEFINE_COMP_OPERATOR

bool isGfx940() const {
return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5;
}
bool hasOcpFp8() const {
return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12;
}
};

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx940)
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand Down Expand Up @@ -815,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx940)
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand Down Expand Up @@ -852,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset < kGfx940)
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand Down
25 changes: 17 additions & 8 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;

Chipset chipset;
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}

LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
Expand Down Expand Up @@ -68,6 +72,15 @@ struct TruncfToFloat16RewritePattern final

} // end namespace

static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
if (chipset.isGfx940())
return success(elementType.isFloat8E5M2FNUZ() ||
elementType.isFloat8E4M3FNUZ());
if (chipset.hasOcpFp8())
return success(elementType.isFloat8E5M2() || elementType.isFloat8E4M3FN());
return failure();
}

static Value castF32To(Type elementType, Value f32, Location loc,
PatternRewriter &rewriter) {
if (elementType.isF32())
Expand All @@ -86,8 +99,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
return isSupportedFp8(inType, chipset);
}

void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Expand Down Expand Up @@ -218,10 +230,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();

return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
chipset.isGfx940()) ||
((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
chipset.hasOcpFp8())));
return isSupportedFp8(outType, chipset);
}

void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Expand Down Expand Up @@ -370,7 +379,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {

if (convertFP8Arithmetic) {
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8Truncf, chipset);
}
Expand All @@ -389,7 +398,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}

bool convertFP8Arithmetic =
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
maybeChipset->isGfx940() || maybeChipset->hasOcpFp8();
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
Expand Down

0 comments on commit 1b790e3

Please sign in to comment.