Skip to content

Commit

Permalink
checkin phi3 moe kernal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Aug 21, 2024
1 parent fb9ce18 commit d76d774
Show file tree
Hide file tree
Showing 14 changed files with 821 additions and 36 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
normalize_routing_weights_, use_sparse_mixer_);

size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE);
class CUDA_MS_OP_CLASS_NAME(1, QMoE);
class CUDA_MS_OP_CLASS_NAME(1, QMoE4Bits);
class CUDA_MS_OP_CLASS_NAME(1, QMoE8Bits);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention);
Expand Down Expand Up @@ -291,7 +292,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE4Bits)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE8Bits)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention)>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4100)
#pragma warning(disable : 4244)
#pragma warning(disable : 4200)
#endif

#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"

#if defined(_MSC_VER)
#pragma warning(pop)
#endif
namespace ort_fastertransformer {
template class MoeGemmRunner<half, uint8_t>;
} // namespace ort_fastertransformer

157 changes: 141 additions & 16 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ __launch_bounds__(TPB) __global__
const int block_row = blockIdx.x;

const bool should_process_row = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
const int thread_row_offset = blockIdx.x * num_experts;
float output_row_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f);

cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
const int idx = thread_row_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];

Expand Down Expand Up @@ -169,6 +169,107 @@ __launch_bounds__(TPB) __global__
}
#endif

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
template <typename T, int TPB, int NUM_EXPERTS>
__launch_bounds__(TPB) __global__ void sparse_mixer_top2(const T *, T *, int *, int *, const float) {
// Does not support pre-Kepler architectures
;

Check warning on line 176 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Line contains only semicolon. If this should be an empty statement, use {} instead. [whitespace/semicolon] [5] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:176: Line contains only semicolon. If this should be an empty statement, use {} instead. [whitespace/semicolon] [5]
}
#else

template <typename T, int TPB, int NUM_EXPERTS>
__launch_bounds__(TPB) __global__
void sparse_mixer_top2(const T *inputs, T *output, int *indices, int *source_rows, const float jitter_eps) {
static constexpr int K = 2;

using cub_kvp = cub::KeyValuePair<int, T>;
using KVBlockReduce = cub::BlockReduce<cub_kvp, TPB>;

__shared__ float result_kvp_value[K];
__shared__ typename KVBlockReduce::TempStorage kvTmpStorage;

cub_kvp thread_kvp;
cub::ArgMax arg_max;

int num_rows = gridDim.x;
const int block_row = blockIdx.x;

const int thread_row_offset = blockIdx.x * NUM_EXPERTS;

float factor[K];
bool logits_mask[K];

#pragma unroll
for (int k_idx = 0; k_idx < K; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f);

cub_kvp inp_kvp;
#pragma unroll
for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) {
const int idx = thread_row_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs[idx];

for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[K * block_row + prior_k];

if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}

thread_kvp = arg_max(inp_kvp, thread_kvp);
}

const cub_kvp result_kvp = KVBlockReduce(kvTmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = K * block_row + k_idx;
result_kvp_value[k_idx] = (float)result_kvp.value;

Check warning on line 228 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:228: Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4]
indices[idx] = result_kvp.key;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();

#pragma unroll
for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) {
const int idx = thread_row_offset + expert;
factor[k_idx] = max(abs((float)inputs[idx]), result_kvp_value[k_idx]);

Check warning on line 237 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:237: Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4]
logits_mask[k_idx] = (result_kvp_value[k_idx] - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]);

Check warning on line 238 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:238: Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4]
if (k_idx == 1 && expert == indices[K * block_row]) {
logits_mask[1] = true;
}
}
}

#pragma unroll
for (int k_idx = 0; k_idx < K; ++k_idx) {
float row_sum(0);

#pragma unroll
for (int ii = threadIdx.x; ii < NUM_EXPERTS; ii += TPB) {
const int idx = thread_row_offset + ii;
row_sum += logits_mask[k_idx] ? 0 : exp((static_cast<float>(inputs[idx]) - result_kvp_value[k_idx]));
}

#pragma unroll
for (int mask = NUM_EXPERTS / 2; mask > 0; mask /= 2) {
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, NUM_EXPERTS);
}

const float normalizing_factor = 1.f / row_sum;

const int idx = K * block_row + k_idx;
if (threadIdx.x == indices[idx]) {
const int input_idx = thread_row_offset + threadIdx.x;
output[idx] = logits_mask[k_idx] ? 0
: exp((static_cast<float>(inputs[input_idx]) - result_kvp_value[k_idx])) *
normalizing_factor;
}
}
}
#endif

// ====================== TopK softmax things ===============================

