From d76d77487447f29019cfe78899e77bcb6f174c98 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 21 Aug 2024 18:21:54 +0000 Subject: [PATCH] checkin phi3 moe kernal changes --- .../cuda/collective/sharded_moe.cc | 2 +- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 6 +- .../moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu | 31 + .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 157 ++++- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 5 +- onnxruntime/contrib_ops/cuda/moe/moe.cc | 2 +- onnxruntime/contrib_ops/cuda/moe/moe_base.h | 7 + .../cuda/quantization/moe_quantization.cc | 27 +- .../cuda/quantization/moe_quantization.h | 1 + .../core/graph/contrib_ops/collective_defs.cc | 4 + .../core/graph/contrib_ops/contrib_defs.cc | 62 +- onnxruntime/core/graph/contrib_ops/ms_opset.h | 6 +- onnxruntime/test/contrib_ops/moe_test.cc | 2 +- .../transformers/test_parity_phi3_moe.py | 545 ++++++++++++++++++ 14 files changed, 821 insertions(+), 36 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu create mode 100644 onnxruntime/test/python/transformers/test_parity_phi3_moe.py diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 013b7e1779773..1a4a63de38790 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -79,7 +79,7 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_); + normalize_routing_weights_, use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 21bd5eb91c20f..8a2c2dd16c69c 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -87,7 +87,8 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE); -class CUDA_MS_OP_CLASS_NAME(1, QMoE); +class CUDA_MS_OP_CLASS_NAME(1, QMoE4Bits); +class CUDA_MS_OP_CLASS_NAME(1, QMoE8Bits); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention); @@ -291,7 +292,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu new file mode 100644 index 0000000000000..b0a72a1d2506a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer + diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 5f26de4810c42..a6ea9f4b61271 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -127,7 +127,7 @@ __launch_bounds__(TPB) __global__ const int block_row = blockIdx.x; const bool should_process_row = finished ? !finished[block_row] : true; - const int thread_read_offset = blockIdx.x * num_experts; + const int thread_row_offset = blockIdx.x * num_experts; float output_row_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; @@ -135,7 +135,7 @@ __launch_bounds__(TPB) __global__ cub_kvp inp_kvp; for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { - const int idx = thread_read_offset + expert; + const int idx = thread_row_offset + expert; inp_kvp.key = expert; inp_kvp.value = inputs_after_softmax[idx]; @@ -169,6 +169,107 @@ __launch_bounds__(TPB) __global__ } #endif +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 +template +__launch_bounds__(TPB) __global__ void sparse_mixer_top2(const T *, T *, int *, int *, const float) { + // Does not support pre-Kepler architectures + ; +} +#else + +template +__launch_bounds__(TPB) __global__ + void sparse_mixer_top2(const T *inputs, T *output, int *indices, int *source_rows, const float jitter_eps) { + static constexpr int K = 2; + + using cub_kvp = cub::KeyValuePair; + using KVBlockReduce = cub::BlockReduce; + + __shared__ float result_kvp_value[K]; + __shared__ typename KVBlockReduce::TempStorage kvTmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const int thread_row_offset = blockIdx.x * NUM_EXPERTS; + + float factor[K]; + bool logits_mask[K]; + +#pragma unroll + for (int k_idx = 0; k_idx < K; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); + + cub_kvp inp_kvp; +#pragma unroll + for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) { + const int idx = thread_row_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[K * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = KVBlockReduce(kvTmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = K * block_row + k_idx; + result_kvp_value[k_idx] = (float)result_kvp.value; + indices[idx] = result_kvp.key; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + +#pragma unroll + for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) { + const int idx = thread_row_offset + expert; + factor[k_idx] = max(abs((float)inputs[idx]), result_kvp_value[k_idx]); + logits_mask[k_idx] = (result_kvp_value[k_idx] - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]); + if (k_idx == 1 && expert == indices[K * block_row]) { + logits_mask[1] = true; + } + } + } + +#pragma unroll + for (int k_idx = 0; k_idx < K; ++k_idx) { + float row_sum(0); + +#pragma unroll + for (int ii = threadIdx.x; ii < NUM_EXPERTS; ii += TPB) { + const int idx = thread_row_offset + ii; + row_sum += logits_mask[k_idx] ? 0 : exp((static_cast(inputs[idx]) - result_kvp_value[k_idx])); + } + +#pragma unroll + for (int mask = NUM_EXPERTS / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, NUM_EXPERTS); + } + + const float normalizing_factor = 1.f / row_sum; + + const int idx = K * block_row + k_idx; + if (threadIdx.x == indices[idx]) { + const int input_idx = thread_row_offset + threadIdx.x; + output[idx] = logits_mask[k_idx] ? 0 + : exp((static_cast(inputs[input_idx]) - result_kvp_value[k_idx])) * + normalizing_factor; + } + } +} +#endif + // ====================== TopK softmax things =============================== /* @@ -406,9 +507,30 @@ void topk_gating_softmax_launcher_helper(const T *input, const bool *finished, T template void topk_gating_softmax_kernelLauncher(const T *input, const bool *finished, T *output, T *softmax_temp_output, int *indices, int *source_row, int num_rows, int num_experts, int k, - bool normalize_routing_weights, cudaStream_t stream) { + bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; + if (use_sparse_mixer) { + static constexpr int TPB = WARP_SIZE * WARPS_PER_TB; + static constexpr float jitter_eps = 0.01f; + + switch (num_experts) { + case 8: { + sparse_mixer_top2<<>>(input, output, indices, source_row, jitter_eps); + break; + } + case 16: { + sparse_mixer_top2<<>>(input, output, indices, source_row, jitter_eps); + break; + } + + default: { + ORT_THROW("Sparse mixer only supports 8 and 16 experts"); + } + } + return; + } + switch (num_experts) { case 2: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, @@ -542,9 +664,9 @@ __global__ void dispatch_activations_kernel(int64_t *total_rows_before_expert, i template CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, - bool normalize_routing_weights) + bool normalize_routing_weights, bool use_sparse_mixer) : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), - normalize_routing_weights_(normalize_routing_weights) { + normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) { moe_gemm_runner_.initialize(sm_version); } @@ -729,7 +851,8 @@ void CutlassMoeFCRunner::run_moe_fc( configure_ws_ptrs(workspace_ptr, static_cast(num_rows), static_cast(hidden_size), static_cast(inter_size), static_cast(num_experts), static_cast(k)); topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, - source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream); + source_rows_, num_rows, num_experts, k, normalize_routing_weights_, + use_sparse_mixer_, stream); const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); sorter_.run(reinterpret_cast(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, @@ -748,7 +871,8 @@ void CutlassMoeFCRunner::run_moe_fc( stream); } - // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, expanded_active_expert_rows); + // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, + // expanded_active_expert_rows); moe_gemm_runner_.moe_gemm_bias_act( permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, @@ -868,9 +992,9 @@ void CutlassMoeFCRunner::get_total_rows_info(int64_t expe // experts in the end. // Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0, -// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input -// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus -// of the expanded index. +// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... +// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we +// simply take the modulus of the expanded index. template __global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *permuted_output, @@ -878,9 +1002,9 @@ __global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *perm int *expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, int cols) { // Reverse permutation map. - // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the - // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 - // thread block will be responsible for all k summations. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need + // the reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in + // MoE. 1 thread block will be responsible for all k summations. const int expanded_dest_row = blockIdx.x; const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; if (threadIdx.x == 0) { @@ -1014,14 +1138,15 @@ void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *red // ========================= TopK Softmax specializations =========================== template void topk_gating_softmax_kernelLauncher(const float *, const bool *, float *, float *, int *, int *, int, int, - int, bool, cudaStream_t); + int, bool, bool, cudaStream_t); template void topk_gating_softmax_kernelLauncher(const half *, const bool *, half *, half *, int *, int *, int, int, - int, bool, cudaStream_t); + int, bool, bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner; // ===================== Specializations for init routing ========================= template void initialize_moe_routing_kernelLauncher(const float *, float *, const int *, int *, int, int, int, int, @@ -1043,4 +1168,4 @@ template void finalize_moe_routing_kernelLauncher(const float *, float *, const template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *, const half *, const int *, const int *, int, int, int, cudaStream_t); -} // namespace ort_fastertransformer +} // namespace ort_fastertransformer \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 18a26e6a43382..c457b608decbf 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -109,7 +109,7 @@ template class CutlassMoeFCRunner { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights); + CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k); @@ -161,6 +161,7 @@ class CutlassMoeFCRunner { bool has_fc3_; bool normalize_routing_weights_; + bool use_sparse_mixer_; // Cuda events contrib::cuda::AutoDestoryCudaEvent cuda_event_; @@ -175,7 +176,7 @@ class CutlassMoeFCRunner { template class CutlassMoeFCRunner::value>> { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights); + CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { return 0; diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 6aa75840e6dc0..c5352d931ce2c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -49,7 +49,7 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { const int sm = device_prop.major * 10 + device_prop.minor; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_); + normalize_routing_weights_, use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 4a407fa1b2159..6b65557444a66 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -22,6 +22,7 @@ enum class MoEParallelType { enum class MoEQuantType { None = 0, UINT4 = 1, + UINT8 = 2, }; struct MoEParameters { @@ -225,9 +226,15 @@ class MoEBase { } normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault("normalize_routing_weights", 0) == 1; + + use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault("use_sparse_mixer", 0) == 1; + if (use_sparse_mixer_) { + ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2"); + } } bool normalize_routing_weights_; + bool use_sparse_mixer_; int64_t k_; ort_fastertransformer::ActivationType activation_type_; }; diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 571cc59dec75c..a9b67004f2c58 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -15,12 +15,18 @@ namespace contrib { namespace cuda { #define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCudaExecutionProvider, \ + ONNX_OPERATOR_KERNEL_EX(QMoE4Bits, kMSDomain, 1, kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .MayInplace(0, 0) \ .TypeConstraint("T", BuildKernelDefConstraints()) \ .TypeConstraint("T1", BuildKernelDefConstraints()), \ - QMoE); + QMoE); \ + ONNX_OPERATOR_KERNEL_EX(QMoE8Bits, kMSDomain, 1, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", BuildKernelDefConstraints()) \ + .TypeConstraint("T1", BuildKernelDefConstraints()), \ + QMoE); REGISTER_KERNEL() @@ -39,9 +45,11 @@ struct ToCudaTypeWrapper { }; } // anonymous namespace -QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {} +template +QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {} -Status QMoE::ComputeInternal(OpKernelContext* context) const { +template +Status QMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* router_probs = context->Input(1); const Tensor* fc1_experts_weights = context->Input(2); @@ -60,18 +68,16 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { #endif MoEParameters moe_params; - MoEQuantType quant_type = MoEQuantType::UINT4; + MoEQuantType quant_type = USE_QUINT4x2 ? MoEQuantType::UINT4 : MoEQuantType::UINT8; ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, fc3_experts_bias_optional)); ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts, moe_params.hidden_size, moe_params.inter_size)); - // Support int4 only at the moment. We can add uint8 if needed. - static constexpr bool use_quint4x2 = true; using T = MLFloat16; using CudaT = typename ToCudaType::MappedType; - using CudaWeightT = typename ToCudaTypeWrapper::MappedType; + using CudaWeightT = typename ToCudaTypeWrapper::MappedType; auto stream = context->GetComputeStream(); @@ -79,7 +85,8 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { const int sm = device_prop.major * 10 + device_prop.minor; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_); + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), @@ -149,4 +156,4 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h index 7b68d2d082de8..7b7ff72dad959 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h @@ -14,6 +14,7 @@ namespace cuda { using namespace onnxruntime::cuda; +template class QMoE final : public CudaKernel, public MoEBase { public: explicit QMoE(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index a0ca2e45f153a..7b4f3611f7cdf 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -95,6 +95,10 @@ void RegisterCollectiveOps() { "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) + .Attr("use_sparse_mixer", + "Whether to use sparse mixer", + AttributeProto::INT, + static_cast(0)) .Attr("local_experts_start_index", "The start index of local experts", AttributeProto::INT, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index aebe726afe711..fdbb78d3a4b32 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1395,6 +1395,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) + .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") @@ -1408,7 +1409,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); ONNX_MS_OPERATOR_SET_SCHEMA( - QMoE, 1, + QMoE4Bits, 1, OpSchema() .SetDoc("Int4 MoE") .Attr("activation_type", @@ -1465,6 +1466,65 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +ONNX_MS_OPERATOR_SET_SCHEMA( + QMoE8Bits, 1, + OpSchema() + .SetDoc("Int8 MoE") + .Attr("activation_type", + "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + AttributeProto::STRING, + std::string("relu")) + .Attr("k", + "Number of top experts to select from expert pool", + AttributeProto::INT, + static_cast(1)) + .Attr("normalize_routing_weights", + "Whether to normalize routing weights", + AttributeProto::INT, + static_cast(0)) + .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) + .Input(0, + "input", + "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " + "(batch_size, sequence_length, hidden_size)", + "T") + .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T1") + .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T") + .Input(4, + "fc1_experts_bias", + "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T1") + .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T") + .Input(7, + "fc2_experts_bias", + "2D optional input tensor with shape (num_experts, hidden_size)", + "T", + OpSchema::Optional) + .Input(8, + "fc3_experts_weights", + "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", + "T1", + OpSchema::Optional) + .Input(9, + "fc3_scales", + "2D optional input tensor with shape (num_experts, inter_size)", + "T", + OpSchema::Optional) + .Input(10, + "fc3_experts_bias", + "2D optional input tensor with shape (num_experts, inter_size)", + "T", + OpSchema::Optional) + .Output(0, + "output", + "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " + "(batch_size, sequence_length, hidden_size)", + "T") + .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, OpSchema() .Input(0, "X", "input", "T") diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index a9a89f756b071..f54dd14e5cd9f 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -84,7 +84,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4); #endif class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MoE); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE4Bits); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE8Bits); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); @@ -194,7 +195,8 @@ class OpSet_Microsoft_ver1 { #endif fn(GetOpSchema()); fn(GetOpSchema()); - fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 42f62981cb52b..44e0e8071e1e5 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -97,7 +97,7 @@ static void RunQMoETest(const std::vector& input, const std::vector("k", static_cast(top_k)); tester.AddAttribute("activation_type", activation_type); tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); diff --git a/onnxruntime/test/python/transformers/test_parity_phi3_moe.py b/onnxruntime/test/python/transformers/test_parity_phi3_moe.py new file mode 100644 index 0000000000000..be88c987c7bda --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_phi3_moe.py @@ -0,0 +1,545 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest +from collections import OrderedDict + +import numpy +import torch +import torch.nn.functional as F +from onnx import TensorProto, helper +from torch import nn + +import onnxruntime + +torch.manual_seed(42) +numpy.random.seed(42) + +ORT_DTYPE = TensorProto.FLOAT16 +NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 +USE_QUANT = True +THRESHOLD = 3e-1 if USE_QUANT else 3e-2 + + +def value_string_of(numpy_array): + arr = numpy_array.flatten() + lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] + return "{\n " + "f,\n ".join(lines) + "f}" + + +def print_tensor(name, numpy_array): + print(f"const std::vector {name} = {value_string_of(numpy_array)};") + + +def quant_dequant(weights, quant_mode: bool = True): + # use the test version `_symmetric_...` to get the non-interleaved weights + type = torch.quint4x2 if quant_mode else torch.int8 + import tensorrt_llm + + quant_weights, processed_q_weight, torch_weight_scales = ( + torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type) + ) + + # Unpack the int4s int int8s + if quant_mode: + upper = quant_weights >> 4 + lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends + quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape) + + quant_weights = quant_weights.to(dtype=weights.dtype) + result = torch.multiply(quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous() + return torch_weight_scales.to(torch.float16), processed_q_weight, result.to(device=weights.device) + + +def create_moe_onnx_graph( + num_rows, + num_experts, + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc3_experts_weights, + fc1_scales, + fc2_scales, + fc3_scales, + topk, +): + use_quant = USE_QUANT + if use_quant: + assert fc1_experts_weights.dtype == torch.int8 + assert fc2_experts_weights.dtype == torch.int8 + assert fc3_experts_weights.dtype == torch.int8 + assert fc1_scales is not None + assert fc2_scales is not None + assert fc3_scales is not None + assert fc1_scales.dtype == torch.float16 + assert fc2_scales.dtype == torch.float16 + assert fc3_scales.dtype == torch.float16 + + nodes = [ + helper.make_node( + "MoE" if not use_quant else "QMoE8Bits", + ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ] + if not use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + "fc3_experts_weights", + "fc3_scales", + "", + ] + ), + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=0, + use_sparse_mixer=1, + activation_type="silu", + domain="com.microsoft", + ), + ] + + feature_size_modifier = 1 if not use_quant else 1 + + fc1_shape = [num_experts, hidden_size, inter_size // feature_size_modifier] + fc2_shape = [num_experts, inter_size, hidden_size // feature_size_modifier] + fc3_shape = [num_experts, hidden_size, inter_size // feature_size_modifier] + + torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + numpy_type = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 + if use_quant: + numpy_type = numpy.uint8 + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + ORT_DTYPE if not use_quant else TensorProto.UINT8, + fc1_shape, + fc1_experts_weights.flatten().numpy().astype(numpy_type).tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + ORT_DTYPE if not use_quant else TensorProto.UINT8, + fc2_shape, + fc2_experts_weights.flatten().numpy().astype(numpy_type).tolist(), + raw=False, + ), + helper.make_tensor( + "fc3_experts_weights", + ORT_DTYPE if not use_quant else TensorProto.UINT8, + fc3_shape, + fc3_experts_weights.flatten().numpy().astype(numpy_type).tolist(), + raw=False, + ), + ] + + if use_quant: + fc1_scale_shape = [num_experts, inter_size] + fc2_scale_shape = [num_experts, hidden_size] + fc3_scale_shape = [num_experts, inter_size] + initializers.extend( + [ + helper.make_tensor( + "fc1_scales", + ORT_DTYPE, + fc1_scale_shape, + fc1_scales.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_scales", + ORT_DTYPE, + fc2_scale_shape, + fc2_scales.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc3_scales", + ORT_DTYPE, + fc3_scale_shape, + fc3_scales.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ORT_DTYPE, + [num_rows, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "silu": nn.SiLU, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +class PhiMoEConfig: + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + hidden_dropout=0.0, + expert_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.01, + input_jitter_noise=0.01, + attention_bias=False, + lm_head_bias=False, + drop_reg=0.0, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.rope_theta = rope_theta + self.sliding_window = sliding_window + self.attention_dropout = attention_dropout + self.hidden_dropout = hidden_dropout + self.expert_dropout = expert_dropout + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + self.input_jitter_noise = input_jitter_noise + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + self.drop_reg = drop_reg + + +class PhiMoEBlockSparseTop2MLP(nn.Module): + def __init__(self, config: PhiMoEConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.expert_dropout = nn.Dropout(config.expert_dropout) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.expert_dropout(current_hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): + assert top_k == 2 + assert training == False + + mask_logits_threshold, selected_experts = torch.topk(scores, 2) + + mask_logits_threshold_1 = mask_logits_threshold[:, 0].unsqueeze(-1) + + factor = scores.abs().clamp(min=mask_logits_threshold_1) + logits_mask = ((mask_logits_threshold_1 - scores) / factor) > (2 * jitter_eps) + + multiplier_1 = torch.softmax(scores.masked_fill(logits_mask, float("-inf")), dim=-1).gather( + dim=-1, index=selected_experts[:, 0].unsqueeze(-1) + ) + + ################ second expert gating ################ + + mask_logits_threshold_2 = mask_logits_threshold[:, 1].unsqueeze(-1) + + factor = scores.abs().clamp(min=mask_logits_threshold_2) + logits_mask = ((mask_logits_threshold_2 - scores) / factor) > (2 * jitter_eps) + + multiplier_2 = torch.softmax( + torch.scatter(scores, -1, selected_experts[:, 0].unsqueeze(-1), float("-inf")).masked_fill( + logits_mask, float("-inf") + ), + dim=-1, + ).gather(dim=-1, index=selected_experts[:, 1].unsqueeze(-1)) + + multiplier = torch.concat((multiplier_1, multiplier_2), dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +class PhiMoESparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config, batch_size, sequence_length): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.router_jitter_noise = config.router_jitter_noise + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([PhiMoEBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + w1_list = [] + w2_list = [] + w3_list = [] + w1_scale_list = [] + w2_scale_list = [] + w3_scale_list = [] + if not USE_QUANT: + for i in range(self.num_experts): + w1_list.append(self.experts[i].w1.weight) + w2_list.append(self.experts[i].w2.weight) + w3_list.append(self.experts[i].w3.weight) + else: + for i in range(self.num_experts): + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, False) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, False) + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, False) + + self.experts[i].w1.weight.data = w1_qdq + self.experts[i].w2.weight.data = w2_qdq + self.experts[i].w3.weight.data = w3_qdq + + w1_list.append(pre_qweight1) + w2_list.append(pre_qweight2) + w3_list.append(pre_qweight3) + w1_scale_list.append(w1_scale) + w2_scale_list.append(w2_scale) + w3_scale_list.append(w3_scale) + + self.moe_experts_weight1 = torch.stack(w1_list, dim=0) + self.moe_experts_weight2 = torch.stack(w2_list, dim=0) + self.moe_experts_weight3 = torch.stack(w3_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if USE_QUANT else None + moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if USE_QUANT else None + moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if USE_QUANT else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + self.moe_onnx_graph = create_moe_onnx_graph( + self.batch_size * self.sequence_length, + self.num_experts, + self.hidden_dim, + self.ffn_dim, + self.moe_experts_weight1, + self.moe_experts_weight2, + self.moe_experts_weight3, + moe_experts_weight_scale1, + moe_experts_weight_scale2, + moe_experts_weight_scale3, + self.top_k, + ) + + self.ort_sess = self.create_ort_session() + + def create_ort_session(self): + from onnxruntime import InferenceSession, SessionOptions + + sess_options = SessionOptions() + + cuda_providers = ["CUDAExecutionProvider"] + if cuda_providers[0] not in onnxruntime.get_available_providers(): + return None + + sess_options.log_severity_level = 2 + ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) + + return ort_session + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights, selected_experts = masked_sampling_omp_inference( + router_logits, + top_k=self.top_k, + jitter_eps=self.router_jitter_noise, + training=False, + ) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + + return final_hidden_states # , router_logits + + def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + ort_inputs = { + "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), + "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), + } + + ort_output = None + if self.ort_sess is not None: + ort_output = self.ort_sess.run(None, ort_inputs) + return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits + + # print_tensor("input", ort_inputs["input"]) + # print_tensor("router_probs", ort_inputs["router_probs"]) + # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) + # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) + # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) + # print_tensor("output", ort_output[0]) + + return None + + def parity_check(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + if ort_output is not None: + assert torch.allclose(torch_output, ort_output.to(torch.float32), rtol=THRESHOLD, atol=THRESHOLD) + print( + "batch_size:", + self.batch_size, + " sequence_length:", + self.sequence_length, + " max_diff:", + (torch_output - ort_output).abs().max(), + " parity: OK", + ) + + +class TestMixtralMoE(unittest.TestCase): + def test_phi3_moe_parity(self): + for batch_size in [1, 16]: + for sequence_length in [32, 128, 512, 2048]: + # use a small sizes to speed up the test + config = PhiMoEConfig(hidden_size=1024, intermediate_size=2048) + phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length) + phi3_moe.parity_check() + + +if __name__ == "__main__": + unittest.main()