Skip to content

Commit

Permalink
Add linux support
Browse files Browse the repository at this point in the history
  • Loading branch information
eralmual committed Jul 9, 2024
1 parent 1d46d17 commit 70ef759
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 6 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ else()
${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S
${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S
${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S
${MLAS_SRC_DIR}/x86_64/cvtfp16a.S
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
Expand Down
125 changes: 125 additions & 0 deletions onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*++ Routine Description:

This routine converts the source buffer of half-precision floats to the
destination buffer of single-precision floats.

This implementation uses AVX2 instructions.

Arguments:

Source (rdi) - Supplies the address of the source buffer of half-precision
floats.

Destination (rsi) - Supplies the address of the destination buffer of
single-precision floats.

Count (rdx) - Supplies the number of elements to convert.

Return Value:

None.

--*/
.data
.equ SINGLE_SIZE, 4
.equ HALF_SIZE, 2
.equ LOW_SELECTOR, 0b00100000
.equ HIGH_SELECTOR, 0b00110001

.text
.globl MlasConvertHalfToFloatBuffer
.intel_syntax noprefix

MlasConvertHalfToFloatBuffer:
test rdx, rdx // Check if we have any elements to convert
jz ExitRoutine

AVX_NE_CONVERT:
cmp rdx, 8
jb ConvertMaskedVectors
cmp rdx, 16
jb Convert128Vectors

Convert256Vectors:
vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes
vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes
vunpcklps ymm2, ymm0, ymm1 // Interleave low part
vunpckhps ymm1, ymm0, ymm1 // Interleave high part
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order
vmovups ymmword PTR [rsi], ymm0 // Store the low part
vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part

add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements
add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements
sub rdx, 16 // Reduce the counter by 16 elements

jz ExitRoutine // If we are done, exit
cmp rdx, 16 // If the vector is big enough, we go again
jae Convert256Vectors



Convert128Vectors:
vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes
vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order
vmovups xmmword PTR [rsi], xmm0 // Store the low part
vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part

add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements
add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements
sub rdx, 8 // Reduce the counter by 8 elements

jz ExitRoutine // If we are done, exit



ConvertMaskedVectors:
vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes
vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order

cmp rdx, 4 // Chek if we can store the complete lower vector
jae ConvertLowerVector

vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones
cmp rdx, 2 // Check how many converts we need
jb ConvertLower1
ja ConvertLower3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values
jmp ConvertLowerMaskedVector
ConvertLower1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value
jmp ConvertLowerMaskedVector
ConvertLower3:
vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values
ConvertLowerMaskedVector:
vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing
ConvertLowerVector:
vmovups xmmword PTR [rsi], xmm0 // Store the low part
sub rdx, 4 // Check if we still need to convert
jz ExitRoutine


add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements
vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones
cmp rdx, 2 // Check how many converts we need
jb ConvertUpper1
ja ConvertUpper3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values
jmp ConvertMaskedUpperVector
ConvertUpper1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value
jmp ConvertMaskedUpperVector
ConvertUpper3:
vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values
ConvertMaskedUpperVector:
vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples

jmp ExitRoutine
ExitRoutine:
ret
32 changes: 26 additions & 6 deletions onnxruntime/core/providers/cpu/tensor/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
#include "Eigen/src/Core/arch/Default/BFloat16.h"
#include "Eigen/src/Core/arch/Default/Half.h"

#if defined(_M_AMD64) && !defined(_M_ARM64EC)
#include "core/mlas/inc/mlas.h"
#endif

#include "core/common/cpuid_info.h"

Expand Down Expand Up @@ -255,20 +253,42 @@ struct TensorCasterNoSat<std::string, DstType> {

#endif

#if defined(_M_AMD64) && !defined(_M_ARM64EC)
// specializations to use optimized and Windows x64-specific
// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion

// tensor MLFloat16 -> float
#if (defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__)
template <>
struct TensorCaster<MLFloat16, float> {
void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
#if defined(_MSC_VER)
auto out_data = out.MutableData<float>();
auto in_data = in.Data<MLFloat16>();
const size_t shape_size = narrow<size_t>(shape.Size());
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size, onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT());
#else
if(onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()){
auto out_data = out.MutableData<float>();
auto in_data = in.Data<MLFloat16>();
const size_t shape_size = narrow<size_t>(shape.Size());
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size, true);
} else {
// Use the generic function
using SrcEigenCastType = typename EigenCastType<MLFloat16>::type;
using DstEigenCastType = typename EigenCastType<float>::type;

const std::ptrdiff_t shape_size = narrow<std::ptrdiff_t>(shape.Size());
const auto in_vector =
ConstEigenVectorMap<SrcEigenCastType>(reinterpret_cast<const SrcEigenCastType*>(in.Data<MLFloat16>()), shape_size);
auto out_vector =
EigenVectorMap<DstEigenCastType>(reinterpret_cast<DstEigenCastType*>(out.MutableData<float>()), shape_size);
out_vector = in_vector.template cast<DstEigenCastType>();
}
#endif
}
};
#endif

#if defined(_M_AMD64) && !defined(_M_ARM64EC)
// specializations to use optimized and Windows x64-specific
// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion

Tensor GetIntermediateMLFloat16ToFloatTensor(
const OpKernelContext& context, const TensorShape& shape, const Tensor& in) {
Expand Down

0 comments on commit 70ef759

Please sign in to comment.