Skip to content

Commit

Permalink
Enable AVX NE CONVERT for FP16 to FP32 cast
Browse files Browse the repository at this point in the history
* Enable AVX_NE_CONVERT detection via CPUID.
* Developed assembly kernel using the new ISA.
* Integrated kernel.
  • Loading branch information
eralmual committed Jun 26, 2024
1 parent 3c0b407 commit f5bc5d7
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 1 deletion.
1 change: 1 addition & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm
${MLAS_SRC_DIR}/amd64/sgemma.asm
${MLAS_SRC_DIR}/amd64/cvtfp16a.asm
${MLAS_SRC_DIR}/amd64/cvtfp16Avx2.asm
${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm
${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm
${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ void CPUIDInfo::X86Init() {
if (max_SubLeaves >= 1) {
GetCPUID(7, 1, data);
has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5));
// Check for AVX_NE_CONVERT as half precision kernel uses it with AVX2 and F16C
has_avx_ne_convert_ = has_avx2_ && has_f16c_ && (data[3] & (1 << 5));
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CPUIDInfo {
bool HasAVX512_BF16() const { return has_avx512_bf16_; }
bool HasAVX512Skylake() const { return has_avx512_skylake_; }
bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/
bool HasAVX_NE_CONVERT() const { return has_avx_ne_convert_; } /*fp16/bf16 conversion inst*/
bool HasSSE3() const { return has_sse3_; }
bool HasSSE4_1() const { return has_sse4_1_; }
bool IsHybrid() const { return is_hybrid_; }
Expand Down Expand Up @@ -101,6 +102,7 @@ class CPUIDInfo {
bool has_avx512_bf16_{false};
bool has_avx512_skylake_{false};
bool has_f16c_{false};
bool has_avx_ne_convert_{false};
bool has_sse3_{false};
bool has_sse4_1_{false};
bool is_hybrid_{false};
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,14 @@ MlasConvertHalfToFloatBuffer(
size_t Count
);

extern "C" void
MLASCALL
MlasConvertHalfToFloatBufferAVX2(
const unsigned short* Source,
float* Destination,
size_t Count
);

//
// Transpose routines.
//
Expand Down
148 changes: 148 additions & 0 deletions onnxruntime/core/mlas/lib/amd64/cvtfp16Avx2.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
;++
;
; Copyright (c) Intel Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; cvtfp16Avx2.asm
;
; Abstract:
;
; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA.
;
;--

.xlist
INCLUDE mlasi.inc
.list

.const
SINGLE_SIZE equ 4
HALF_SIZE equ 2
LOW_SELECTOR equ 00100000b
HIGH_SELECTOR equ 00110001b
SUBTTL "Convert buffer of half-precision floats to single-precision floats"
;++
;
; Routine Description:
;
; This routine converts the source buffer of half-precision floats to the
; destination buffer of single-precision floats.
;
; This implementation uses AVX2 instructions.
;
; Arguments:
;
; Source (rcx) - Supplies the address of the source buffer of half-precision
; floats.
;
; Destination (rdx) - Supplies the address of the destination buffer of
; single-precision floats.
;
; Count (r8) - Supplies the number of elements to convert.
;
; Return Value:
;
; None.
;
;--

LEAF_ENTRY MlasConvertHalfToFloatBufferAVX2, _TEXT

test r8, r8
jz ExitRoutine
cmp r8, 8
jb ConvertMaskedVectors
cmp r8, 16
jb Convert128Vectors



Convert256Vectors:
vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes
vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes
vunpcklps ymm2, ymm0, ymm1 ; Interleave low part
vunpckhps ymm1, ymm0, ymm1 ; Interleave high part
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order
vmovups ymmword PTR [rdx], ymm0 ; Store the low part
vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part

add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements
add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements
sub r8, 16 ; Reduce the counter by 16 elements

jz ExitRoutine ; If we are done, exit
cmp r8, 16 ; If the vector is big enough, we go again
jae Convert256Vectors



Convert128Vectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part

add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements
add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements
sub r8, 8 ; Reduce the counter by 8 elements

jz ExitRoutine ; If we are done, exit



ConvertMaskedVectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order

cmp r8, 4 ; Chek if we can store the complete lower vector
jae ConvertLowerVector

vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertLower1
ja ConvertLower3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertLowerMaskedVector
ConvertLower1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertLowerMaskedVector
ConvertLower3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertLowerMaskedVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing
ConvertLowerVector:
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
sub r8, 4 ; Check if we still need to convert
jz ExitRoutine


add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements
vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertUpper1
ja ConvertUpper3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertMaskedUpperVector
ConvertUpper1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertMaskedUpperVector
ConvertUpper3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertMaskedUpperVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples

ExitRoutine:
ret

LEAF_END MlasConvertHalfToFloatBufferAVX2, _TEXT

END
10 changes: 9 additions & 1 deletion onnxruntime/core/providers/cpu/tensor/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include "core/mlas/inc/mlas.h"
#endif

#include "core/common/cpuid_info.h"


namespace onnxruntime {

namespace op_kernel_type_control {
Expand Down Expand Up @@ -263,7 +266,12 @@ struct TensorCaster<MLFloat16, float> {
auto out_data = out.MutableData<float>();
auto in_data = in.Data<MLFloat16>();
const size_t shape_size = narrow<size_t>(shape.Size());
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
if (onnxruntime::CPUIDInfo::GetCPUIDInfo().HasAVX_NE_CONVERT()) {
MlasConvertHalfToFloatBufferAVX2(&in_data[0].val, out_data, shape_size);
}
else {

Check warning on line 272 in onnxruntime/core/providers/cpu/tensor/cast_op.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 An else should appear on the same line as the preceding } [whitespace/newline] [4] Raw Output: onnxruntime/core/providers/cpu/tensor/cast_op.cc:272: An else should appear on the same line as the preceding } [whitespace/newline] [4]

Check warning on line 272 in onnxruntime/core/providers/cpu/tensor/cast_op.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: onnxruntime/core/providers/cpu/tensor/cast_op.cc:272: If an else has a brace on one side, it should have it on both [readability/braces] [5]
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
}
}
};

Expand Down

0 comments on commit f5bc5d7

Please sign in to comment.