From b311401a6566d8c31d5386ff668133c26f563e44 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 15 Aug 2024 13:17:26 -0700 Subject: [PATCH] pass mlas test for Atype 32 and 16 Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/inc/mlas_qnbit.h | 3 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 152 ++++++++++++++---- onnxruntime/core/mlas/lib/sqnbitgemm.h | 80 ++++++++- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 120 ++++++++++---- .../test/mlas/unittest/test_sqnbitgemm.cpp | 61 ++++--- 5 files changed, 324 insertions(+), 92 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 61be7d7aadbff..d3aad9a9a47ea 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -43,7 +43,7 @@ typedef enum { * @brief Data parameters for float/n-bit quantized int GEMM routine. */ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const MLAS_FP16* A = nullptr; ///< address of A (float32 matrix) + const void* A = nullptr; ///< address of A (float32 matrix) size_t lda = 0; ///< leading dimension of A const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data @@ -85,6 +85,7 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { buffer with at least that many bytes. Otherwise, it may be nullptr. * @param[in] ThreadPool optional thread pool to use */ +template void MLASCALL MlasSQNBitGemmBatch( size_t M, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 530543f284fd5..c4b68e0e50be4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -76,13 +76,13 @@ MlasIsSQNBitGemmAvailable( switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { - return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && + return Dispatch->SQ4BitGemmM1Kernel_CompFp32_ATypeFp32 != nullptr && Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 return (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || - (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); + (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8_ATypeFp32 != nullptr); } default: { return false; @@ -295,7 +295,35 @@ AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t } } -typedef void(SQNBitGemmFn)( +MLAS_FORCEINLINE +void +ConvertFp16ToFp32(const MLAS_FP16* a_row, std::vector& a_row_fp32) +{ + size_t size = a_row_fp32.size(); + size_t i = 0; + + // Process 16 elements at a time using AVX2 + for (; i + 15 < size; i += 16) { + // Load 16 FP16 values into an AVX2 register + __m256i fp16_values = _mm256_loadu_si256(reinterpret_cast(a_row + i)); + + // Convert FP16 values to FP32 + __m256 fp32_values1 = _mm256_cvtph_ps(_mm256_castsi256_si128(fp16_values)); + __m256 fp32_values2 = _mm256_cvtph_ps(_mm256_extracti128_si256(fp16_values, 1)); + + // Store the converted FP32 values into the output vector + _mm256_storeu_ps(a_row_fp32.data() + i, fp32_values1); + _mm256_storeu_ps(a_row_fp32.data() + i + 8, fp32_values2); + } + + // Process any remaining elements + for (; i < size; ++i) { + a_row_fp32[i] = a_row[i].ToFloat(); + } +} + +template +using SQNBitGemmFn = void( size_t BlkLen, size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, @@ -306,6 +334,7 @@ typedef void(SQNBitGemmFn)( size_t RangeCountN ); +template void SQ4BitGemm_CompFp32( const size_t BlkLen, @@ -337,7 +366,7 @@ SQ4BitGemm_CompFp32( const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - const MLAS_FP16* A = DataParams->A + RangeStartM * lda; + const AType* A = (AType*)(DataParams->A) + RangeStartM * lda; const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; @@ -355,7 +384,7 @@ SQ4BitGemm_CompFp32( for (size_t n = 0; n < RangeCountN; n += CountN) { CountN = std::min(RangeCountN - n, size_t{128}); - const MLAS_FP16* a_row = A; + const AType* a_row = A; const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; const std::byte* b_col_zp = @@ -363,11 +392,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - + GetMlasPlatform().SQNBitGemmDispatch->CallSQ4BitGemmM1Kernel_CompFp32_Fn(BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias); if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( DataParams->C, RangeStartM, RangeStartN + n, @@ -393,7 +418,7 @@ SQ4BitGemm_CompFp32( // // Step through each slice of matrix A along the M dimension. // - const MLAS_FP16* a_row = A; + const AType* a_row = A; const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; const std::byte* b_col_zp = @@ -405,15 +430,28 @@ SQ4BitGemm_CompFp32( BlkLen, dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks ); - +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) + std::vector a_row_fp32_v; + float* a_row_fp32 = nullptr; + if constexpr (std::is_same::value) { + a_row_fp32_v.resize(lda * RangeCountM); + ConvertFp16ToFp32(a_row, a_row_fp32_v); + a_row_fp32 = &a_row_fp32_v[0]; + } +#endif size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) - assert(false); - auto RowsHandled = 0; - //auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - // a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true - //); + int64_t RowsHandled = 0; + if constexpr (std::is_same::value) { + RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row_fp32, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); + } else { + RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); + } #else auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); #endif @@ -430,11 +468,17 @@ SQ4BitGemm_CompFp32( c_blk += ldc * RowsHandled; a_row += lda * RowsHandled; +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) + if constexpr (std::is_same::value) { + a_row_fp32 += lda * RowsHandled; + } +#endif RowsRemaining -= RowsHandled; } } } +template void SQ4BitGemm_CompInt8( const size_t BlkLen, @@ -564,7 +608,8 @@ SQ4BitGemm_CompInt8( } } -typedef void(InitializeWorkspaceFn)( +template +using InitializeWorkspaceFn = void( size_t M, size_t N, size_t K, @@ -576,6 +621,7 @@ typedef void(InitializeWorkspaceFn)( MLAS_THREADPOOL* ThreadPool ); +template void InitializeWorkspace_CompInt8( size_t M, @@ -592,7 +638,7 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(N); const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->GetQuantizeARowComputeBlkSum_CompInt8_Fn(); const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -614,7 +660,7 @@ InitializeWorkspace_CompInt8( } else { MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; - const MLAS_FP16* ARowPtr = data.A; + const AType* ARowPtr = static_cast(data.A); void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); @@ -633,22 +679,47 @@ InitializeWorkspace_CompInt8( } struct Operations { - InitializeWorkspaceFn* InitializeWorkspace = nullptr; - SQNBitGemmFn* SQNBitGemm = nullptr; + InitializeWorkspaceFn* InitializeWorkspace_ATypeFp32 = nullptr; + InitializeWorkspaceFn* InitializeWorkspace_ATypeFp16 = nullptr; + template + InitializeWorkspaceFn* + GetInitializeWorkspaceFn() const { + if constexpr (std::is_same::value) { + return InitializeWorkspace_ATypeFp16; + } else { + return InitializeWorkspace_ATypeFp32; + } + } + SQNBitGemmFn* SQNBitGemm_ATypeFp32 = nullptr; + SQNBitGemmFn* SQNBitGemm_ATypeFp16 = nullptr; + template + SQNBitGemmFn* + GetSQNBitGemmFn() const + { + if constexpr (std::is_same::value) { + return SQNBitGemm_ATypeFp16; + } else { + return SQNBitGemm_ATypeFp32; + } + } }; constexpr auto OperationMap = []() { std::array ops; - ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; + ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm_ATypeFp32 = SQ4BitGemm_CompFp32; + ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm_ATypeFp16 = SQ4BitGemm_CompFp32; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; + ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace_ATypeFp32 = InitializeWorkspace_CompInt8; + ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace_ATypeFp16 = InitializeWorkspace_CompInt8; + ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm_ATypeFp32 = SQ4BitGemm_CompInt8; + ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm_ATypeFp16 = SQ4BitGemm_CompInt8; return ops; }(); } // namespace +template void MLASCALL MlasSQNBitGemmBatch( const size_t M, @@ -679,14 +750,14 @@ MlasSQNBitGemmBatch( const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; + if (const auto InitializeWorkspaceOperation = OperationMap[Variant].GetInitializeWorkspaceFn(); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool ); } - const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + const auto ComputeOperation = OperationMap[Variant].GetSQNBitGemmFn(); const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -779,3 +850,30 @@ MlasSQNBitGemmBatch( } }); } + +// Explicit template instantiations +template void MLASCALL MlasSQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); + +template void MLASCALL MlasSQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 720d5d65550fd..dac7794a2f58c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -184,9 +184,10 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. * @param Bias Bias vector of length N. */ - typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)( + template + using SQ4BitGemmM1Kernel_CompFp32_Fn = void( size_t BlkLen, - const MLAS_FP16* A, + const AType* A, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -197,7 +198,61 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { const float* Bias ); - SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr; + SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32_ATypeFp32 = nullptr; + SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32_ATypeFp16 = nullptr; + + template + SQ4BitGemmM1Kernel_CompFp32_Fn* + GetSQ4BitGemmM1Kernel_CompFp32_Fn() + { + if constexpr (std::is_same::value) { + return SQ4BitGemmM1Kernel_CompFp32_ATypeFp16; + } else { + return SQ4BitGemmM1Kernel_CompFp32_ATypeFp32; + } + } + + template + void CallSQ4BitGemmM1Kernel_CompFp32_Fn( + size_t BlkLen, + const AType* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ) const { + if constexpr (std::is_same::value) { + SQ4BitGemmM1Kernel_CompFp32_ATypeFp16( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompFp32_ATypeFp32( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } /** * @brief Dequantize B into the format expected by the Sgemm kernel. @@ -328,13 +383,26 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; - typedef void(QuantizeARowComputeBlkSum_CompInt8_Fn)( + template + using QuantizeARowComputeBlkSum_CompInt8_Fn = void( size_t BlkLen, - const MLAS_FP16* A, + const T* A, size_t CountK, std::byte* QuantA, float* QuantAScale, float* AScaledGroupSum // scale_k * Sum_blklen(a_i) ); - QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; + + QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8_ATypeFp32 = nullptr; + QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8_ATypeFp16 = nullptr; + template + QuantizeARowComputeBlkSum_CompInt8_Fn* + GetQuantizeARowComputeBlkSum_CompInt8_Fn() const + { + if constexpr (std::is_same::value) { + return QuantizeARowComputeBlkSum_CompInt8_ATypeFp16; + } else { + return QuantizeARowComputeBlkSum_CompInt8_ATypeFp32; + } + } }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index fdd3182f75f25..f845e46794701 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -634,11 +634,11 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( return CountM; } -template +template MLAS_FORCEINLINE void ComputeDotProducts_BlkLen16_CompFp32_avx2( size_t BlkLen, - const MLAS_FP16* ARowPtr, + const AType* ARowPtr, const std::byte* QuantBDataColPtr, const float* QuantBScaleColPtr, const std::byte* QuantBZeroPointColPtr, @@ -703,9 +703,19 @@ ComputeDotProducts_BlkLen16_CompFp32_avx2( // Load A row vectors int n_to_read = std::min(kklen, 8); - __m256 av_lo = load_float16_n_avx2(ARowPtr + k + kk, n_to_read); + __m256 av_lo; + if constexpr (std::is_same::value) { + av_lo = load_float16_n_avx2(ARowPtr + k + kk, n_to_read); + } else { + av_lo = load_float_n_avx2(ARowPtr + k + kk, n_to_read); + } n_to_read = std::min(kklen - 8, 8); - __m256 av_hi = load_float16_n_avx2(ARowPtr + k + kk + 8, n_to_read); + __m256 av_hi; + if constexpr (std::is_same::value) { + av_hi = load_float16_n_avx2(ARowPtr + k + kk + 8, n_to_read); + } else { + av_hi = load_float_n_avx2(ARowPtr + k + kk + 8, n_to_read); + } UnrolledLoop([&](size_t i) { // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | @@ -785,10 +795,10 @@ ComputeDotProducts_BlkLen16_CompFp32_avx2( } // TODO: flow MlasQ4GemmKernelBlkLen16Avx512f to improve perf -template +template void SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( - const MLAS_FP16* A, + const AType* A, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -803,7 +813,7 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( constexpr size_t BlkBitWidth4 = 4; constexpr size_t NCols4 = 4; - const MLAS_FP16* ARowPtr = A; + const AType* ARowPtr = A; float* CRowPtr = C; const size_t BlockCountK = BlockStrideQuantB; @@ -823,7 +833,7 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( int64_t nblk = static_cast(CountN) - NCols4; while (nblk >= 0) { - ComputeDotProducts_BlkLen16_CompFp32_avx2( + ComputeDotProducts_BlkLen16_CompFp32_avx2( BlkLen16, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -847,7 +857,7 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( // left over columns less than `NCols`? nblk += NCols4; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkLen16_CompFp32_avx2<1, HasZeroPoint>( + ComputeDotProducts_BlkLen16_CompFp32_avx2( BlkLen16, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -868,11 +878,11 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( } // TODO: flow MlasQ4GemmKernelBlkLen32PlusAvx512f to improve perf -template +template MLAS_FORCEINLINE void ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( size_t BlkLen, - const MLAS_FP16* ARowPtr, + const AType* ARowPtr, const std::byte* QuantBDataColPtr, const float* QuantBScaleColPtr, const std::byte* QuantBZeroPointColPtr, @@ -940,16 +950,36 @@ ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( // Load 4 float8 from A int n_to_read = std::min(kklen, 8); - __m256 av0_8_ps = load_float16_n_avx2(ARowPtr + k + kk, n_to_read); + __m256 av0_8_ps; + if constexpr (std::is_same::value) { + av0_8_ps = load_float16_n_avx2(ARowPtr + k + kk, n_to_read); + } else { + av0_8_ps = load_float_n_avx2(ARowPtr + k + kk, n_to_read); + } n_to_read = std::min(kklen - 8, 8); - __m256 av1_8_ps = load_float16_n_avx2(ARowPtr + k + kk + 8, n_to_read); + __m256 av1_8_ps; + if constexpr (std::is_same::value) { + av1_8_ps = load_float16_n_avx2(ARowPtr + k + kk + 8, n_to_read); + } else { + av1_8_ps = load_float_n_avx2(ARowPtr + k + kk + 8, n_to_read); + } n_to_read = std::min(kklen - 16, 8); - __m256 av2_8_ps = load_float16_n_avx2(ARowPtr + k + kk + 16, n_to_read); + __m256 av2_8_ps; + if constexpr (std::is_same::value) { + av2_8_ps = load_float16_n_avx2(ARowPtr + k + kk + 16, n_to_read); + } else { + av2_8_ps = load_float_n_avx2(ARowPtr + k + kk + 16, n_to_read); + } n_to_read = std::min(kklen - 24, 8); - __m256 av3_8_ps = load_float16_n_avx2(ARowPtr + k + kk + 24, n_to_read); + __m256 av3_8_ps; + if constexpr (std::is_same::value) { + av3_8_ps = load_float16_n_avx2(ARowPtr + k + kk + 24, n_to_read); + } else { + av3_8_ps = load_float_n_avx2(ARowPtr + k + kk + 24, n_to_read); + } if constexpr (IsBlkLen64Layout) { count_half_4 = 4 * (int)((kk % (2 * SubBlkLen32)) / SubBlkLen32); @@ -1045,11 +1075,11 @@ ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( } // TODO: flow MlasQ4GemmKernelBlkLen16Avx512f to improve perf -template +template void SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( size_t BlkLen, - const MLAS_FP16* A, + const AType* A, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -1063,7 +1093,7 @@ SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( constexpr size_t BlkBitWidth4 = 4; constexpr size_t NCols4 = 4; - const MLAS_FP16* ARowPtr = A; + const AType* ARowPtr = A; float* CRowPtr = C; const size_t BlockCountK = BlockStrideQuantB; @@ -1083,14 +1113,14 @@ SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( int64_t nblk = static_cast(CountN) - NCols4; while (nblk >= 0) { if (BlkLen >= 64) { - ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr ); } else { - ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -1116,14 +1146,14 @@ SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( nblk += NCols4; for (int64_t n = 0; n < nblk; ++n) { if (BlkLen >= 64) { - ComputeDotProducts_BlkLen32Plus_CompFp32_avx2<1, HasZeroPoint, true>( + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr ); } else { - ComputeDotProducts_BlkLen32Plus_CompFp32_avx2<1, HasZeroPoint, false>( + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -1144,10 +1174,11 @@ SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( } } +template MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32_avx2( size_t BlkLen, - const MLAS_FP16* A, + const AType* A, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -1160,7 +1191,7 @@ SQ4BitGemmM1Kernel_CompFp32_avx2( { if (BlkLen == 16) { if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( A, QuantBData, QuantBScale, @@ -1172,7 +1203,7 @@ SQ4BitGemmM1Kernel_CompFp32_avx2( Bias ); } else { - SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( A, QuantBData, QuantBScale, @@ -1186,7 +1217,7 @@ SQ4BitGemmM1Kernel_CompFp32_avx2( } } else { if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( BlkLen, A, QuantBData, @@ -1199,7 +1230,7 @@ SQ4BitGemmM1Kernel_CompFp32_avx2( Bias ); } else { - SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( BlkLen, A, QuantBData, @@ -1215,10 +1246,11 @@ SQ4BitGemmM1Kernel_CompFp32_avx2( } } +template void MLASCALL QuantizeARow_CompInt8_avx2( size_t BlkLen, - const MLAS_FP16* A, + const AType* A, size_t CountK, std::byte* QuantA, float* QuantAScale, @@ -1239,7 +1271,12 @@ QuantizeARow_CompInt8_avx2( for (size_t kk = 0; kk < step; kk += 8) { const int klen = std::min(8, (int)(step - kk)); - __m256 v0 = load_float16_n_avx2(A + k + kk, klen); + __m256 v0; + if constexpr (std::is_same::value) { + v0 = load_float16_n_avx2(A + k + kk, klen); + } else { + v0 = load_float_n_avx2(A + k + kk, klen); + } // Compute max(abs(e)) for the block maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, v0)); @@ -1264,7 +1301,12 @@ QuantizeARow_CompInt8_avx2( const int klen = std::min(16, (int)(step - kk)); int n_to_read = std::min(klen, 8); - __m256 v0 = load_float16_n_avx2(A + k + kk, n_to_read); + __m256 v0; + if constexpr (std::is_same::value) { + v0 = load_float16_n_avx2(A + k + kk, n_to_read); + } else { + v0 = load_float_n_avx2(A + k + kk, n_to_read); + } v0 = _mm256_mul_ps(v0, mul); v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); @@ -1273,7 +1315,11 @@ QuantizeARow_CompInt8_avx2( if (n_to_read <= 0) { v1 = _mm256_setzero_ps(); } else { - v1 = load_float16_n_avx2(A + k + kk + 8, n_to_read); + if constexpr (std::is_same::value) { + v1 = load_float16_n_avx2(A + k + kk + 8, n_to_read); + } else { + v1 = load_float_n_avx2(A + k + kk + 8, n_to_read); + } v1 = _mm256_mul_ps(v1, mul); v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); } @@ -1335,11 +1381,13 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; - d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.SQ4BitGemmM1Kernel_CompFp32_ATypeFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.SQ4BitGemmM1Kernel_CompFp32_ATypeFp16 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; - d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + d.QuantizeARowComputeBlkSum_CompInt8_ATypeFp32 = QuantizeARow_CompInt8_avx2; + d.QuantizeARowComputeBlkSum_CompInt8_ATypeFp16 = QuantizeARow_CompInt8_avx2; return d; }(); @@ -1354,11 +1402,13 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; - d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.SQ4BitGemmM1Kernel_CompFp32_ATypeFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.SQ4BitGemmM1Kernel_CompFp32_ATypeFp16 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; - d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + d.QuantizeARowComputeBlkSum_CompInt8_ATypeFp32 = QuantizeARow_CompInt8_avx2; + d.QuantizeARowComputeBlkSum_CompInt8_ATypeFp16 = QuantizeARow_CompInt8_avx2; return d; }(); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index c287cda864b65..7b65786ce96e2 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -16,9 +16,13 @@ Module Name: #include "test_util.h" #include "test_fp16.h" +#include "core/framework/float16.h" #include "mlas_q4.h" +#include "mlas.h" #include "mlas_qnbit.h" +#include + static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) { switch (ComputeType) { case CompFp32: @@ -34,10 +38,10 @@ static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE Compu * @brief Test class for n-bit int block quantized GEMM * Note: only 2-D matmul supported for now */ -template +template class MlasSQNBitGemmTest : public MlasTestBase { private: - MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferA; MatrixGuardBuffer BufferQuantAData; MatrixGuardBuffer BufferQuantAScale; MatrixGuardBuffer BufferB; @@ -54,7 +58,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { void CallGemm(size_t M, size_t N, size_t K, - const MLFp16* A, + const AType* A, size_t lda, const void* /*QuantBData*/, const void* PackedQuantBDataWorkspace, @@ -67,7 +71,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, MLAS_THREADPOOL* Threadpool) { MLAS_SQNBIT_GEMM_DATA_PARAMS params; - params.A = reinterpret_cast(A); + params.A = A; params.lda = lda; params.Bias = Bias; params.C = C; @@ -82,16 +86,20 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); + if constexpr (std::is_same::value) { + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); + } else { + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); + } } - void QuantizeA(size_t M, size_t K, const MLFp16* A, int8_t* QuantAData, float* QuantAScale) { + void QuantizeA(size_t M, size_t K, const AType* A, int8_t* QuantAData, float* QuantAScale) { const size_t BlockCountK = (K + BlkLen - 1) / BlkLen; const size_t lda = K; for (size_t m = 0; m < M; ++m) { for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) { const size_t local_blk_len = std::min(K - k, BlkLen); - MLFp16 blk_a[BlkLen]{}; + AType blk_a[BlkLen]{}; std::copy_n(A + m * lda + k, local_blk_len, blk_a); float amax = 0.0f; // max of absolute values of A block @@ -118,10 +126,11 @@ class MlasSQNBitGemmTest : public MlasTestBase { } } + template void CallReferenceGemm_CompInt8(size_t M, size_t N, size_t K, - const MLFp16* A, + const AType* A, const uint8_t* QuantBData, const float* QuantBScale, const uint8_t* QuantBZeroPoint, @@ -168,10 +177,11 @@ class MlasSQNBitGemmTest : public MlasTestBase { } } + template void CallReferenceGemm_CompFp32(size_t M, size_t N, size_t K, - const MLFp16* A, + const AType* A, const uint8_t* QuantBData, const float* QuantBScale, const uint8_t* QuantBZeroPoint, @@ -185,7 +195,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++) { - const MLFp16* a = A + m * K; + const AType* a = A + m * K; const float* b = DequantizedBData + n * K; float* c = C + (m * N) + n; @@ -206,7 +216,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { bool WithThreadpool, bool Symmetric, bool WithBias) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; - const MLFp16* A = BufferA.GetBuffer(K * M); + const AType* A = BufferA.GetBuffer(K * M); const float* B = BufferB.GetBuffer(N * K); @@ -322,7 +332,8 @@ class MlasSQNBitGemmTest : public MlasTestBase { static const char* GetTestSuiteName() { static std::string suite_name = std::string("SQNBitGemm") + "BlkBitWidth" + std::to_string(BlkBitWidth) + - "BlkLen" + std::to_string(BlkLen); + "BlkLen" + std::to_string(BlkLen) + + "AType" + std::string(std::is_same::value ? "Fp32" : "Fp16"); return suite_name.c_str(); } }; @@ -330,8 +341,8 @@ class MlasSQNBitGemmTest : public MlasTestBase { // // Short Execute() test helper to register each test separately by all parameters. // -template -class SQNBitGemmShortExecuteTest : public MlasTestFixture> { +template +class SQNBitGemmShortExecuteTest : public MlasTestFixture> { public: explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, @@ -346,7 +357,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture>::mlas_tester->Test( + MlasTestFixture>::mlas_tester->Test( M_, N_, K_, ComputeType_, WithThreadpool_, Symmetric_, WithBias_); } @@ -361,18 +372,19 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::value ? "Fp32" : "Fp16") << "/computeType" << ComputeTypeName(ComputeType); auto test_name = ss.str(); testing::RegisterTest( - MlasSQNBitGemmTest::GetTestSuiteName(), + MlasSQNBitGemmTest::GetTestSuiteName(), test_name.c_str(), nullptr, test_name.c_str(), __FILE__, __LINE__, // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { + [=]() -> MlasTestFixture>* { return new SQNBitGemmShortExecuteTest( M, N, K, ComputeType, WithThreadpool, Symmetric, WithBias); }); @@ -430,14 +442,15 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture static size_t SQNBitGemmRegisterAllShortExecuteTests() { size_t count = 0; - count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<4, 64>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<4, 128>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<4, 256>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest::RegisterShortExecuteTests(); return count; } @@ -445,7 +458,9 @@ static size_t SQNBitGemmRegisterAllShortExecuteTests() { static UNUSED_VARIABLE bool added_to_main = AddTestRegister( [](bool is_short_execute) -> size_t { if (is_short_execute) { - return SQNBitGemmRegisterAllShortExecuteTests(); + // using MLAS_FP16 = onnxruntime::MLFloat16; + return SQNBitGemmRegisterAllShortExecuteTests() + SQNBitGemmRegisterAllShortExecuteTests(); + //return SQNBitGemmRegisterAllShortExecuteTests(); } return 0; });