Skip to content

Commit

Permalink
Merge amd64 and add x86 implementations
Browse files Browse the repository at this point in the history
Fp162fp32 v2
  • Loading branch information
eralmual authored Jul 11, 2024
2 parents f5bc5d7 + 8a1a28c commit 803c718
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 177 deletions.
2 changes: 1 addition & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm
${MLAS_SRC_DIR}/amd64/sgemma.asm
${MLAS_SRC_DIR}/amd64/cvtfp16a.asm
${MLAS_SRC_DIR}/amd64/cvtfp16Avx2.asm
${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm
${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm
${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm
Expand Down Expand Up @@ -533,6 +532,7 @@ else()
${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S
${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S
${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S
${MLAS_SRC_DIR}/x86_64/cvtfp16a.S
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CPUIDInfo {
bool HasAVX512f() const { return has_avx512f_; }
bool HasAVX512_BF16() const { return has_avx512_bf16_; }
bool HasAVX512Skylake() const { return has_avx512_skylake_; }
bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/
bool HasF16C() const { return has_f16c_; }/*fp16 conversion inst*/
bool HasAVX_NE_CONVERT() const { return has_avx_ne_convert_; } /*fp16/bf16 conversion inst*/
bool HasSSE3() const { return has_sse3_; }
bool HasSSE4_1() const { return has_sse4_1_; }
Expand Down
11 changes: 2 additions & 9 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1034,15 +1034,8 @@ MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
float* Destination,
size_t Count
);

extern "C" void
MLASCALL
MlasConvertHalfToFloatBufferAVX2(
const unsigned short* Source,
float* Destination,
size_t Count
size_t Count,
bool useAVX
);

//
Expand Down
148 changes: 0 additions & 148 deletions onnxruntime/core/mlas/lib/amd64/cvtfp16Avx2.asm

This file was deleted.

120 changes: 112 additions & 8 deletions onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
;
; Abstract:
;
; This module implements routines to convert between FP16 and FP32 formats.
; This module implements routines to convert between FP16 and FP32 formats, one using old SSE
; instructions and new AVX_NE_CONVERT as well.
;
;--

Expand All @@ -21,42 +22,145 @@ INCLUDE mlasi.inc
.const

ALIGN 16
; Legacy implementation constants
MlasFp16MaskSign DD 4 DUP (00007FFFh)
MlasFp16CompareInfinity DD 4 DUP (00007C00h)
MlasFp16CompareSmallest DD 4 DUP (00000400h)
MlasFp16AdjustExponent DD 4 DUP (38000000h)
MlasFp16MagicDenormal DD 4 DUP (38800000h)
; AVX implementation constants
SINGLE_SIZE equ 4
HALF_SIZE equ 2
LOW_SELECTOR equ 00100000b
HIGH_SELECTOR equ 00110001b


SUBTTL "Convert buffer of half-precision floats to single-precision floats"
;++
;
; Routine Description:
;
; This routine converts the source buffer of half-precision floats to the
; destination buffer of single-precision floats.
;
; This implementation uses SSE2 instructions.
; This routine calls the implementation of the cast operator depending on the ISA flag.
;
; Arguments:
;
; Source (rcx) - Supplies the address of the source buffer of half-precision
; floats.
;
; Destination (edx) - Supplies the address of the destination buffer of
; Destination (rdx) - Supplies the address of the destination buffer of
; single-precision floats.
;
; Count (r8) - Supplies the number of elements to convert.
;
; ISA flag (r9) - Determines whether to use AVX_NE_CONVERT or not.
;
; Return Value:
;
; None.
;
;--

LEAF_ENTRY MlasConvertHalfToFloatBuffer, _TEXT

test r8,r8
LEAF_ENTRY MlasConvertHalfToFloatBuffer, _TEXT

test r8, r8 ; Check if we have any elements to convert
jz ExitRoutine
test r9, r9 ; Check if we need to use AVX_NE_CONVERT
jz SSE

AVX_NE_CONVERT:
cmp r8, 8
jb ConvertMaskedVectors
cmp r8, 16
jb Convert128Vectors



Convert256Vectors:
vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes
vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes
vunpcklps ymm2, ymm0, ymm1 ; Interleave low part
vunpckhps ymm1, ymm0, ymm1 ; Interleave high part
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order
vmovups ymmword PTR [rdx], ymm0 ; Store the low part
vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part

add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements
add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements
sub r8, 16 ; Reduce the counter by 16 elements

jz ExitRoutine ; If we are done, exit
cmp r8, 16 ; If the vector is big enough, we go again
jae Convert256Vectors



Convert128Vectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part

add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements
add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements
sub r8, 8 ; Reduce the counter by 8 elements

jz ExitRoutine ; If we are done, exit



ConvertMaskedVectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order

cmp r8, 4 ; Chek if we can store the complete lower vector
jae ConvertLowerVector

vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertLower1
ja ConvertLower3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertLowerMaskedVector
ConvertLower1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertLowerMaskedVector
ConvertLower3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertLowerMaskedVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing
ConvertLowerVector:
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
sub r8, 4 ; Check if we still need to convert
jz ExitRoutine


add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements
vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertUpper1
ja ConvertUpper3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertMaskedUpperVector
ConvertUpper1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertMaskedUpperVector
ConvertUpper3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertMaskedUpperVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples

jmp ExitRoutine



SSE:
cmp r8,4
jb LoadPartialVector

Expand Down
Loading

0 comments on commit 803c718

Please sign in to comment.