From 1d46d17f3a48472c93e35bd48d2b77268a2274bc Mon Sep 17 00:00:00 2001 From: Erick Munoz Date: Tue, 2 Jul 2024 09:49:42 -0700 Subject: [PATCH 1/3] Merge cast implementations --- cmake/onnxruntime_mlas.cmake | 1 - onnxruntime/core/mlas/inc/mlas.h | 11 +- .../core/mlas/lib/amd64/cvtfp16Avx2.asm | 148 ------------------ onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm | 120 +++++++++++++- .../core/providers/cpu/tensor/cast_op.cc | 7 +- 5 files changed, 115 insertions(+), 172 deletions(-) delete mode 100644 onnxruntime/core/mlas/lib/amd64/cvtfp16Avx2.asm diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 4abed59695add..304aa77f5473c 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -200,7 +200,6 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/sgemma.asm ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm - ${MLAS_SRC_DIR}/amd64/cvtfp16Avx2.asm ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index b6718f36e9b41..99ab3b2b2484e 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1034,15 +1034,8 @@ MLASCALL MlasConvertHalfToFloatBuffer( const unsigned short* Source, float* Destination, - size_t Count - ); - -extern "C" void -MLASCALL -MlasConvertHalfToFloatBufferAVX2( - const unsigned short* Source, - float* Destination, - size_t Count + size_t Count, + bool useAVX ); // diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx2.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx2.asm deleted file mode 100644 index 1fdeebee66d63..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx2.asm +++ /dev/null @@ -1,148 +0,0 @@ -;++ -; -; Copyright (c) Intel Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; cvtfp16Avx2.asm -; -; Abstract: -; -; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - .const -SINGLE_SIZE equ 4 -HALF_SIZE equ 2 -LOW_SELECTOR equ 00100000b -HIGH_SELECTOR equ 00110001b - SUBTTL "Convert buffer of half-precision floats to single-precision floats" -;++ -; -; Routine Description: -; -; This routine converts the source buffer of half-precision floats to the -; destination buffer of single-precision floats. -; -; This implementation uses AVX2 instructions. -; -; Arguments: -; -; Source (rcx) - Supplies the address of the source buffer of half-precision -; floats. -; -; Destination (rdx) - Supplies the address of the destination buffer of -; single-precision floats. -; -; Count (r8) - Supplies the number of elements to convert. -; -; Return Value: -; -; None. -; -;-- - - LEAF_ENTRY MlasConvertHalfToFloatBufferAVX2, _TEXT - - test r8, r8 - jz ExitRoutine - cmp r8, 8 - jb ConvertMaskedVectors - cmp r8, 16 - jb Convert128Vectors - - - -Convert256Vectors: - vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes - vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes - vunpcklps ymm2, ymm0, ymm1 ; Interleave low part - vunpckhps ymm1, ymm0, ymm1 ; Interleave high part - vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order - vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order - vmovups ymmword PTR [rdx], ymm0 ; Store the low part - vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part - - add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements - add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements - sub r8, 16 ; Reduce the counter by 16 elements - - jz ExitRoutine ; If we are done, exit - cmp r8, 16 ; If the vector is big enough, we go again - jae Convert256Vectors - - - -Convert128Vectors: - vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes - vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes - vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order - vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order - vmovups xmmword PTR [rdx], xmm0 ; Store the low part - vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part - - add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements - add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements - sub r8, 8 ; Reduce the counter by 8 elements - - jz ExitRoutine ; If we are done, exit - - - -ConvertMaskedVectors: - vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes - vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes - vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order - vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order - - cmp r8, 4 ; Chek if we can store the complete lower vector - jae ConvertLowerVector - - vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones - cmp r8, 2 ; Check how many converts we need - jb ConvertLower1 - ja ConvertLower3 - vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values - jmp ConvertLowerMaskedVector -ConvertLower1: - vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value - jmp ConvertLowerMaskedVector -ConvertLower3: - vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values -ConvertLowerMaskedVector: - vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples - jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing -ConvertLowerVector: - vmovups xmmword PTR [rdx], xmm0 ; Store the low part - sub r8, 4 ; Check if we still need to convert - jz ExitRoutine - - - add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements - vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones - cmp r8, 2 ; Check how many converts we need - jb ConvertUpper1 - ja ConvertUpper3 - vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values - jmp ConvertMaskedUpperVector -ConvertUpper1: - vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value - jmp ConvertMaskedUpperVector -ConvertUpper3: - vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values -ConvertMaskedUpperVector: - vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples - -ExitRoutine: - ret - - LEAF_END MlasConvertHalfToFloatBufferAVX2, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm index 50315146ca79b..bc3efff1f481e 100644 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm @@ -10,7 +10,8 @@ ; ; Abstract: ; -; This module implements routines to convert between FP16 and FP32 formats. +; This module implements routines to convert between FP16 and FP32 formats, one using old SSE +; instructions and new AVX_NE_CONVERT as well. ; ;-- @@ -21,42 +22,145 @@ INCLUDE mlasi.inc .const ALIGN 16 +; Legacy implementation constants MlasFp16MaskSign DD 4 DUP (00007FFFh) MlasFp16CompareInfinity DD 4 DUP (00007C00h) MlasFp16CompareSmallest DD 4 DUP (00000400h) MlasFp16AdjustExponent DD 4 DUP (38000000h) MlasFp16MagicDenormal DD 4 DUP (38800000h) +; AVX implementation constants +SINGLE_SIZE equ 4 +HALF_SIZE equ 2 +LOW_SELECTOR equ 00100000b +HIGH_SELECTOR equ 00110001b + SUBTTL "Convert buffer of half-precision floats to single-precision floats" ;++ ; ; Routine Description: ; -; This routine converts the source buffer of half-precision floats to the -; destination buffer of single-precision floats. -; -; This implementation uses SSE2 instructions. +; This routine calls the implementation of the cast operator depending on the ISA flag. ; ; Arguments: ; ; Source (rcx) - Supplies the address of the source buffer of half-precision ; floats. ; -; Destination (edx) - Supplies the address of the destination buffer of +; Destination (rdx) - Supplies the address of the destination buffer of ; single-precision floats. ; ; Count (r8) - Supplies the number of elements to convert. ; +; ISA flag (r9) - Determines whether to use AVX_NE_CONVERT or not. +; ; Return Value: ; ; None. ; ;-- - LEAF_ENTRY MlasConvertHalfToFloatBuffer, _TEXT - test r8,r8 +LEAF_ENTRY MlasConvertHalfToFloatBuffer, _TEXT + + test r8, r8 ; Check if we have any elements to convert jz ExitRoutine + test r9, r9 ; Check if we need to use AVX_NE_CONVERT + jz SSE + +AVX_NE_CONVERT: + cmp r8, 8 + jb ConvertMaskedVectors + cmp r8, 16 + jb Convert128Vectors + + + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes + vunpcklps ymm2, ymm0, ymm1 ; Interleave low part + vunpckhps ymm1, ymm0, ymm1 ; Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order + vmovups ymmword PTR [rdx], ymm0 ; Store the low part + vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part + + add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements + add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements + sub r8, 16 ; Reduce the counter by 16 elements + + jz ExitRoutine ; If we are done, exit + cmp r8, 16 ; If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part + + add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements + add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements + sub r8, 8 ; Reduce the counter by 8 elements + + jz ExitRoutine ; If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + + cmp r8, 4 ; Chek if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + sub r8, 4 ; Check if we still need to convert + jz ExitRoutine + + + add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples + + jmp ExitRoutine + + + +SSE: cmp r8,4 jb LoadPartialVector diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index e837c3427432f..4f97b89f101bb 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -266,12 +266,7 @@ struct TensorCaster { auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size()); - if (onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()) { - MlasConvertHalfToFloatBufferAVX2(&in_data[0].val, out_data, shape_size); - } - else { - MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); - } + MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size, onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()); } }; From 70ef759dec0309be239f8863969ccdab4119680a Mon Sep 17 00:00:00 2001 From: wos Date: Tue, 9 Jul 2024 14:25:43 -0700 Subject: [PATCH 2/3] Add linux support --- cmake/onnxruntime_mlas.cmake | 1 + onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S | 125 ++++++++++++++++++ .../core/providers/cpu/tensor/cast_op.cc | 32 ++++- 3 files changed, 152 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 304aa77f5473c..0139652cdb008 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -532,6 +532,7 @@ else() ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S new file mode 100644 index 0000000000000..d122d5df9a507 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S @@ -0,0 +1,125 @@ +/*++ Routine Description: + + This routine converts the source buffer of half-precision floats to the + destination buffer of single-precision floats. + + This implementation uses AVX2 instructions. + + Arguments: + + Source (rdi) - Supplies the address of the source buffer of half-precision + floats. + + Destination (rsi) - Supplies the address of the destination buffer of + single-precision floats. + + Count (rdx) - Supplies the number of elements to convert. + + Return Value: + + None. + +--*/ +.data +.equ SINGLE_SIZE, 4 +.equ HALF_SIZE, 2 +.equ LOW_SELECTOR, 0b00100000 +.equ HIGH_SELECTOR, 0b00110001 + +.text +.globl MlasConvertHalfToFloatBuffer +.intel_syntax noprefix + +MlasConvertHalfToFloatBuffer: + test rdx, rdx // Check if we have any elements to convert + jz ExitRoutine + +AVX_NE_CONVERT: + cmp rdx, 8 + jb ConvertMaskedVectors + cmp rdx, 16 + jb Convert128Vectors + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes + vunpcklps ymm2, ymm0, ymm1 // Interleave low part + vunpckhps ymm1, ymm0, ymm1 // Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order + vmovups ymmword PTR [rsi], ymm0 // Store the low part + vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part + + add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements + add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements + sub rdx, 16 // Reduce the counter by 16 elements + + jz ExitRoutine // If we are done, exit + cmp rdx, 16 // If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + vmovups xmmword PTR [rsi], xmm0 // Store the low part + vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part + + add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements + add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements + sub rdx, 8 // Reduce the counter by 8 elements + + jz ExitRoutine // If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + + cmp rdx, 4 // Chek if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rsi], xmm0 // Store the low part + sub rdx, 4 // Check if we still need to convert + jz ExitRoutine + + + add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples + + jmp ExitRoutine +ExitRoutine: + ret diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 4f97b89f101bb..329518d42e38b 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -22,9 +22,7 @@ #include "Eigen/src/Core/arch/Default/BFloat16.h" #include "Eigen/src/Core/arch/Default/Half.h" -#if defined(_M_AMD64) && !defined(_M_ARM64EC) #include "core/mlas/inc/mlas.h" -#endif #include "core/common/cpuid_info.h" @@ -255,20 +253,42 @@ struct TensorCasterNoSat { #endif -#if defined(_M_AMD64) && !defined(_M_ARM64EC) -// specializations to use optimized and Windows x64-specific -// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion - // tensor MLFloat16 -> float +#if (defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__) template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { +#if defined(_MSC_VER) auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size()); MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size, onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()); +#else + if(onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()){ + auto out_data = out.MutableData(); + auto in_data = in.Data(); + const size_t shape_size = narrow(shape.Size()); + MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size, true); + } else { + // Use the generic function + using SrcEigenCastType = typename EigenCastType::type; + using DstEigenCastType = typename EigenCastType::type; + + const std::ptrdiff_t shape_size = narrow(shape.Size()); + const auto in_vector = + ConstEigenVectorMap(reinterpret_cast(in.Data()), shape_size); + auto out_vector = + EigenVectorMap(reinterpret_cast(out.MutableData()), shape_size); + out_vector = in_vector.template cast(); + } +#endif } }; +#endif + +#if defined(_M_AMD64) && !defined(_M_ARM64EC) +// specializations to use optimized and Windows x64-specific +// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion Tensor GetIntermediateMLFloat16ToFloatTensor( const OpKernelContext& context, const TensorShape& shape, const Tensor& in) { From 8a1a28c581d1a924a3d69c745ec405af0298ae85 Mon Sep 17 00:00:00 2001 From: Erick Munoz Alvarado Date: Thu, 11 Jul 2024 11:15:39 -0600 Subject: [PATCH 3/3] Linter suggestions --- onnxruntime/core/common/cpuid_info.h | 2 +- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index c49bafd2113ef..5943e2ef3c730 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -21,7 +21,7 @@ class CPUIDInfo { bool HasAVX512f() const { return has_avx512f_; } bool HasAVX512_BF16() const { return has_avx512_bf16_; } bool HasAVX512Skylake() const { return has_avx512_skylake_; } - bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/ + bool HasF16C() const { return has_f16c_; }/*fp16 conversion inst*/ bool HasAVX_NE_CONVERT() const { return has_avx_ne_convert_; } /*fp16/bf16 conversion inst*/ bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 329518d42e38b..9cf937f5ebc54 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -264,7 +264,7 @@ struct TensorCaster { const size_t shape_size = narrow(shape.Size()); MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size, onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()); #else - if(onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()){ + if (onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()) { auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size());