Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi3 MoE cuda kernel #21819

Merged
merged 10 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ Do not modify directly.*
* <a href="#com.microsoft.QLinearSigmoid">com.microsoft.QLinearSigmoid</a>
* <a href="#com.microsoft.QLinearSoftmax">com.microsoft.QLinearSoftmax</a>
* <a href="#com.microsoft.QLinearWhere">com.microsoft.QLinearWhere</a>
* <a href="#com.microsoft.QMoE">com.microsoft.QMoE</a>
* <a href="#com.microsoft.QMoE4Bits">com.microsoft.QMoE4Bits</a>
* <a href="#com.microsoft.QMoE8Bits">com.microsoft.QMoE8Bits</a>
* <a href="#com.microsoft.QOrderedAttention">com.microsoft.QOrderedAttention</a>
* <a href="#com.microsoft.QOrderedGelu">com.microsoft.QOrderedGelu</a>
* <a href="#com.microsoft.QOrderedLayerNormalization">com.microsoft.QOrderedLayerNormalization</a>
Expand Down Expand Up @@ -3081,6 +3082,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
<dt><tt>use_sparse_mixer</tt> : int</dt>
<dd>Whether to use sparse mixer</dd>
</dl>

#### Inputs (5 - 8)
Expand Down Expand Up @@ -4394,7 +4397,7 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.QMoE"></a><a name="com.microsoft.qmoe">**com.microsoft.QMoE**</a>
### <a name="com.microsoft.QMoE4Bits"></a><a name="com.microsoft.qmoe4bits">**com.microsoft.QMoE4Bits**</a>
wangyems marked this conversation as resolved.
Show resolved Hide resolved

Int4 MoE

Expand Down Expand Up @@ -4457,6 +4460,71 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.QMoE8Bits"></a><a name="com.microsoft.qmoe8bits">**com.microsoft.QMoE8Bits**</a>

Int8 MoE

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
<dt><tt>use_sparse_mixer</tt> : int</dt>
wangyems marked this conversation as resolved.
Show resolved Hide resolved
<dd>Whether to use sparse mixer</dd>
</dl>

#### Inputs (7 - 11)

<dl>
<dt><tt>input</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size)</dd>
<dt><tt>fc1_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc2_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
<dt><tt>fc2_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size)</dd>
<dt><tt>fc3_scales</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dd>Constrain input and output types to float or float16 tensors.</dd>
<dt><tt>T1</tt> : tensor(uint8)</dt>
<dd>Constrain weights type to uint8 tensors.</dd>
</dl>


### <a name="com.microsoft.QOrderedAttention"></a><a name="com.microsoft.qorderedattention">**com.microsoft.QOrderedAttention**</a>

Quantized version of simplified Multi-Head Self Attention(using int8 with specific matrix Layout).
Expand Down
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,8 @@ Do not modify directly.*
|PackedAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float16)<br/> **T1** = tensor(uint8)|
|QMoE4Bits|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float16)<br/> **T1** = tensor(uint8)|
|QMoE8Bits|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float16)<br/> **T1** = tensor(uint8)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* attention_bias:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale:**F**<br> *in* B:**F**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
normalize_routing_weights_, use_sparse_mixer_);

size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE);
class CUDA_MS_OP_CLASS_NAME(1, QMoE);
class CUDA_MS_OP_CLASS_NAME(1, QMoE4Bits);
class CUDA_MS_OP_CLASS_NAME(1, QMoE8Bits);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention);
Expand Down Expand Up @@ -291,7 +292,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE4Bits)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE8Bits)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention)>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4100)
#pragma warning(disable : 4244)
#pragma warning(disable : 4200)
#endif

#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"

#if defined(_MSC_VER)
#pragma warning(pop)
#endif
namespace ort_fastertransformer {
template class MoeGemmRunner<half, uint8_t>;
} // namespace ort_fastertransformer

Loading
Loading