From 70ef759dec0309be239f8863969ccdab4119680a Mon Sep 17 00:00:00 2001 From: wos Date: Tue, 9 Jul 2024 14:25:43 -0700 Subject: [PATCH] Add linux support --- cmake/onnxruntime_mlas.cmake | 1 + onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S | 125 ++++++++++++++++++ .../core/providers/cpu/tensor/cast_op.cc | 32 ++++- 3 files changed, 152 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 304aa77f5473c..0139652cdb008 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -532,6 +532,7 @@ else() ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S new file mode 100644 index 0000000000000..d122d5df9a507 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S @@ -0,0 +1,125 @@ +/*++ Routine Description: + + This routine converts the source buffer of half-precision floats to the + destination buffer of single-precision floats. + + This implementation uses AVX2 instructions. + + Arguments: + + Source (rdi) - Supplies the address of the source buffer of half-precision + floats. + + Destination (rsi) - Supplies the address of the destination buffer of + single-precision floats. + + Count (rdx) - Supplies the number of elements to convert. + + Return Value: + + None. + +--*/ +.data +.equ SINGLE_SIZE, 4 +.equ HALF_SIZE, 2 +.equ LOW_SELECTOR, 0b00100000 +.equ HIGH_SELECTOR, 0b00110001 + +.text +.globl MlasConvertHalfToFloatBuffer +.intel_syntax noprefix + +MlasConvertHalfToFloatBuffer: + test rdx, rdx // Check if we have any elements to convert + jz ExitRoutine + +AVX_NE_CONVERT: + cmp rdx, 8 + jb ConvertMaskedVectors + cmp rdx, 16 + jb Convert128Vectors + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes + vunpcklps ymm2, ymm0, ymm1 // Interleave low part + vunpckhps ymm1, ymm0, ymm1 // Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order + vmovups ymmword PTR [rsi], ymm0 // Store the low part + vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part + + add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements + add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements + sub rdx, 16 // Reduce the counter by 16 elements + + jz ExitRoutine // If we are done, exit + cmp rdx, 16 // If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + vmovups xmmword PTR [rsi], xmm0 // Store the low part + vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part + + add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements + add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements + sub rdx, 8 // Reduce the counter by 8 elements + + jz ExitRoutine // If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + + cmp rdx, 4 // Chek if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rsi], xmm0 // Store the low part + sub rdx, 4 // Check if we still need to convert + jz ExitRoutine + + + add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples + + jmp ExitRoutine +ExitRoutine: + ret diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 4f97b89f101bb..329518d42e38b 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -22,9 +22,7 @@ #include "Eigen/src/Core/arch/Default/BFloat16.h" #include "Eigen/src/Core/arch/Default/Half.h" -#if defined(_M_AMD64) && !defined(_M_ARM64EC) #include "core/mlas/inc/mlas.h" -#endif #include "core/common/cpuid_info.h" @@ -255,20 +253,42 @@ struct TensorCasterNoSat { #endif -#if defined(_M_AMD64) && !defined(_M_ARM64EC) -// specializations to use optimized and Windows x64-specific -// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion - // tensor MLFloat16 -> float +#if (defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__) template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { +#if defined(_MSC_VER) auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size()); MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size, onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()); +#else + if(onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()){ + auto out_data = out.MutableData(); + auto in_data = in.Data(); + const size_t shape_size = narrow(shape.Size()); + MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size, true); + } else { + // Use the generic function + using SrcEigenCastType = typename EigenCastType::type; + using DstEigenCastType = typename EigenCastType::type; + + const std::ptrdiff_t shape_size = narrow(shape.Size()); + const auto in_vector = + ConstEigenVectorMap(reinterpret_cast(in.Data()), shape_size); + auto out_vector = + EigenVectorMap(reinterpret_cast(out.MutableData()), shape_size); + out_vector = in_vector.template cast(); + } +#endif } }; +#endif + +#if defined(_M_AMD64) && !defined(_M_ARM64EC) +// specializations to use optimized and Windows x64-specific +// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion Tensor GetIntermediateMLFloat16ToFloatTensor( const OpKernelContext& context, const TensorShape& shape, const Tensor& in) {