diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 304aa77f5473c..4abed59695add 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -200,6 +200,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/sgemma.asm ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm + ${MLAS_SRC_DIR}/amd64/cvtfp16Avx2.asm ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 04172e46e9c6a..e310d7492245f 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -135,6 +135,8 @@ void CPUIDInfo::X86Init() { if (max_SubLeaves >= 1) { GetCPUID(7, 1, data); has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5)); + // Check for AVX_NE_CONVERT as half precision kernel uses it with AVX2 and F16C + has_avx_ne_convert_ = has_avx2_ && has_f16c_ && (data[3] & (1 << 5)); } } } diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 0bee36e4d10b3..c49bafd2113ef 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -22,6 +22,7 @@ class CPUIDInfo { bool HasAVX512_BF16() const { return has_avx512_bf16_; } bool HasAVX512Skylake() const { return has_avx512_skylake_; } bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/ + bool HasAVX_NE_CONVERT() const { return has_avx_ne_convert_; } /*fp16/bf16 conversion inst*/ bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } @@ -101,6 +102,7 @@ class CPUIDInfo { bool has_avx512_bf16_{false}; bool has_avx512_skylake_{false}; bool has_f16c_{false}; + bool has_avx_ne_convert_{false}; bool has_sse3_{false}; bool has_sse4_1_{false}; bool is_hybrid_{false}; diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index cdfd283899c8c..b6718f36e9b41 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1037,6 +1037,14 @@ MlasConvertHalfToFloatBuffer( size_t Count ); +extern "C" void +MLASCALL +MlasConvertHalfToFloatBufferAVX2( + const unsigned short* Source, + float* Destination, + size_t Count + ); + // // Transpose routines. // diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx2.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx2.asm new file mode 100644 index 0000000000000..1fdeebee66d63 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx2.asm @@ -0,0 +1,148 @@ +;++ +; +; Copyright (c) Intel Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; cvtfp16Avx2.asm +; +; Abstract: +; +; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. +; +;-- + + .xlist +INCLUDE mlasi.inc + .list + + .const +SINGLE_SIZE equ 4 +HALF_SIZE equ 2 +LOW_SELECTOR equ 00100000b +HIGH_SELECTOR equ 00110001b + SUBTTL "Convert buffer of half-precision floats to single-precision floats" +;++ +; +; Routine Description: +; +; This routine converts the source buffer of half-precision floats to the +; destination buffer of single-precision floats. +; +; This implementation uses AVX2 instructions. +; +; Arguments: +; +; Source (rcx) - Supplies the address of the source buffer of half-precision +; floats. +; +; Destination (rdx) - Supplies the address of the destination buffer of +; single-precision floats. +; +; Count (r8) - Supplies the number of elements to convert. +; +; Return Value: +; +; None. +; +;-- + + LEAF_ENTRY MlasConvertHalfToFloatBufferAVX2, _TEXT + + test r8, r8 + jz ExitRoutine + cmp r8, 8 + jb ConvertMaskedVectors + cmp r8, 16 + jb Convert128Vectors + + + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes + vunpcklps ymm2, ymm0, ymm1 ; Interleave low part + vunpckhps ymm1, ymm0, ymm1 ; Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order + vmovups ymmword PTR [rdx], ymm0 ; Store the low part + vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part + + add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements + add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements + sub r8, 16 ; Reduce the counter by 16 elements + + jz ExitRoutine ; If we are done, exit + cmp r8, 16 ; If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part + + add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements + add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements + sub r8, 8 ; Reduce the counter by 8 elements + + jz ExitRoutine ; If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + + cmp r8, 4 ; Chek if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + sub r8, 4 ; Check if we still need to convert + jz ExitRoutine + + + add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples + +ExitRoutine: + ret + + LEAF_END MlasConvertHalfToFloatBufferAVX2, _TEXT + + END diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index cbc4d8360d4b1..e837c3427432f 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -26,6 +26,9 @@ #include "core/mlas/inc/mlas.h" #endif +#include "core/common/cpuid_info.h" + + namespace onnxruntime { namespace op_kernel_type_control { @@ -263,7 +266,12 @@ struct TensorCaster { auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size()); - MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); + if (onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()) { + MlasConvertHalfToFloatBufferAVX2(&in_data[0].val, out_data, shape_size); + } + else { + MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); + } } };