/*
Expand Down Expand Up @@ -406,9 +507,30 @@ void topk_gating_softmax_launcher_helper(const T *input, const bool *finished, T
template <typename T>
void topk_gating_softmax_kernelLauncher(const T *input, const bool *finished, T *output, T *softmax_temp_output,
int *indices, int *source_row, int num_rows, int num_experts, int k,
bool normalize_routing_weights, cudaStream_t stream) {
bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;

if (use_sparse_mixer) {
static constexpr int TPB = WARP_SIZE * WARPS_PER_TB;
static constexpr float jitter_eps = 0.01f;

switch (num_experts) {
case 8: {
sparse_mixer_top2<T, TPB, 8><<<num_rows, TPB, 0, stream>>>(input, output, indices, source_row, jitter_eps);
break;
}
case 16: {
sparse_mixer_top2<T, TPB, 16><<<num_rows, TPB, 0, stream>>>(input, output, indices, source_row, jitter_eps);
break;
}

default: {
ORT_THROW("Sparse mixer only supports 8 and 16 experts");
}
}
return;
}

switch (num_experts) {
case 2: {
topk_gating_softmax_launcher_helper<T, 2, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
Expand Down Expand Up @@ -542,9 +664,9 @@ __global__ void dispatch_activations_kernel(int64_t *total_rows_before_expert, i

template <typename T, typename WeightType, typename Enable>
CutlassMoeFCRunner<T, WeightType, Enable>::CutlassMoeFCRunner(int sm_version, bool has_fc3,
bool normalize_routing_weights)
bool normalize_routing_weights, bool use_sparse_mixer)
: has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0),
normalize_routing_weights_(normalize_routing_weights) {
normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) {
moe_gemm_runner_.initialize(sm_version);
}

Expand Down Expand Up @@ -729,7 +851,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
configure_ws_ptrs(workspace_ptr, static_cast<size_t>(num_rows), static_cast<size_t>(hidden_size),
static_cast<size_t>(inter_size), static_cast<size_t>(num_experts), static_cast<size_t>(k));
topk_gating_softmax_kernelLauncher<T>(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream);
source_rows_, num_rows, num_experts, k, normalize_routing_weights_,
use_sparse_mixer_, stream);

const int sorter_ws_size_bytes = static_cast<int>(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)));
sorter_.run(reinterpret_cast<void *>(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_,
Expand All @@ -748,7 +871,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
stream);
}

// moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, expanded_active_expert_rows);
// moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size,
// expanded_active_expert_rows);
moe_gemm_runner_.moe_gemm_bias_act(
permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases,
fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index,
Expand Down Expand Up @@ -868,19 +992,19 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::get_total_rows_info(int64_t expe
// experts in the end.

// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0,
// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input
// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus
// of the expanded index.
// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ...
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we
// simply take the modulus of the expanded index.

template <typename T>
__global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *permuted_output,
const int *expanded_dest_row_to_expanded_source_row,
int *expanded_source_row_to_expanded_dest_row, int num_rows,
int active_rows, int cols) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the
// reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need
// the reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in
// MoE. 1 thread block will be responsible for all k summations.
const int expanded_dest_row = blockIdx.x;
const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row];
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -1014,14 +1138,15 @@ void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *red

// ========================= TopK Softmax specializations ===========================
template void topk_gating_softmax_kernelLauncher(const float *, const bool *, float *, float *, int *, int *, int, int,
int, bool, cudaStream_t);
int, bool, bool, cudaStream_t);
template void topk_gating_softmax_kernelLauncher(const half *, const bool *, half *, half *, int *, int *, int, int,
int, bool, cudaStream_t);
int, bool, bool, cudaStream_t);

// ==================== Variable batched GEMM specializations ==================================
template class CutlassMoeFCRunner<float, float>;
template class CutlassMoeFCRunner<half, half>;
template class CutlassMoeFCRunner<half, cutlass::uint4b_t>;
template class CutlassMoeFCRunner<half, uint8_t>;

// ===================== Specializations for init routing =========================
template void initialize_moe_routing_kernelLauncher(const float *, float *, const int *, int *, int, int, int, int,
Expand All @@ -1043,4 +1168,4 @@ template void finalize_moe_routing_kernelLauncher(const float *, float *, const
template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *,
const half *, const int *, const int *, int, int, int, cudaStream_t);

} // namespace ort_fastertransformer
} // namespace ort_fastertransformer

Check warning on line 1171 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:1171: At least two spaces is best between code and comments [whitespace/comments] [2]

Check warning on line 1171 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Could not find a newline character at the end of the file. [whitespace/ending_newline] [5] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:1171: Could not find a newline character at the end of the file. [whitespace/ending_newline] [5]
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ template <typename T, /*The type used for activations/scales/compute*/
typename Enable = void>
class CutlassMoeFCRunner {
public:
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);

size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k);

Expand Down Expand Up @@ -161,6 +161,7 @@ class CutlassMoeFCRunner {

bool has_fc3_;
bool normalize_routing_weights_;
bool use_sparse_mixer_;

// Cuda events
contrib::cuda::AutoDestoryCudaEvent cuda_event_;
Expand All @@ -175,7 +176,7 @@ class CutlassMoeFCRunner {
template <typename WeightType>
class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_same<float, WeightType>::value>> {
public:
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);

size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
return 0;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/moe/moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
const int sm = device_prop.major * 10 + device_prop.minor;

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
normalize_routing_weights_, use_sparse_mixer_);

size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/moe_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum class MoEParallelType {
enum class MoEQuantType {
None = 0,
UINT4 = 1,
UINT8 = 2,
};

struct MoEParameters {
Expand Down Expand Up @@ -225,9 +226,15 @@ class MoEBase {
}

normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault<int64_t>("normalize_routing_weights", 0) == 1;

use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault<int64_t>("use_sparse_mixer", 0) == 1;
if (use_sparse_mixer_) {
ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2");
}
}

bool normalize_routing_weights_;
bool use_sparse_mixer_;
int64_t k_;
ort_fastertransformer::ActivationType activation_type_;
};
Expand Down
Loading

0 comments on commit d76d774

Please sign in to comment.