Skip to content

Commit

Permalink
Phi3 MoE cuda kernel (#21819)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <[email protected]>
  • Loading branch information
wangyems and Your Name authored Aug 27, 2024
1 parent 2522220 commit 1d059b8
Show file tree
Hide file tree
Showing 13 changed files with 1,075 additions and 590 deletions.
14 changes: 10 additions & 4 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3083,6 +3083,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
<dt><tt>use_sparse_mixer</tt> : int</dt>
<dd>Whether to use sparse mixer</dd>
</dl>

#### Inputs (5 - 8)
Expand Down Expand Up @@ -4398,7 +4400,7 @@ This version of the operator has been available since version 1 of the 'com.micr

### <a name="com.microsoft.QMoE"></a><a name="com.microsoft.qmoe">**com.microsoft.QMoE**</a>

Int4 MoE
Quantized MoE

#### Version

Expand All @@ -4409,10 +4411,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
<dt><tt>expert_weight_bits</tt> : int</dt>
<dd>Number of bits used in quantized weights. Default is 4 bits</dd>
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
<dt><tt>use_sparse_mixer</tt> : int</dt>
<dd>Whether to use sparse mixer</dd>
</dl>

#### Inputs (7 - 11)
Expand All @@ -4423,19 +4429,19 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
<dt><tt>fc1_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc2_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size / 2)</dd>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
<dt><tt>fc2_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
<dt><tt>fc3_scales</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
normalize_routing_weights_, use_sparse_mixer_);

size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4100)
#pragma warning(disable : 4244)
#pragma warning(disable : 4200)
#endif

#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"

#if defined(_MSC_VER)
#pragma warning(pop)
#endif
namespace ort_fastertransformer {
template class MoeGemmRunner<half, uint8_t>;
} // namespace ort_fastertransformer

157 changes: 141 additions & 16 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ __launch_bounds__(TPB) __global__
const int block_row = blockIdx.x;

const bool should_process_row = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
const int thread_row_offset = blockIdx.x * num_experts;
float output_row_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f);

cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
const int idx = thread_row_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];

Expand Down Expand Up @@ -169,6 +169,107 @@ __launch_bounds__(TPB) __global__
}
#endif

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
template <typename T, int TPB, int NUM_EXPERTS>
__launch_bounds__(TPB) __global__ void sparse_mixer_top2(const T *, T *, int *, int *, const float) {
// Does not support pre-Kepler architectures
;
}
#else

template <typename T, int TPB, int NUM_EXPERTS>
__launch_bounds__(TPB) __global__
void sparse_mixer_top2(const T *inputs, T *output, int *indices, int *source_rows, const float jitter_eps) {
static constexpr int K = 2;

using cub_kvp = cub::KeyValuePair<int, T>;
using KVBlockReduce = cub::BlockReduce<cub_kvp, TPB>;

__shared__ float result_kvp_value[K];
__shared__ typename KVBlockReduce::TempStorage kvTmpStorage;

cub_kvp thread_kvp;
cub::ArgMax arg_max;

int num_rows = gridDim.x;
const int block_row = blockIdx.x;

const int thread_row_offset = blockIdx.x * NUM_EXPERTS;

float factor[K];
bool logits_mask[K];

#pragma unroll
for (int k_idx = 0; k_idx < K; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f);

cub_kvp inp_kvp;
#pragma unroll
for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) {
const int idx = thread_row_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs[idx];

for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[K * block_row + prior_k];

if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}

thread_kvp = arg_max(inp_kvp, thread_kvp);
}

const cub_kvp result_kvp = KVBlockReduce(kvTmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = K * block_row + k_idx;
result_kvp_value[k_idx] = (float)result_kvp.value;
indices[idx] = result_kvp.key;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();

#pragma unroll
for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) {
const int idx = thread_row_offset + expert;
factor[k_idx] = max(abs((float)inputs[idx]), result_kvp_value[k_idx]);
logits_mask[k_idx] = (result_kvp_value[k_idx] - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]);
if (k_idx == 1 && expert == indices[K * block_row]) {
logits_mask[1] = true;
}
}
}

#pragma unroll
for (int k_idx = 0; k_idx < K; ++k_idx) {
float row_sum(0);

#pragma unroll
for (int ii = threadIdx.x; ii < NUM_EXPERTS; ii += TPB) {
const int idx = thread_row_offset + ii;
row_sum += logits_mask[k_idx] ? 0 : exp((static_cast<float>(inputs[idx]) - result_kvp_value[k_idx]));
}

#pragma unroll
for (int mask = NUM_EXPERTS / 2; mask > 0; mask /= 2) {
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, NUM_EXPERTS);
}

const float normalizing_factor = 1.f / row_sum;

const int idx = K * block_row + k_idx;
if (threadIdx.x == indices[idx]) {
const int input_idx = thread_row_offset + threadIdx.x;
output[idx] = logits_mask[k_idx] ? 0
: exp((static_cast<float>(inputs[input_idx]) - result_kvp_value[k_idx])) *
normalizing_factor;
}
}
}
#endif

// ====================== TopK softmax things ===============================

