diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h index e5c14c1cb68278..5b071a46f49ed9 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h @@ -49,10 +49,10 @@ struct Chipset { #undef DEFINE_COMP_OPERATOR bool isGfx940() const { - return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50; + return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5; } bool hasOcpFp8() const { - return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12; + return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12; } }; diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 770c5072e2e79d..75cd9a499e61f2 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -771,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - if (chipset.majorVersion != 9 || chipset < kGfx940) + if (!(chipset.isGfx940() || chipset.hasOcpFp8())) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); @@ -815,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - if (chipset.majorVersion != 9 || chipset < kGfx940) + if (!(chipset.isGfx940() || chipset.hasOcpFp8())) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); @@ -852,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - if (chipset.majorVersion != 9 || chipset < kGfx940) + if (!(chipset.isGfx940() || chipset.hasOcpFp8())) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 423069d406f472..542f3ed0043e03 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final struct ExtFOnFloat8RewritePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; + Chipset chipset; + ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) + : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} + LogicalResult match(arith::ExtFOp op) const override; void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; }; @@ -68,6 +72,15 @@ struct TruncfToFloat16RewritePattern final } // end namespace +static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) { + if (chipset.isGfx940()) + return success(elementType.isFloat8E5M2FNUZ() || + elementType.isFloat8E4M3FNUZ()); + if (chipset.hasOcpFp8()) + return success(elementType.isFloat8E5M2() || elementType.isFloat8E4M3FN()); + return failure(); +} + static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter) { if (elementType.isF32()) @@ -86,8 +99,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { return failure(); inType = inVecType.getElementType(); } - return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() || - inType.isFloat8E5M2() || inType.isFloat8E4M3FN()); + return isSupportedFp8(inType, chipset); } void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, @@ -218,10 +230,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { // Conversion between 8-bit floats is not supported with truncation enabled. return failure(); - return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) && - chipset.isGfx940()) || - ((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) && - chipset.hasOcpFp8()))); + return isSupportedFp8(outType, chipset); } void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, @@ -370,7 +379,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns( bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { if (convertFP8Arithmetic) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), chipset); patterns.add(patterns.getContext(), saturateFP8Truncf, chipset); } @@ -389,7 +398,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() { } bool convertFP8Arithmetic = - maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0); + maybeChipset->isGfx940() || maybeChipset->hasOcpFp8(); arith::populateArithToAMDGPUConversionPatterns( patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, *maybeChipset);