Skip to content

Commit

Permalink
Enable AVX NE CONVERT for FP16 to FP32 cast
Browse files Browse the repository at this point in the history
* Enable AVX_NE_CONVERT detection via CPUID.
* Developed x86 and amd64 assembly kernel using the new ISA.
* Integrated kernel.
  • Loading branch information
eralmual committed Jul 26, 2024
1 parent 3c0b407 commit b1325e0
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 7 deletions.
12 changes: 12 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm
${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm
)
if(MSVC_VERSION GREATER_EQUAL 1933)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm
)
endif()

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
Expand Down Expand Up @@ -536,6 +542,12 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.0)
set(mlas_platform_srcs_avx2
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S
)
endif()
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")

set(mlas_platform_srcs_avx512f
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ void CPUIDInfo::X86Init() {
if (max_SubLeaves >= 1) {
GetCPUID(7, 1, data);
has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5));
// Check for AVX_NE_CONVERT as half precision kernel uses it with AVX2 and F16C
has_avx_ne_convert_ = has_avx2_ && has_f16c_ && (data[3] & (1 << 5));
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class CPUIDInfo {
bool HasAVX512f() const { return has_avx512f_; }
bool HasAVX512_BF16() const { return has_avx512_bf16_; }
bool HasAVX512Skylake() const { return has_avx512_skylake_; }
bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/
bool HasF16C() const { return has_f16c_; }/*fp16 conversion inst*/
bool HasAVX_NE_CONVERT() const { return has_avx_ne_convert_; } /*fp16/bf16 conversion inst*/
bool HasSSE3() const { return has_sse3_; }
bool HasSSE4_1() const { return has_sse4_1_; }
bool IsHybrid() const { return is_hybrid_; }
Expand Down Expand Up @@ -101,6 +102,7 @@ class CPUIDInfo {
bool has_avx512_bf16_{false};
bool has_avx512_skylake_{false};
bool has_f16c_{false};
bool has_avx_ne_convert_{false};
bool has_sse3_{false};
bool has_sse4_1_{false};
bool is_hybrid_{false};
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,16 @@ MlasConvertHalfToFloatBuffer(
size_t Count
);

#if (_MSC_VER >= 1933) || (__GNUC__ >= 13)
extern "C" void
MLASCALL
MlasConvertHalfToFloatBufferAVX(
const unsigned short* Source,
float* Destination,
size_t Count
);
#endif

//
// Transpose routines.
//
Expand Down
148 changes: 148 additions & 0 deletions onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
;++
;
; Copyright (c) Intel Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; cvtfp16Avx2.asm
;
; Abstract:
;
; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA.
;
;--

.xlist
INCLUDE mlasi.inc
.list

.const
SINGLE_SIZE equ 4
HALF_SIZE equ 2
LOW_SELECTOR equ 00100000b
HIGH_SELECTOR equ 00110001b
SUBTTL "Convert buffer of half-precision floats to single-precision floats"
;++
;
; Routine Description:
;
; This routine converts the source buffer of half-precision floats to the
; destination buffer of single-precision floats.
;
; This implementation uses AVX2 instructions.
;
; Arguments:
;
; Source (rcx) - Supplies the address of the source buffer of half-precision
; floats.
;
; Destination (rdx) - Supplies the address of the destination buffer of
; single-precision floats.
;
; Count (r8) - Supplies the number of elements to convert.
;
; Return Value:
;
; None.
;
;--

LEAF_ENTRY MlasConvertHalfToFloatBufferAVX, _TEXT

test r8, r8
jz ExitRoutine
cmp r8, 8
jb ConvertMaskedVectors
cmp r8, 16
jb Convert128Vectors

Convert256Vectors:
vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes
vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes
vunpcklps ymm2, ymm0, ymm1 ; Interleave low part
vunpckhps ymm1, ymm0, ymm1 ; Interleave high part
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order
vmovups ymmword PTR [rdx], ymm0 ; Store the low part
vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part

add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements
add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements
sub r8, 16 ; Reduce the counter by 16 elements

jz ExitRoutine ; If we are done, exit
cmp r8, 16 ; If the vector is big enough, we go again
jae Convert256Vectors



Convert128Vectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part

add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements
add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements
sub r8, 8 ; Reduce the counter by 8 elements

jz ExitRoutine ; If we are done, exit



ConvertMaskedVectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order

cmp r8, 4 ; Chek if we can store the complete lower vector
jae ConvertLowerVector

vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertLower1
ja ConvertLower3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertLowerMaskedVector
ConvertLower1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertLowerMaskedVector
ConvertLower3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertLowerMaskedVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing
ConvertLowerVector:
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
sub r8, 4 ; Check if we still need to convert
jz ExitRoutine


add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements
vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertUpper1
ja ConvertUpper3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertMaskedUpperVector
ConvertUpper1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertMaskedUpperVector
ConvertUpper3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertMaskedUpperVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples

jmp ExitRoutine

ExitRoutine:
ret

LEAF_END MlasConvertHalfToFloatBufferAVX, _TEXT

END
125 changes: 125 additions & 0 deletions onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*++ Routine Description:

This routine converts the source buffer of half-precision floats to the
destination buffer of single-precision floats.

This implementation uses AVX2 instructions.

Arguments:

Source (rdi) - Supplies the address of the source buffer of half-precision
floats.

Destination (rsi) - Supplies the address of the destination buffer of
single-precision floats.

Count (rdx) - Supplies the number of elements to convert.

Return Value:

None.

--*/
.data
.equ SINGLE_SIZE, 4
.equ HALF_SIZE, 2
.equ LOW_SELECTOR, 0b00100000
.equ HIGH_SELECTOR, 0b00110001

.text
.globl MlasConvertHalfToFloatBufferAVX
.intel_syntax noprefix

MlasConvertHalfToFloatBufferAVX:
test rdx, rdx // Check if we have any elements to convert
jz ExitRoutine

AVX_NE_CONVERT:
cmp rdx, 8
jb ConvertMaskedVectors
cmp rdx, 16
jb Convert128Vectors

Convert256Vectors:
vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes
vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes
vunpcklps ymm2, ymm0, ymm1 // Interleave low part
vunpckhps ymm1, ymm0, ymm1 // Interleave high part
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order
vmovups ymmword PTR [rsi], ymm0 // Store the low part
vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part

add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements
add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements
sub rdx, 16 // Reduce the counter by 16 elements

jz ExitRoutine // If we are done, exit
cmp rdx, 16 // If the vector is big enough, we go again
jae Convert256Vectors



Convert128Vectors:
vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes
vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order
vmovups xmmword PTR [rsi], xmm0 // Store the low part
vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part

add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements
add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements
sub rdx, 8 // Reduce the counter by 8 elements

jz ExitRoutine // If we are done, exit



ConvertMaskedVectors:
vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes
vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order

cmp rdx, 4 // Chek if we can store the complete lower vector
jae ConvertLowerVector

vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones
cmp rdx, 2 // Check how many converts we need
jb ConvertLower1
ja ConvertLower3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values
jmp ConvertLowerMaskedVector
ConvertLower1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value
jmp ConvertLowerMaskedVector
ConvertLower3:
vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values
ConvertLowerMaskedVector:
vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing
ConvertLowerVector:
vmovups xmmword PTR [rsi], xmm0 // Store the low part
sub rdx, 4 // Check if we still need to convert
jz ExitRoutine


add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements
vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones
cmp rdx, 2 // Check how many converts we need
jb ConvertUpper1
ja ConvertUpper3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values
jmp ConvertMaskedUpperVector
ConvertUpper1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value
jmp ConvertMaskedUpperVector
ConvertUpper3:
vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values
ConvertMaskedUpperVector:
vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples

jmp ExitRoutine
ExitRoutine:
ret
Loading

0 comments on commit b1325e0

Please sign in to comment.