Skip to content

Commit

Permalink
pass mlas test for Atype 32 and 16
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu committed Aug 15, 2024
1 parent 1446187 commit b311401
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 92 deletions.
3 changes: 2 additions & 1 deletion onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ typedef enum {
* @brief Data parameters for float/n-bit quantized int GEMM routine.
*/
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const MLAS_FP16* A = nullptr; ///< address of A (float32 matrix)
const void* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values)
const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data
Expand Down Expand Up @@ -85,6 +85,7 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
buffer with at least that many bytes. Otherwise, it may be nullptr.
* @param[in] ThreadPool optional thread pool to use
*/
template<typename T>
void MLASCALL
MlasSQNBitGemmBatch(
size_t M,
Expand Down
152 changes: 125 additions & 27 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ MlasIsSQNBitGemmAvailable(

switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompFp32: {
return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr &&
return Dispatch->SQ4BitGemmM1Kernel_CompFp32_ATypeFp32 != nullptr &&
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
}
case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8
return
(Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) ||
(Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr);
(Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8_ATypeFp32 != nullptr);
}
default: {
return false;
Expand Down Expand Up @@ -295,7 +295,35 @@ AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t
}
}

typedef void(SQNBitGemmFn)(
MLAS_FORCEINLINE
void
ConvertFp16ToFp32(const MLAS_FP16* a_row, std::vector<float>& a_row_fp32)
{
size_t size = a_row_fp32.size();
size_t i = 0;

// Process 16 elements at a time using AVX2
for (; i + 15 < size; i += 16) {
// Load 16 FP16 values into an AVX2 register
__m256i fp16_values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a_row + i));

// Convert FP16 values to FP32
__m256 fp32_values1 = _mm256_cvtph_ps(_mm256_castsi256_si128(fp16_values));
__m256 fp32_values2 = _mm256_cvtph_ps(_mm256_extracti128_si256(fp16_values, 1));

// Store the converted FP32 values into the output vector
_mm256_storeu_ps(a_row_fp32.data() + i, fp32_values1);
_mm256_storeu_ps(a_row_fp32.data() + i + 8, fp32_values2);
}

// Process any remaining elements
for (; i < size; ++i) {
a_row_fp32[i] = a_row[i].ToFloat();
}
}

template <typename AType>
using SQNBitGemmFn = void(
size_t BlkLen,
size_t K,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
Expand All @@ -306,6 +334,7 @@ typedef void(SQNBitGemmFn)(
size_t RangeCountN
);

template<typename AType>
void
SQ4BitGemm_CompFp32(
const size_t BlkLen,
Expand Down Expand Up @@ -337,7 +366,7 @@ SQ4BitGemm_CompFp32(
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);

const MLAS_FP16* A = DataParams->A + RangeStartM * lda;
const AType* A = (AType*)(DataParams->A) + RangeStartM * lda;

const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->PackedQuantBData) + RangeStartN * ldb;
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
Expand All @@ -355,19 +384,15 @@ SQ4BitGemm_CompFp32(
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});

const MLAS_FP16* a_row = A;
const AType* a_row = A;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
);

GetMlasPlatform().SQNBitGemmDispatch->CallSQ4BitGemmM1Kernel_CompFp32_Fn<AType>(BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias);
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
Expand All @@ -393,7 +418,7 @@ SQ4BitGemm_CompFp32(
//
// Step through each slice of matrix A along the M dimension.
//
const MLAS_FP16* a_row = A;
const AType* a_row = A;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
Expand All @@ -405,15 +430,28 @@ SQ4BitGemm_CompFp32(
BlkLen,
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
);

#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
std::vector<float> a_row_fp32_v;
float* a_row_fp32 = nullptr;
if constexpr (std::is_same<AType, MLAS_FP16>::value) {
a_row_fp32_v.resize(lda * RangeCountM);
ConvertFp16ToFp32(a_row, a_row_fp32_v);
a_row_fp32 = &a_row_fp32_v[0];
}
#endif
size_t RowsRemaining = RangeCountM;
while (RowsRemaining > 0) {
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
assert(false);
auto RowsHandled = 0;
//auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
// a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true
//);
int64_t RowsHandled = 0;
if constexpr (std::is_same<AType, MLAS_FP16>::value) {
RowsHandled = GetMlasPlatform().GemmFloatKernel(
a_row_fp32, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true
);
} else {
RowsHandled = GetMlasPlatform().GemmFloatKernel(
a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true
);
}
#else
auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f);
#endif
Expand All @@ -430,11 +468,17 @@ SQ4BitGemm_CompFp32(

c_blk += ldc * RowsHandled;
a_row += lda * RowsHandled;
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
if constexpr (std::is_same<AType, MLAS_FP16>::value) {
a_row_fp32 += lda * RowsHandled;
}
#endif
RowsRemaining -= RowsHandled;
}
}
}