/*
Expand Down Expand Up @@ -406,9 +507,30 @@ void topk_gating_softmax_launcher_helper(const T *input, const bool *finished, T
template <typename T>
void topk_gating_softmax_kernelLauncher(const T *input, const bool *finished, T *output, T *softmax_temp_output,
int *indices, int *source_row, int num_rows, int num_experts, int k,
bool normalize_routing_weights, cudaStream_t stream) {
bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;

if (use_sparse_mixer) {
static constexpr int TPB = WARP_SIZE * WARPS_PER_TB;
static constexpr float jitter_eps = 0.01f;

switch (num_experts) {
case 8: {
sparse_mixer_top2<T, TPB, 8><<<num_rows, TPB, 0, stream>>>(input, output, indices, source_row, jitter_eps);
break;
}
case 16: {
sparse_mixer_top2<T, TPB, 16><<<num_rows, TPB, 0, stream>>>(input, output, indices, source_row, jitter_eps);
break;
}

default: {
ORT_THROW("Sparse mixer only supports 8 and 16 experts");
}
}
return;
}

switch (num_experts) {
case 2: {
topk_gating_softmax_launcher_helper<T, 2, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
Expand Down Expand Up @@ -542,9 +664,9 @@ __global__ void dispatch_activations_kernel(int64_t *total_rows_before_expert, i

template <typename T, typename WeightType, typename Enable>
CutlassMoeFCRunner<T, WeightType, Enable>::CutlassMoeFCRunner(int sm_version, bool has_fc3,
bool normalize_routing_weights)
bool normalize_routing_weights, bool use_sparse_mixer)
: has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0),
normalize_routing_weights_(normalize_routing_weights) {
normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) {
moe_gemm_runner_.initialize(sm_version);
}

Expand Down Expand Up @@ -729,7 +851,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
configure_ws_ptrs(workspace_ptr, static_cast<size_t>(num_rows), static_cast<size_t>(hidden_size),
static_cast<size_t>(inter_size), static_cast<size_t>(num_experts), static_cast<size_t>(k));
topk_gating_softmax_kernelLauncher<T>(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream);
source_rows_, num_rows, num_experts, k, normalize_routing_weights_,
use_sparse_mixer_, stream);

const int sorter_ws_size_bytes = static_cast<int>(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)));
sorter_.run(reinterpret_cast<void *>(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_,
Expand All @@ -748,7 +871,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
stream);
}

// moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, expanded_active_expert_rows);
// moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size,
// expanded_active_expert_rows);
moe_gemm_runner_.moe_gemm_bias_act(
permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases,
fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index,
Expand Down Expand Up @@ -868,19 +992,19 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::get_total_rows_info(int64_t expe
// experts in the end.

// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0,
// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input
// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus
// of the expanded index.
// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ...
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we
// simply take the modulus of the expanded index.

template <typename T>
__global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *permuted_output,
const int *expanded_dest_row_to_expanded_source_row,
int *expanded_source_row_to_expanded_dest_row, int num_rows,
int active_rows, int cols) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the
// reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need
// the reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in
// MoE. 1 thread block will be responsible for all k summations.
const int expanded_dest_row = blockIdx.x;
const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row];
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -1014,14 +1138,15 @@ void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *red

// ========================= TopK Softmax specializations ===========================
template void topk_gating_softmax_kernelLauncher(const float *, const bool *, float *, float *, int *, int *, int, int,
int, bool, cudaStream_t);
int, bool, bool, cudaStream_t);
template void topk_gating_softmax_kernelLauncher(const half *, const bool *, half *, half *, int *, int *, int, int,
int, bool, cudaStream_t);
int, bool, bool, cudaStream_t);

// ==================== Variable batched GEMM specializations ==================================
template class CutlassMoeFCRunner<float, float>;
template class CutlassMoeFCRunner<half, half>;
template class CutlassMoeFCRunner<half, cutlass::uint4b_t>;
template class CutlassMoeFCRunner<half, uint8_t>;

// ===================== Specializations for init routing =========================
template void initialize_moe_routing_kernelLauncher(const float *, float *, const int *, int *, int, int, int, int,
Expand All @@ -1043,4 +1168,4 @@ template void finalize_moe_routing_kernelLauncher(const float *, float *, const
template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *,
const half *, const int *, const int *, int, int, int, cudaStream_t);

} // namespace ort_fastertransformer
} // namespace ort_fastertransformer
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ template <typename T, /*The type used for activations/scales/compute*/
typename Enable = void>
class CutlassMoeFCRunner {
public:
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);

size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k);

Expand Down Expand Up @@ -161,6 +161,7 @@ class CutlassMoeFCRunner {

bool has_fc3_;
bool normalize_routing_weights_;
bool use_sparse_mixer_;

// Cuda events
contrib::cuda::AutoDestoryCudaEvent cuda_event_;
Expand All @@ -175,7 +176,7 @@ class CutlassMoeFCRunner {
template <typename WeightType>
class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_same<float, WeightType>::value>> {
public:
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);

size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
return 0;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/moe/moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
const int sm = device_prop.major * 10 + device_prop.minor;

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
normalize_routing_weights_, use_sparse_mixer_);

size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/moe_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum class MoEParallelType {
enum class MoEQuantType {
None = 0,
UINT4 = 1,
UINT8 = 2,
};

struct MoEParameters {
Expand Down Expand Up @@ -225,9 +226,15 @@ class MoEBase {
}

normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault<int64_t>("normalize_routing_weights", 0) == 1;

use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault<int64_t>("use_sparse_mixer", 0) == 1;
if (use_sparse_mixer_) {
ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2");
}
}

bool normalize_routing_weights_;
bool use_sparse_mixer_;
int64_t k_;
ort_fastertransformer::ActivationType activation_type_;
};
Expand Down
Loading

0 comments on commit 1d059b8

Please sign in to comment.