diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 304aa77f5473c..d5024f4fa312a 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -208,6 +208,12 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm ) + if(MSVC_VERSION GREATER_EQUAL 1933) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm + ) + endif() + if (NOT onnxruntime_ORT_MINIMAL_BUILD) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/q4gemm_avx512.cpp @@ -536,6 +542,12 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.0) + set(mlas_platform_srcs_avx2 + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S + ) + endif() set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") set(mlas_platform_srcs_avx512f diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 04172e46e9c6a..e310d7492245f 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -135,6 +135,8 @@ void CPUIDInfo::X86Init() { if (max_SubLeaves >= 1) { GetCPUID(7, 1, data); has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5)); + // Check for AVX_NE_CONVERT as half precision kernel uses it with AVX2 and F16C + has_avx_ne_convert_ = has_avx2_ && has_f16c_ && (data[3] & (1 << 5)); } } } diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 0bee36e4d10b3..5943e2ef3c730 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -21,7 +21,8 @@ class CPUIDInfo { bool HasAVX512f() const { return has_avx512f_; } bool HasAVX512_BF16() const { return has_avx512_bf16_; } bool HasAVX512Skylake() const { return has_avx512_skylake_; } - bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/ + bool HasF16C() const { return has_f16c_; }/*fp16 conversion inst*/ + bool HasAVX_NE_CONVERT() const { return has_avx_ne_convert_; } /*fp16/bf16 conversion inst*/ bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } @@ -101,6 +102,7 @@ class CPUIDInfo { bool has_avx512_bf16_{false}; bool has_avx512_skylake_{false}; bool has_f16c_{false}; + bool has_avx_ne_convert_{false}; bool has_sse3_{false}; bool has_sse4_1_{false}; bool is_hybrid_{false}; diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index cdfd283899c8c..caaf569d804f1 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1037,6 +1037,16 @@ MlasConvertHalfToFloatBuffer( size_t Count ); +#if (_MSC_VER >= 1933) || (__GNUC__ >= 13) +extern "C" void +MLASCALL +MlasConvertHalfToFloatBufferAVX( + const unsigned short* Source, + float* Destination, + size_t Count + ); +#endif + // // Transpose routines. // diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm new file mode 100644 index 0000000000000..1eb14af601ef3 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm @@ -0,0 +1,148 @@ +;++ +; +; Copyright (c) Intel Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; cvtfp16Avx2.asm +; +; Abstract: +; +; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. +; +;-- + + .xlist +INCLUDE mlasi.inc + .list + + .const +SINGLE_SIZE equ 4 +HALF_SIZE equ 2 +LOW_SELECTOR equ 00100000b +HIGH_SELECTOR equ 00110001b + SUBTTL "Convert buffer of half-precision floats to single-precision floats" +;++ +; +; Routine Description: +; +; This routine converts the source buffer of half-precision floats to the +; destination buffer of single-precision floats. +; +; This implementation uses AVX2 instructions. +; +; Arguments: +; +; Source (rcx) - Supplies the address of the source buffer of half-precision +; floats. +; +; Destination (rdx) - Supplies the address of the destination buffer of +; single-precision floats. +; +; Count (r8) - Supplies the number of elements to convert. +; +; Return Value: +; +; None. +; +;-- + + LEAF_ENTRY MlasConvertHalfToFloatBufferAVX, _TEXT + + test r8, r8 + jz ExitRoutine + cmp r8, 8 + jb ConvertMaskedVectors + cmp r8, 16 + jb Convert128Vectors + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes + vunpcklps ymm2, ymm0, ymm1 ; Interleave low part + vunpckhps ymm1, ymm0, ymm1 ; Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order + vmovups ymmword PTR [rdx], ymm0 ; Store the low part + vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part + + add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements + add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements + sub r8, 16 ; Reduce the counter by 16 elements + + jz ExitRoutine ; If we are done, exit + cmp r8, 16 ; If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part + + add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements + add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements + sub r8, 8 ; Reduce the counter by 8 elements + + jz ExitRoutine ; If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + + cmp r8, 4 ; Chek if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + sub r8, 4 ; Check if we still need to convert + jz ExitRoutine + + + add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples + + jmp ExitRoutine + +ExitRoutine: + ret + + LEAF_END MlasConvertHalfToFloatBufferAVX, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S new file mode 100644 index 0000000000000..1a2855b8278b9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S @@ -0,0 +1,125 @@ +/*++ Routine Description: + + This routine converts the source buffer of half-precision floats to the + destination buffer of single-precision floats. + + This implementation uses AVX2 instructions. + + Arguments: + + Source (rdi) - Supplies the address of the source buffer of half-precision + floats. + + Destination (rsi) - Supplies the address of the destination buffer of + single-precision floats. + + Count (rdx) - Supplies the number of elements to convert. + + Return Value: + + None. + +--*/ +.data +.equ SINGLE_SIZE, 4 +.equ HALF_SIZE, 2 +.equ LOW_SELECTOR, 0b00100000 +.equ HIGH_SELECTOR, 0b00110001 + +.text +.globl MlasConvertHalfToFloatBufferAVX +.intel_syntax noprefix + +MlasConvertHalfToFloatBufferAVX: + test rdx, rdx // Check if we have any elements to convert + jz ExitRoutine + +AVX_NE_CONVERT: + cmp rdx, 8 + jb ConvertMaskedVectors + cmp rdx, 16 + jb Convert128Vectors + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes + vunpcklps ymm2, ymm0, ymm1 // Interleave low part + vunpckhps ymm1, ymm0, ymm1 // Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order + vmovups ymmword PTR [rsi], ymm0 // Store the low part + vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part + + add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements + add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements + sub rdx, 16 // Reduce the counter by 16 elements + + jz ExitRoutine // If we are done, exit + cmp rdx, 16 // If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + vmovups xmmword PTR [rsi], xmm0 // Store the low part + vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part + + add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements + add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements + sub rdx, 8 // Reduce the counter by 8 elements + + jz ExitRoutine // If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + + cmp rdx, 4 // Chek if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rsi], xmm0 // Store the low part + sub rdx, 4 // Check if we still need to convert + jz ExitRoutine + + + add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples + + jmp ExitRoutine +ExitRoutine: + ret diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index cbc4d8360d4b1..cd78b656d2b7d 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -22,9 +22,10 @@ #include "Eigen/src/Core/arch/Default/BFloat16.h" #include "Eigen/src/Core/arch/Default/Half.h" -#if defined(_M_AMD64) && !defined(_M_ARM64EC) #include "core/mlas/inc/mlas.h" -#endif + +#include "core/common/cpuid_info.h" + namespace onnxruntime { @@ -252,20 +253,56 @@ struct TensorCasterNoSat { #endif -#if defined(_M_AMD64) && !defined(_M_ARM64EC) -// specializations to use optimized and Windows x64-specific -// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion - // tensor MLFloat16 -> float +#if (defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__) template <> struct TensorCaster { + void Fallback(const TensorShape& shape, const Tensor& in, Tensor& out) const { + using SrcEigenCastType = typename EigenCastType::type; + using DstEigenCastType = typename EigenCastType::type; + + const std::ptrdiff_t shape_size = narrow(shape.Size()); + const auto in_vector = + ConstEigenVectorMap(reinterpret_cast(in.Data()), shape_size); + auto out_vector = + EigenVectorMap(reinterpret_cast(out.MutableData()), shape_size); + out_vector = in_vector.template cast(); + } void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { +#if defined(_MSC_VER) auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size()); +#if (_MSC_VER >= 1933) + if (onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()) { + MlasConvertHalfToFloatBufferAVX(&in_data[0].val, out_data, shape_size); + } else { + MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); + } +#else MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); +#endif +#elif (__GNUC__ >= 13) + auto out_data = out.MutableData(); + auto in_data = in.Data(); + const size_t shape_size = narrow(shape.Size()); + if (onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()) { + MlasConvertHalfToFloatBufferAVX(&in_data[0].val, out_data, shape_size); + } else { + // Use the generic function + Fallback(shape, in, out); + } +#else + // Use the generic function + Fallback(shape, in, out); +#endif } }; +#endif + +#if defined(_M_AMD64) && !defined(_M_ARM64EC) +// specializations to use optimized and Windows x64-specific +// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion Tensor GetIntermediateMLFloat16ToFloatTensor( const OpKernelContext& context, const TensorShape& shape, const Tensor& in) {