template<typename AType>
void
SQ4BitGemm_CompInt8(
const size_t BlkLen,
Expand Down Expand Up @@ -564,7 +608,8 @@ SQ4BitGemm_CompInt8(
}
}

typedef void(InitializeWorkspaceFn)(
template <typename AType>
using InitializeWorkspaceFn = void(
size_t M,
size_t N,
size_t K,
Expand All @@ -576,6 +621,7 @@ typedef void(InitializeWorkspaceFn)(
MLAS_THREADPOOL* ThreadPool
);

template<typename AType>
void
InitializeWorkspace_CompInt8(
size_t M,
Expand All @@ -592,7 +638,7 @@ InitializeWorkspace_CompInt8(
MLAS_UNREFERENCED_PARAMETER(N);

const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8;
const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8;
const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->GetQuantizeARowComputeBlkSum_CompInt8_Fn<AType>();

const size_t BlockCountK = MlasDivRoundup(K, BlkLen);

Expand All @@ -614,7 +660,7 @@ InitializeWorkspace_CompInt8(
} else {
MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
const auto& data = DataParams[gemm_idx];
const MLAS_FP16* ARowPtr = data.A;
const AType* ARowPtr = static_cast<const AType*>(data.A);

void* PerGemmWorkspace = static_cast<std::byte*>(Workspace) + gemm_idx * PerGemmWorkspaceStride;
PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen);
Expand All @@ -633,22 +679,47 @@ InitializeWorkspace_CompInt8(
}

struct Operations {
InitializeWorkspaceFn* InitializeWorkspace = nullptr;
SQNBitGemmFn* SQNBitGemm = nullptr;
InitializeWorkspaceFn<float>* InitializeWorkspace_ATypeFp32 = nullptr;
InitializeWorkspaceFn<MLAS_FP16>* InitializeWorkspace_ATypeFp16 = nullptr;
template <typename AType>
InitializeWorkspaceFn<AType>*
GetInitializeWorkspaceFn() const {
if constexpr (std::is_same<AType, MLAS_FP16>::value) {
return InitializeWorkspace_ATypeFp16;
} else {
return InitializeWorkspace_ATypeFp32;
}
}
SQNBitGemmFn<float>* SQNBitGemm_ATypeFp32 = nullptr;
SQNBitGemmFn<MLAS_FP16>* SQNBitGemm_ATypeFp16 = nullptr;
template <typename AType>
SQNBitGemmFn<AType>*
GetSQNBitGemmFn() const
{
if constexpr (std::is_same<AType, MLAS_FP16>::value) {
return SQNBitGemm_ATypeFp16;
} else {
return SQNBitGemm_ATypeFp32;
}
}
};

constexpr auto OperationMap = []() {
std::array<Operations, SQNBitGemmVariantCount> ops;

ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32;
ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm_ATypeFp32 = SQ4BitGemm_CompFp32<float>;
ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm_ATypeFp16 = SQ4BitGemm_CompFp32<MLAS_FP16>;

ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8;
ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8;
ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace_ATypeFp32 = InitializeWorkspace_CompInt8<float>;
ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace_ATypeFp16 = InitializeWorkspace_CompInt8<MLAS_FP16>;
ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm_ATypeFp32 = SQ4BitGemm_CompInt8<float>;
ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm_ATypeFp16 = SQ4BitGemm_CompInt8<MLAS_FP16>;

return ops;
}();
} // namespace

template<typename AType>
void MLASCALL
MlasSQNBitGemmBatch(
const size_t M,
Expand Down Expand Up @@ -679,14 +750,14 @@ MlasSQNBitGemmBatch(

const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType);

if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace;
if (const auto InitializeWorkspaceOperation = OperationMap[Variant].GetInitializeWorkspaceFn<AType>();
InitializeWorkspaceOperation != nullptr) {
InitializeWorkspaceOperation(
M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool
);
}

const auto ComputeOperation = OperationMap[Variant].SQNBitGemm;
const auto ComputeOperation = OperationMap[Variant].GetSQNBitGemmFn<AType>();

const size_t BlockCountK = MlasDivRoundup(K, BlkLen);

Expand Down Expand Up @@ -779,3 +850,30 @@ MlasSQNBitGemmBatch(
}
});
}

// Explicit template instantiations
template void MLASCALL MlasSQNBitGemmBatch<float>(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const size_t BlkBitWidth,
const size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* Workspace,
MLAS_THREADPOOL* ThreadPool
);

template void MLASCALL MlasSQNBitGemmBatch<MLAS_FP16>(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const size_t BlkBitWidth,
const size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* Workspace,
MLAS_THREADPOOL* ThreadPool
);
Loading

0 comments on commit b311401

Please sign in to comment.