From 034ab4fa04789af9f989975ba6515fc72941b95b Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 23 Oct 2024 11:03:09 +0800 Subject: [PATCH 1/8] [WebNN EP] Fixed a minor bug in ConvTranspose (#22384) For ConvTranspose, the filter should be transposed from iohw -> ohwi if it is NHWC preferred layout. --- .../core/providers/webnn/builders/impl/conv_op_builder.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index f03e5b90ff6db..329db75316e82 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -133,7 +133,7 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, const auto out_t = dims[0], in_t = dims[1], h_t = dims[2], w_t = dims[3]; std::vector dest_shape; - if (is_conv == 1) + if (is_conv) dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 else dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight @@ -265,7 +265,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi")); if (is_constant_weight) { - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false, is_conv1d)); } } } From ba40022ec42a4b60d4b1ef875d6613923e9e8624 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 23 Oct 2024 11:26:34 +0800 Subject: [PATCH 2/8] [WebNN EP] Support axes and fix some validation for Resize (#21952) - Supports arbitrary axes for Resize opset 18+ - Check all inputs and attributes more carefully --------- Co-authored-by: Dwayne Robinson --- js/web/docs/webnn-operators.md | 2 +- .../core/providers/webnn/builders/helper.h | 36 +++ .../webnn/builders/impl/resize_op_builder.cc | 287 +++++++++++------- 3 files changed, 216 insertions(+), 109 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index f696264aeead7..bf0f1dffb83ee 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -78,7 +78,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✓ | ✓ | Input 'axes' if present should be a constant | | Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | -| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes | +| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | | Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index aecb1f7a03bb9..ec9993bf138ba 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -36,6 +36,31 @@ WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type); // Collects all the initializer tensors in the subGraph and its ancestor graphs. InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer); +inline std::vector convertAxesFromNCHWtoNHWC(const std::vector& axes) { + constexpr std::array nchw_to_nhwc = {0, 3, 1, 2}; + std::vector new_axes; + new_axes.reserve(axes.size()); + for (int64_t axis : axes) { + if (axis >= nchw_to_nhwc.size()) { + ORT_THROW("Invalid axis value: ", axis); + } + new_axes.push_back(nchw_to_nhwc[static_cast(axis)]); + } + return new_axes; +} + +inline std::vector HandleNegativeAxes(const std::vector& axes, size_t input_size) { + std::vector new_axes(axes.size()); + for (size_t i = 0; i < axes.size(); ++i) { + new_axes[i] = HandleNegativeAxis(axes[i], input_size); + } + return new_axes; +} + +inline std::vector GetResolvedAxes(const NodeAttrHelper& helper, size_t input_size) { + return HandleNegativeAxes(helper.Get("axes", std::vector{}), input_size); +} + bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger); template @@ -144,6 +169,17 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va return true; } +inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::string& name) { + if (name.empty() || !Contains(initializers, name)) { + return true; + } + + const auto& tensor = *initializers.at(name); + const auto dims = tensor.dims(); + // An empty tensor contains a 0 in the dimensions list. + return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); +} + bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 9dc79f4f52f46..3442afbc2b3cd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -38,16 +38,33 @@ class ResizeOpBuilder : public BaseOpBuilder { }; // Helper functions -bool GetResizeScales(const InitializedTensorSet& initializers, - const Node& node, std::vector& scales, - const logging::Logger& logger) { +bool GetResizeScalesAndAxes(const InitializedTensorSet& initializers, + const Node& node, std::vector& scales, + std::vector& axes, const bool is_nhwc, + const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); if (input_defs.size() < 3) return false; + const bool has_axes = !axes.empty(); const auto& scales_tensor = *initializers.at(input_defs[2]->Name()); - if (scales_tensor.dims_size() != 1 || scales_tensor.dims()[0] != 4) + if (scales_tensor.dims_size() != 1) { + LOGS(logger, ERROR) << "'scales' should be a 1D tensor."; return false; + } + + // Number of elements of 'scales' tensor. + const auto num_of_scales = scales_tensor.dims()[0]; + + if (has_axes && num_of_scales != 2) { + LOGS(logger, ERROR) << "When 'axes' is provided, 'scales' should have 2 elements."; + return false; + } + + if (!has_axes && num_of_scales != 4) { + LOGS(logger, ERROR) << "When 'axes' is not provided, 'scales' should have 4 elements."; + return false; + } std::vector unpacked_tensor; auto status = onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor); @@ -56,20 +73,65 @@ bool GetResizeScales(const InitializedTensorSet& initializers, return false; } const float* scales_data = reinterpret_cast(unpacked_tensor.data()); - scales = std::vector{scales_data, scales_data + 4}; + + if (has_axes) { + // 'axes' is specified since opset 18+, 'scales' should have 2 elements. + scales = std::vector{scales_data, scales_data + 2}; + } else { + // Before opset 18, 'scales' should have 4 elements. + // Make sure 'scales' is not trying to scale on N/C channels here. + std::vector onnx_scales{scales_data, scales_data + 4}; + // 'scales' input has been transposed to NHWC layout if it is NHWC preferred layout. + const float scale_n = onnx_scales[0]; + const float scale_c = is_nhwc ? onnx_scales[3] : onnx_scales[1]; + const float scale_h = is_nhwc ? onnx_scales[1] : onnx_scales[2]; + const float scale_w = is_nhwc ? onnx_scales[2] : onnx_scales[3]; + if (scale_n != 1.0f || scale_c != 1.0f) { + LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" + << "Scales of N/C channels are not supported" + << ", scale_n, " << scale_n << ", scale_c, " << scale_c; + return false; + } + + scales = {scale_h, scale_w}; + axes = {2, 3}; + } + + if (is_nhwc) { + // For NHWC preferred layout, we need to convert axes from NCHW to NHWC. + axes = convertAxesFromNCHWtoNHWC(axes); + } + return true; } -bool GetResizeOutputSizes(const InitializedTensorSet& initializers, - const Node& node, std::vector& sizes, - const logging::Logger& logger) { +bool GetResizeSizesAndAxes(const InitializedTensorSet& initializers, + const Node& node, std::vector& sizes, + std::vector& axes, const bool is_nhwc, + const gsl::span& input_shape, + const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); if (input_defs.size() < 4) return false; + const bool has_axes = !axes.empty(); const auto& sizes_tensor = *initializers.at(input_defs[3]->Name()); - if (sizes_tensor.dims_size() != 1 || sizes_tensor.dims()[0] != 4) + if (sizes_tensor.dims_size() != 1) { + LOGS(logger, ERROR) << "'sizes' should be a 1D tensor."; + return false; + } + + // Number of elements of sizes tensor. + const auto num_of_sizes = sizes_tensor.dims()[0]; + if (has_axes && num_of_sizes != 2) { + LOGS(logger, ERROR) << "When 'axes' is provided, 'sizes' should have 2 elements."; + return false; + } + + if (!has_axes && num_of_sizes != 4) { + LOGS(logger, ERROR) << "When 'axes' is not provided, 'sizes' should have 4 elements."; return false; + } std::vector unpacked_tensor; auto status = onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor); @@ -78,7 +140,35 @@ bool GetResizeOutputSizes(const InitializedTensorSet& initializers, return false; } const int64_t* sizes_data = reinterpret_cast(unpacked_tensor.data()); - sizes = std::vector{sizes_data, sizes_data + 4}; + + if (has_axes) { + // 'axes' is specified since opset 18+, 'sizes' should have 2 elements. + sizes = std::vector{sizes_data, sizes_data + 2}; + } else { + // Before opset 18, 'sizes' should have 4 elements. + // Make sure 'sizes' is not trying to resize on N/C channels here. + std::vector onnx_sizes{sizes_data, sizes_data + 4}; + auto size_n = onnx_sizes[0]; + const int c_idx = is_nhwc ? 3 : 1; + if (size_n != input_shape[0] || onnx_sizes[c_idx] != input_shape[c_idx]) { + LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " + << "Resize of N/C channels are not supported" + << ", input_size_n, " << input_shape[0] << ", output_size_n, " << size_n + << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << onnx_sizes[c_idx]; + return false; + } + // 'sizes' input has been transposed to NHWC layout if it is NHWC preferred layout. + const int64_t sizes_h = is_nhwc ? onnx_sizes[1] : onnx_sizes[2]; + const int64_t sizes_w = is_nhwc ? onnx_sizes[2] : onnx_sizes[3]; + sizes = {sizes_h, sizes_w}; + axes = {2, 3}; + } + + if (is_nhwc) { + // For NHWC preferred layout, we need to convert 'axes' from NCHW to NHWC. + axes = convertAxesFromNCHWtoNHWC(axes); + } + return true; } @@ -103,9 +193,15 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + + const auto& initializers(model_builder.GetInitializerTensors()); + NodeAttrHelper helper(node); + emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - NodeAttrHelper helper(node); const auto mode = helper.Get("mode", "nearest"); if (mode == "linear") { options.set("mode", emscripten::val("linear")); @@ -113,45 +209,30 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("mode", emscripten::val("nearest-neighbor")); } - const auto& input_defs = node.InputDefs(); - const auto& initializers(model_builder.GetInitializerTensors()); - std::vector scales; - std::vector sizes; - std::vector scales_hw; - std::vector sizes_hw; - std::vector axes; - std::string scales_name = GetTensorName(input_defs, 2); + std::vector sizes; + std::vector webnn_sizes; + std::vector axes = GetResolvedAxes(helper, 4); // We already checked input shape is 4D in IsOpSupportedImpl. + std::string sizes_name = GetTensorName(input_defs, 3); const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; - if (!scales_name.empty()) { // Use scales. - ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); - if (is_nhwc) { - scales_hw = {scales[1], scales[2]}; - } else { - scales_hw = {scales[2], scales[3]}; - } - options.set("scales", emscripten::val::array(scales_hw)); - } else { // Use sizes, we already checked inputs in IsOpSupportedImpl. - std::vector output_sizes; - ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger), - "Error getting resize output_sizes"); - std::transform(output_sizes.cbegin(), output_sizes.cend(), - std::back_inserter(sizes), - [](int64_t dim) -> int32_t { return SafeInt(dim); }); - if (is_nhwc) { - sizes_hw = {sizes[1], sizes[2]}; - } else { - sizes_hw = {sizes[2], sizes[3]}; - } - options.set("sizes", emscripten::val::array(sizes_hw)); - } - if (is_nhwc) { - axes = {1, 2}; + // We know we have either a 'scales' or 'sizes' input so this is safe. + // Check for 'sizes' first. + // This handles Resize-11 where 'scales' was a required input but 'sizes' were used if provided. + bool using_sizes = !sizes_name.empty() && Contains(initializers, sizes_name); + if (using_sizes) { + ORT_RETURN_IF_NOT(GetResizeSizesAndAxes(initializers, node, sizes, axes, is_nhwc, input_shape, logger), + "Error getting Resize sizes"); + webnn_sizes = GetVecUint32FromVecInt64(sizes); + options.set("sizes", emscripten::val::array(webnn_sizes)); } else { - axes = {2, 3}; + ORT_RETURN_IF_NOT(GetResizeScalesAndAxes(initializers, node, scales, axes, is_nhwc, logger), + "Error getting Resize scales"); + options.set("scales", emscripten::val::array(scales)); } - options.set("axes", emscripten::val::array(axes)); + + std::vector webnn_axes = GetVecUint32FromVecInt64(axes); + options.set("axes", emscripten::val::array(webnn_axes)); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val output = model_builder.GetBuilder().call("resample2d", input, options); @@ -166,6 +247,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) @@ -179,92 +261,81 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers } { // Check attributes. - NodeAttrHelper helper(node); - const auto mode = helper.Get("mode", "nearest"); - bool is_linear_resize = mode == "linear"; - bool is_nearest_resize = mode == "nearest"; - // WebNN only supports "linear" and "nearest" modes. - if (!is_linear_resize && !is_nearest_resize) { - LOGS(logger, VERBOSE) << "Resize does not support input mode: " << mode; + // antialias + if (helper.Get("antialias", 0) != 0) { + LOGS(logger, VERBOSE) << "Resize does not support antialias"; return false; } - const auto exclude_outside = helper.Get("exclude_outside", 0); - if (exclude_outside != 0) { - LOGS(logger, VERBOSE) << "Resize does not support exclude_outside for now"; + // coordinate_transformation_mode + // Spec issue for supporting more coordinate transformation modes: + // https://github.com/webmachinelearning/webnn/issues/270 + const std::string coordinate_transformation_mode = helper.Get("coordinate_transformation_mode", "half_pixel"); + if (coordinate_transformation_mode != "half_pixel") { + LOGS(logger, VERBOSE) << "Resize does not support coordinate_transformation_mode: " + << coordinate_transformation_mode; return false; } - } - { // scales and sizes (if present) must be initializers. - const std::string scales_name = GetTensorName(input_defs, 2); - const std::string sizes_name = GetTensorName(input_defs, 3); - - // scales (scales may be empty tensor) - bool has_scales = !scales_name.empty(); - if ((has_scales && !Contains(initializers, scales_name)) || (!has_scales && node.SinceVersion() == 11)) { - LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; + // exclude_outside + const auto exclude_outside = helper.Get("exclude_outside", 0); + if (exclude_outside != 0) { + LOGS(logger, VERBOSE) << "Resize does not support exclude_outside for now"; return false; } - // sizes (sizes may be empty tensor) - bool has_sizes = !sizes_name.empty(); - if (has_sizes && !Contains(initializers, sizes_name)) { - LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; + // keep_aspect_ratio_policy + const auto keep_aspect_ratio_policy = helper.Get("keep_aspect_ratio_policy", "stretch"); + if (keep_aspect_ratio_policy != "stretch") { + LOGS(logger, VERBOSE) << "Resize does not support keep_aspect_ratio_policy: " << keep_aspect_ratio_policy; return false; } - if (has_scales && has_sizes) { - LOGS(logger, VERBOSE) << "Only one of 'scales' and 'sizes' can be specified"; + // mode + const auto mode = helper.Get("mode", "nearest"); + bool is_linear_resize = mode == "linear"; + bool is_nearest_resize = mode == "nearest"; + // WebNN only supports "linear" and "nearest" modes. + if (!is_linear_resize && !is_nearest_resize) { + LOGS(logger, VERBOSE) << "Resize does not support input mode: " << mode; return false; } + } - const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; - // We want to check if the scales or sizes are not trying to resize on N/C channels here. - if (has_scales) { // We are using scales. - std::vector scales; - if (!GetResizeScales(initializers, node, scales, logger)) - return false; - - float scale_n = scales[0]; - float scale_c = is_nhwc ? scales[3] : scales[1]; - if (scale_n != 1.0f || scale_c != 1.0f) { - LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" - << "Resize of N/C channels are not supported" - << ", scale_n, " << scale_n << ", scale_c, " << scale_c; - return false; - } + { // 'scales' and 'sizes' (if present) must be non-empty initializers. + const std::string scales_name = GetTensorName(input_defs, 2); + const std::string sizes_name = GetTensorName(input_defs, 3); - // For now we only support upscale, so the scale_h and scale_w should be an integer >= 1. - // TODO support ResizeBilinear. - float scale_h = is_nhwc ? scales[1] : scales[2]; - float scale_w = is_nhwc ? scales[2] : scales[3]; + // Check for 'sizes' first. + // This handles Resize-11 where 'scales' was a required input but 'sizes' were used if provided. + // 'scales' or 'sizes' may be empty tensor. + bool using_sizes = !IsEmptyTensor(initializers, sizes_name); + bool using_scales = !using_sizes && !IsEmptyTensor(initializers, scales_name); - // Onnx spec requires scale to be a positive float, so we are not checking that here. - if (roundf(scale_h) != scale_h) { - LOGS(logger, VERBOSE) << "Resize: scale_h: " << scale_h << " is not a whole number"; - return false; - } + if (!using_scales && !using_sizes) { + LOGS(logger, VERBOSE) << "Resize: only one of 'scales' and 'sizes' can be specified"; + return false; + } - if (roundf(scale_w) != scale_w) { - LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number"; + // 'axes' is from opset 18 on and allows 'scales' or 'sizes' to have entries for the subset of 'axes'. + // We fill with default values if necessary so that the processing is consistent across all supported opsets. + std::vector axes = GetResolvedAxes(helper, input_size); + if (!axes.empty()) { // We have 'axes' attribute. + if (axes.size() != 2 || axes[0] >= input_size || axes[1] >= input_size) { + LOGS(logger, VERBOSE) << "Resize: invalid axes attribute"; return false; } } - if (has_sizes) { - // We are using sizes. - std::vector output_sizes; - if (!GetResizeOutputSizes(initializers, node, output_sizes, logger)) + const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; + if (using_sizes) { // We are using 'sizes'. + std::vector sizes; + if (!GetResizeSizesAndAxes(initializers, node, sizes, axes, is_nhwc, input_shape, logger)) { return false; - - auto output_size_n = output_sizes[0]; - const int c_idx = is_nhwc ? 3 : 1; - if (output_size_n != input_shape[0] || output_sizes[c_idx] != input_shape[c_idx]) { - LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " - << "Resize of N/C channels are not supported" - << ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n - << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << output_sizes[c_idx]; + } + } else { // We are using 'scales'. + std::vector scales; + if (!GetResizeScalesAndAxes(initializers, node, scales, axes, is_nhwc, logger)) { return false; } } From 0028d3f3328c9147261e28c85dbc70d36319485d Mon Sep 17 00:00:00 2001 From: ivberg Date: Tue, 22 Oct 2024 20:45:44 -0700 Subject: [PATCH 3/8] Fix crash in QNN EP - ResetQnnLogLevel (#22456) ### Description Fix crash with extra checks ResetQnnLogLevel. From the dump it looks like during ETW callbacks, while the provider is stopping, we attempt to reset the QNN log level. While the QNN BackEndMgr (this) is alive logger_ is not valid ### Motivation and Context ORT should not crash --- .../qnn/builder/qnn_backend_manager.cc | 16 +++- .../qnn/builder/qnn_backend_manager.h | 3 + .../providers/qnn/qnn_execution_provider.cc | 90 ++++++++++--------- .../providers/qnn/qnn_execution_provider.h | 2 +- 4 files changed, 63 insertions(+), 48 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index eaffe1e2ac224..34dcbd1d77fca 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -302,13 +302,21 @@ QnnLog_Level_t QnnBackendManager::MapOrtSeverityToQNNLogLevel(logging::Severity } Status QnnBackendManager::ResetQnnLogLevel() { - auto ort_log_level = logger_->GetSeverity(); - LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; - return UpdateQnnLogLevel(ort_log_level); + std::lock_guard lock(logger_mutex_); + + if (backend_setup_completed_ && logger_ != nullptr) { + auto ort_log_level = logger_->GetSeverity(); + LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; + return UpdateQnnLogLevel(ort_log_level); + } + return Status::OK(); } Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) { ORT_RETURN_IF(nullptr == log_handle_, "Unable to update QNN Log Level. Invalid QNN log handle."); + ORT_RETURN_IF(false == backend_setup_completed_, "Unable to update QNN Log Level. Backend setup not completed."); + ORT_RETURN_IF(nullptr == logger_, "Unable to update QNN Log Level. Invalid logger."); + QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level); LOGS(*logger_, INFO) << "Updating Qnn log level to: " << qnn_log_level; @@ -686,6 +694,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t } Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) { + std::lock_guard lock(logger_mutex_); if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; return Status::OK(); @@ -972,6 +981,7 @@ void QnnBackendManager::ReleaseResources() { ORT_THROW("Failed to ShutdownBackend."); } + std::lock_guard lock(logger_mutex_); result = TerminateQnnLog(); if (Status::OK() != result) { ORT_THROW("Failed to TerminateQnnLog."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index b80f1374fcdc7..43007d4a5c244 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -12,9 +12,11 @@ #endif #include +#include #include #include #include + #include "HTP/QnnHtpDevice.h" #include "QnnLog.h" #include "QnnTypes.h" @@ -233,6 +235,7 @@ class QnnBackendManager { private: const std::string backend_path_; + std::mutex logger_mutex_; const logging::Logger* logger_ = nullptr; QNN_INTERFACE_VER_TYPE qnn_interface_ = QNN_INTERFACE_VER_TYPE_INIT; QNN_SYSTEM_INTERFACE_VER_TYPE qnn_sys_interface_ = QNN_SYSTEM_INTERFACE_VER_TYPE_INIT; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index becb9a728b1e3..6735528bebbf9 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -258,49 +258,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } -#ifdef _WIN32 - auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - // Register callback for ETW capture state (rundown) - callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( - [&etwRegistrationManager, this]( - LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - ORT_UNUSED_PARAMETER(SourceId); - ORT_UNUSED_PARAMETER(MatchAnyKeyword); - ORT_UNUSED_PARAMETER(MatchAllKeyword); - ORT_UNUSED_PARAMETER(FilterData); - ORT_UNUSED_PARAMETER(CallbackContext); - - if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); - } - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { - if (Level != 0) { - // Commenting out Dynamic QNN Profiling for now - // There seems to be a crash in 3rd party QC QnnHtp.dll with this. - // Repro Scenario - start ETW tracing prior to session creation. - // Then disable/enable ETW Tracing with the code below uncommented a few times - // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); - // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); - } - } - } - - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); - (void)qnn_backend_manager_->ResetQnnLogLevel(); - } - }); - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); -#endif - // In case ETW gets disabled later auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL); if (profiling_level_pos != provider_options_map.end()) { @@ -440,6 +397,49 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio htp_arch, soc_model, enable_htp_weight_sharing_); + +#ifdef _WIN32 + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + // Register callback for ETW capture state (rundown) + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); + } + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { + if (Level != 0) { + // Commenting out Dynamic QNN Profiling for now + // There seems to be a crash in 3rd party QC QnnHtp.dll with this. + // Repro Scenario - start ETW tracing prior to session creation. + // Then disable/enable ETW Tracing with the code below uncommented a few times + // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); + // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); + } + } + } + + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); + (void)qnn_backend_manager_->ResetQnnLogLevel(); + } + }); + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); +#endif } QNNExecutionProvider::~QNNExecutionProvider() { @@ -453,7 +453,9 @@ QNNExecutionProvider::~QNNExecutionProvider() { // Unregister the ETW callback #ifdef _WIN32 - logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + if (callback_ETWSink_provider_ != nullptr) { + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + } #endif } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 30e2fd53e9613..35c061de6132c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -151,7 +151,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; #ifdef _WIN32 - onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; + onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif qnn::ModelSettings model_settings_ = {}; From ffaddead0ad7014190def4dba2350dc4d935e9c9 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Wed, 23 Oct 2024 08:14:10 -0700 Subject: [PATCH 4/8] Refactor cuda packaging pipeline (#22542) ### Description ### Motivation and Context --- .../build-perf-test-binaries-pipeline.yml | 7 +- .../py-cuda-alt-packaging-pipeline.yml | 27 +++ .../py-cuda-packaging-pipeline.yml | 54 +++--- .../py-dml-packaging-pipeline.yml | 18 ++ .../azure-pipelines/py-packaging-pipeline.yml | 15 +- .../jobs/py-linux-cuda-package-test-job.yml | 8 +- .../py-cpu-packaging-stage.yml} | 170 ++---------------- ...g-stage.yml => py-gpu-packaging-stage.yml} | 53 +++--- .../py-linux-gpu-stage.yml} | 49 +++-- .../py-win-gpu-stage.yml} | 28 +-- .../py-packaging-linux-test-cuda.yml | 4 +- 11 files changed, 172 insertions(+), 261 deletions(-) create mode 100644 tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml rename tools/ci_build/github/azure-pipelines/{templates/py-packaging-stage.yml => stages/py-cpu-packaging-stage.yml} (64%) rename tools/ci_build/github/azure-pipelines/stages/{py-cuda-packaging-stage.yml => py-gpu-packaging-stage.yml} (68%) rename tools/ci_build/github/azure-pipelines/{templates/py-linux-gpu.yml => stages/py-linux-gpu-stage.yml} (63%) rename tools/ci_build/github/azure-pipelines/{templates/py-win-gpu.yml => stages/py-win-gpu-stage.yml} (93%) diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml index 50d4d8a912585..4e5d9a70beb66 100644 --- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml @@ -34,11 +34,8 @@ stages: # build Python packages # Linux GPU only - ${{ if parameters.BuildPythonPackages }}: - - template: templates/py-packaging-stage.yml + - template: stages/py-gpu-packaging-stage.yml parameters: enable_linux_gpu: true - enable_linux_cpu: false - enable_windows_cpu: false enable_windows_gpu: false - enable_mac_cpu: false - enable_linux_arm: false + cuda_version: 12.2 diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml new file mode 100644 index 0000000000000..cc2977721d03b --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml @@ -0,0 +1,27 @@ +trigger: none + +parameters: + - name: enable_linux_cuda + type: boolean + default: true + + - name: enable_windows_cuda + type: boolean + default: true + + - name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +stages: + - template: stages/py-gpu-packaging-stage.yml + parameters: + enable_linux_cuda: ${{ parameters.enable_linux_cuda }} + enable_windows_cuda: ${{ parameters.enable_windows_cuda }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml index 3503857a9233c..7e6b1889687a3 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml @@ -1,12 +1,20 @@ trigger: none - +# The `resources` specify the location and version of the 1ES PT. +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release parameters: - - name: enable_linux_gpu + - name: enable_linux_cuda type: boolean default: true - - name: enable_windows_gpu + + - name: enable_windows_cuda type: boolean default: true + - name: cmake_build_type type: string default: 'Release' @@ -15,28 +23,22 @@ parameters: - Release - RelWithDebInfo - MinSizeRel - - name: cuda_version - type: string - default: '12.2' - values: - - 11.8 - - 12.2 - - name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - - name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + # Update the pool with your team's 1ES hosted pool. + pool: + name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool + os: windows # OS of the image. This value cannot be a variable. Allowed values: windows, linux, macOS -stages: - - template: stages/py-cuda-packaging-stage.yml - parameters: - enable_linux_gpu: ${{ parameters.enable_linux_gpu }} - enable_windows_gpu: ${{ parameters.enable_windows_gpu }} - cmake_build_type: ${{ parameters.cmake_build_type }} - cuda_version: ${{ parameters.cuda_version }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + stages: + - template: stages/py-gpu-packaging-stage.yml + parameters: + enable_linux_cuda: ${{ parameters.enable_linux_cuda }} + enable_windows_cuda: ${{ parameters.enable_windows_cuda }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: '12.2' diff --git a/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml new file mode 100644 index 0000000000000..0c7c6abeb35da --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml @@ -0,0 +1,18 @@ +trigger: none + +parameters: + - name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +stages: + - template: stages/py-gpu-packaging-stage.yml + parameters: + enable_windows_dml: true + cmake_build_type: ${{ parameters.cmake_build_type }} + publish_symbols: true diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index de17db216da9c..ed992be31257a 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -4,21 +4,11 @@ parameters: type: boolean default: true -- name: enable_linux_gpu - displayName: 'Whether Linux GPU package is built.' - type: boolean - default: true - - name: enable_windows_cpu displayName: 'Whether Windows CPU package is built.' type: boolean default: true -- name: enable_windows_gpu - displayName: 'Whether Windows GPU package is built.' - type: boolean - default: true - - name: enable_mac_cpu displayName: 'Whether Mac CPU package is built.' type: boolean @@ -74,12 +64,10 @@ parameters: trigger: none stages: -- template: templates/py-packaging-stage.yml +- template: stages/py-cpu-packaging-stage.yml parameters: - enable_linux_gpu: ${{ parameters.enable_linux_gpu }} enable_linux_cpu: ${{ parameters.enable_linux_cpu }} enable_windows_cpu: ${{ parameters.enable_windows_cpu }} - enable_windows_gpu: ${{ parameters.enable_windows_gpu }} enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} @@ -90,3 +78,4 @@ stages: cmake_build_type: ${{ parameters.cmake_build_type }} qnn_sdk_version: ${{ parameters.qnn_sdk_version }} publish_symbols: true + diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index 9289935b4ef9c..a33f757c24408 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -57,15 +57,15 @@ jobs: - checkout: self - task: DownloadPipelineArtifact@2 inputs: - artifact: 'drop-linux-gpu-x86_64' - targetPath: '$(Build.SourcesDirectory)/drop-linux-gpu-x86_64' + artifact: 'linux_gpu_wheel_x86_64' + targetPath: '$(Build.SourcesDirectory)/linux_gpu_wheel_x86_64' ${{ if ne(parameters.build_id, 'latest') }}: buildType: 'specific' project: '${{ parameters.project }}' pipeline: '${{ parameters.pipeline }}' buildVersionToDownload: 'specific' buildId: '${{ parameters.build_id }}' - displayName: 'Download Build Artifacts - drop-linux-gpu-x86_64' + displayName: 'Download Build Artifacts - linux_gpu_wheel_x86_64' - task: DownloadPipelineArtifact@2 inputs: @@ -82,7 +82,7 @@ jobs: - bash: | set -e -x ls $(Build.SourcesDirectory) - mv "$(Build.SourcesDirectory)/drop-linux-gpu-x86_64" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Build.SourcesDirectory)/linux_gpu_wheel_x86_64" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} mv "$(Build.SourcesDirectory)/onnxruntime_gpu" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml similarity index 64% rename from tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml rename to tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 10d7ce04747d9..e92761e20d9e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -10,21 +10,11 @@ parameters: type: boolean default: true -- name: enable_linux_gpu - displayName: 'Whether Linux GPU package is built.' - type: boolean - default: true - - name: enable_windows_cpu displayName: 'Whether Windows CPU package is built.' type: boolean default: true -- name: enable_windows_gpu - displayName: 'Whether Windows GPU package is built.' - type: boolean - default: true - - name: enable_mac_cpu displayName: 'Whether Mac CPU package is built.' type: boolean @@ -65,10 +55,6 @@ parameters: - RelWithDebInfo - MinSizeRel -- name: publish_symbols - type: boolean - default: false - # Only applies to QNN packages. - name: qnn_sdk_version type: string @@ -128,7 +114,7 @@ stages: clean: true submodules: recursive - - template: telemetry-steps.yml + - template: ../templates/telemetry-steps.yml - task: UsePythonVersion@0 inputs: @@ -142,7 +128,7 @@ stages: tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' appendSourceBranchName: false - - template: set-nightly-build-option-variable-step.yml + - template: ../templates/set-nightly-build-option-variable-step.yml - task: BatchScript@1 displayName: 'setup env' @@ -151,7 +137,7 @@ stages: modifyEnvironment: true workingFolder: '$(Build.BinariesDirectory)' - - template: download-deps.yml + - template: ../templates/download-deps.yml - task: PythonScript@0 displayName: 'Update deps.txt' @@ -180,24 +166,12 @@ stages: --enable_pybind --enable_onnx_tests ${{ parameters.build_py_parameters }} - --parallel --use_binskim_compliant_compile_flags --update + --parallel --use_binskim_compliant_compile_flags --update --build $(TelemetryOption) workingDirectory: '$(Build.BinariesDirectory)' - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\onnxruntime.sln' - platform: $(MsbuildPlatform) - configuration: ${{ parameters.cmake_build_type }} - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}' - createLogFile: true - # Esrp signing - - template: win-esrp-dll.yml + - template: ../templates/win-esrp-dll.yml parameters: FolderPath: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime\capi' DisplayName: 'ESRP - Sign Native dlls' @@ -251,29 +225,8 @@ stages: python onnx_backend_test_series.py workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' displayName: 'Run Python Tests' - - ${{ if eq(parameters.publish_symbols, true) }}: - - task: PublishSymbols@2 - displayName: 'Publish symbols' - condition: and (succeeded(), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))) - inputs: - SymbolsFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' - SearchPattern: | - onnxruntime_pybind11_state.pdb - onnxruntime_providers_shared.pdb - IndexSources: true - SymbolServerType: TeamServices - SymbolExpirationInDays: 3650 - SymbolsArtifactName: 'win_cpu_$(PythonVersion)_$(buildArch)_$(Build.BuildNumber)' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and(and (succeeded(), and(eq(variables['buildArch'], 'x64'), eq(variables['PythonVersion'], '3.8'))), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - continueOnError: true - - template: component-governance-component-detection-steps.yml + - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' @@ -281,87 +234,6 @@ stages: displayName: 'Clean Agent Directories' condition: always() -- ${{ if eq(parameters.enable_windows_gpu, true) }}: - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.10' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.11' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.12' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.13' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.10' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.11' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.12' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.13' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - ${{ if eq(parameters.enable_mac_cpu, true) }}: - stage: Python_Packaging_MacOS dependsOn: [] @@ -395,9 +267,9 @@ stages: inputs: versionSpec: $(PythonVersion) - - template: use-xcode-version.yml + - template: ../templates/use-xcode-version.yml - - template: download-deps.yml + - template: ../templates/download-deps.yml - task: PythonScript@0 displayName: 'Update deps.txt' @@ -437,7 +309,7 @@ stages: inputs: ArtifactName: onnxruntime - - template: component-governance-component-detection-steps.yml + - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' @@ -446,7 +318,7 @@ stages: - stage: Python_Packaging_Linux_ARM dependsOn: [] jobs: - - template: py-linux.yml + - template: ../templates/py-linux.yml parameters: arch: 'aarch64' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' @@ -457,30 +329,18 @@ stages: - stage: Python_Packaging_Linux_CPU dependsOn: [] jobs: - - template: py-linux.yml - parameters: - arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - - ${{ if eq(parameters.enable_linux_gpu, true) }}: - - template: py-linux-gpu.yml + - template: ../templates/py-linux.yml parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} - trt_version: '10.4.0.26-1.cuda11.8' - cuda_version: '11.8' - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: - stage: Python_Packaging_Windows_ARM64_QNN dependsOn: [] jobs: - - template: py-win-arm64-qnn.yml + - template: ../templates/py-win-arm64-qnn.yml parameters: MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' QNN_SDK: ${{ parameters.qnn_sdk_version }} @@ -490,7 +350,7 @@ stages: - stage: Python_Packaging_Windows_arm64ec_QNN dependsOn: [] jobs: - - template: py-win-arm64ec-qnn.yml + - template: ../templates/py-win-arm64ec-qnn.yml parameters: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} @@ -500,7 +360,7 @@ stages: - stage: Python_Packaging_Windows_x64_QNN dependsOn: [] jobs: - - template: py-win-x64-qnn.yml + - template: ../templates/py-win-x64-qnn.yml parameters: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} @@ -510,7 +370,7 @@ stages: - stage: Python_Packaging_Linux_x64_QNN dependsOn: [] jobs: - - template: py-linux-qnn.yml + - template: ../templates/py-linux-qnn.yml parameters: machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml similarity index 68% rename from tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml rename to tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index ae18687cb9e54..1ae95a296162c 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -5,15 +5,20 @@ parameters: type: string default: '' -- name: enable_linux_gpu - displayName: 'Whether Linux GPU package is built.' +- name: enable_linux_cuda + displayName: 'Whether Linux CUDA package is built.' type: boolean - default: true + default: false -- name: enable_windows_gpu - displayName: 'Whether Windows GPU package is built.' +- name: enable_windows_cuda + displayName: 'Whether Windows CUDA package is built.' type: boolean - default: true + default: false + +- name: enable_windows_dml + displayName: 'Whether Windows DML package is built.' + type: boolean + default: false # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type @@ -34,16 +39,6 @@ parameters: - 11.8 - 12.2 -- name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - -- name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - - name: PythonVersions type: object displayName: 'Python versions to build' @@ -53,23 +48,25 @@ parameters: - '3.12' - '3.13' +- name: publish_symbols + type: boolean + default: false + stages: - - ${{ if eq(parameters.enable_windows_gpu, true) }}: + - ${{ if eq(parameters.enable_windows_cuda, true) }}: - ${{ each python_version in parameters.PythonVersions }}: - - template: ../templates/py-win-gpu.yml + - template: py-win-gpu-stage.yml parameters: PYTHON_VERSION: ${{ python_version }} EP_NAME: gpu CudaVersion: ${{ parameters.cuda_version }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} ${{ if eq(parameters.cuda_version, '11.8') }}: EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ${{ if eq(parameters.cuda_version, '12.2') }}: EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 --cuda_home=$(Agent.TempDirectory)\v12.2 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - - ${{ if eq(parameters.enable_linux_gpu, true) }}: - - template: ../templates/py-linux-gpu.yml + - ${{ if eq(parameters.enable_linux_cuda, true) }}: + - template: py-linux-gpu-stage.yml parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' @@ -82,3 +79,15 @@ stages: ${{ if eq(parameters.cuda_version, '12.2') }}: docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 trt_version: 10.4.0.26-1.cuda12.6 + + - ${{ if eq(parameters.enable_windows_dml, true) }}: + - ${{ each python_version in parameters.PythonVersions }}: + - template: py-win-gpu-stage.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' + PYTHON_VERSION: ${{ python_version }} + EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos + ENV_SETUP_SCRIPT: setup_env.bat + EP_NAME: directml + publish_symbols: ${{ parameters.publish_symbols }} + cmake_build_type: ${{ parameters.cmake_build_type }} \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml similarity index 63% rename from tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml rename to tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml index d19472bcbab5a..f9053cba56835 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml @@ -41,7 +41,27 @@ stages: timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.machine_pool }} + pool: + name: ${{ parameters.machine_pool }} + os: linux + templateContext: + codeSignValidation: + enabled: true + break: true + psscriptanalyzer: + enabled: true + sdl: + binskim: + enabled: true + scanOutputDirectoryOnly: true + targetPathPattern: '\".*.so\"' + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/dist + artifactName: onnxruntime_gpu + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/${{ parameters.cmake_build_type }} + artifactName: linux_gpu_wheel_${{ parameters.arch }} variables: # The build machine pool doesn't have dotnet, so it can't run CG. - name: skipComponentGovernanceDetection @@ -56,9 +76,9 @@ stages: clean: true submodules: recursive - - template: set-nightly-build-option-variable-step.yml + - template: ../templates/set-nightly-build-option-variable-step.yml - - template: get-docker-image-steps.yml + - template: ../templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda/Dockerfile Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda @@ -73,17 +93,18 @@ stages: filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh arguments: -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} $(extra_build_args) - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - PathtoPublish: '$(Build.BinariesDirectory)/dist' - ArtifactName: onnxruntime_gpu - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Test Binaries' - inputs: - artifactName: 'drop-linux-gpu-${{ parameters.arch }}' - targetPath: '$(Build.BinariesDirectory)/Release' + - script: | + set -e -x + mv $(Build.BinariesDirectory)/${{ parameters.cmake_build_type }} ./${{ parameters.cmake_build_type }} + mv $(Build.BinariesDirectory)/dist ./dist + pushd dist + find . -name \*.whl -exec unzip -qq -o {} \; + popd + pushd ${{ parameters.cmake_build_type }} + find . -name \*.whl -exec unzip -qq -o {} \; + popd + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'Move files' - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml similarity index 93% rename from tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml rename to tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index 71500e4ef9025..0cbcd2b74371e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -28,16 +28,6 @@ parameters: - 11.8 - 12.2 -- name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - -- name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - - name: cmake_build_type type: string displayName: 'Linux packages cmake build type. Linux Only.' @@ -75,7 +65,7 @@ stages: clean: true submodules: recursive - - template: telemetry-steps.yml + - template: ../templates/telemetry-steps.yml - task: UsePythonVersion@0 inputs: @@ -89,10 +79,10 @@ stages: tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' appendSourceBranchName: false - - template: download-deps.yml + - template: ../templates/download-deps.yml - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}: - - template: jobs/set-winenv.yml + - template: ../templates/jobs/set-winenv.yml parameters: EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: @@ -101,7 +91,7 @@ stages: DownloadTRT: true - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}: - - template: jobs/download_win_gpu_library.yml + - template: ../templates/jobs/download_win_gpu_library.yml parameters: CudaVersion: ${{ parameters.CudaVersion }} ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: @@ -123,7 +113,7 @@ stages: workingDirectory: '$(Build.BinariesDirectory)' arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\installed -build_config ${{ parameters.cmake_build_type }} - - template: set-nightly-build-option-variable-step.yml + - template: ../templates/set-nightly-build-option-variable-step.yml - task: PythonScript@0 displayName: 'Generate cmake config' @@ -153,7 +143,7 @@ stages: workingDirectory: '$(Build.BinariesDirectory)' # Esrp signing - - template: win-esrp-dll.yml + - template: ../templates/win-esrp-dll.yml parameters: FolderPath: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime\capi' DisplayName: 'ESRP - Sign Native dlls' @@ -216,7 +206,7 @@ stages: GdnPublishTsaOnboard: false GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - template: component-governance-component-detection-steps.yml + - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' @@ -243,13 +233,11 @@ stages: addToPath: true architecture: 'x64' - - template: flex-downloadPipelineArtifact.yml + - template: ../templates/flex-downloadPipelineArtifact.yml parameters: ArtifactName: onnxruntime_${{ parameters.EP_NAME }} StepName: 'Download Pipeline Artifact - Windows GPU Build' TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - task: PowerShell@2 displayName: 'Install ONNX' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index 4ca462bf962f5..6a74d0e7befd3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -61,7 +61,7 @@ jobs: # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - download: build # pipeline resource identifier. - artifact: 'drop-linux-gpu-${{ parameters.arch }}' + artifact: 'linux_gpu_wheel_${{ parameters.arch }}' - download: build # pipeline resource identifier. artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' @@ -69,7 +69,7 @@ jobs: - bash: | set -e -x ls $(Pipeline.Workspace)/build - mv "$(Pipeline.Workspace)/build/drop-linux-gpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/linux_gpu_wheel_${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; From 33e2f6ad8d335c526358466f81da5515a8ab9352 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 23 Oct 2024 23:18:16 +0800 Subject: [PATCH 5/8] [WebNN EP] Support external data (#22263) ### Description This PR introduces support for registering external data inside WebNN EP. ### Motivation and Context - The WebNN EP needs to register the initializers at graph compilation stage, for initializers from external data, it can't leverage the general external data loader framework because the graph compilation of WebNN EP is executed before external data loader called. - Exposes the `utils::GetExternalDataInfo`, it is useful for WebNN EP to read the external tensor's infomation. - Define a new `registerMLConstant` in JSEP to create WebNN constants from external data in WebNN backend, with the info of tensor as parameters, as well as the `Module.MountedFiles`, which holds all preloaded external files. --- js/web/lib/wasm/jsep/backend-webnn.ts | 63 ++++++++++ .../core/framework/tensorprotoutils.cc | 62 +++++----- onnxruntime/core/framework/tensorprotoutils.h | 14 +++ .../webnn/builders/impl/base_op_builder.cc | 25 ---- .../providers/webnn/builders/model_builder.cc | 113 ++++++++++-------- onnxruntime/wasm/pre-jsep.js | 5 + 6 files changed, 178 insertions(+), 104 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index d13136d252d2a..37eb0e0edc67c 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -163,6 +163,69 @@ export class WebNNBackend { return id; } + // Register WebNN Constant operands from external data. + public registerMLConstant( + externalFilePath: string, + dataOffset: number, + dataLength: number, + builder: MLGraphBuilder, + desc: MLOperandDescriptor, + mountedFiles: Map | undefined, + ): MLOperand { + // If available, "Module.MountedFiles" is a Map for all preloaded files. + if (!mountedFiles) { + throw new Error('External mounted files are not available.'); + } + + let filePath = externalFilePath; + if (externalFilePath.startsWith('./')) { + filePath = externalFilePath.substring(2); + } + const fileData = mountedFiles.get(filePath); + if (!fileData) { + throw new Error(`File with name ${filePath} not found in preloaded files.`); + } + + if (dataOffset + dataLength > fileData.byteLength) { + throw new Error('Out of bounds: data offset and length exceed the external file data size.'); + } + + const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer; + let bufferView: ArrayBufferView; + switch (desc.dataType) { + case 'float32': + bufferView = new Float32Array(buffer); + break; + case 'float16': + bufferView = new Uint16Array(buffer); + break; + case 'int32': + bufferView = new Int32Array(buffer); + break; + case 'uint32': + bufferView = new Uint32Array(buffer); + break; + case 'int64': + bufferView = new BigInt64Array(buffer); + break; + case 'uint64': + bufferView = new BigUint64Array(buffer); + break; + case 'int8': + bufferView = new Int8Array(buffer); + break; + case 'uint8': + bufferView = new Uint8Array(buffer); + break; + default: + throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`); + } + + LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`); + + return builder.constant(desc, bufferView); + } + public flush(): void { // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations. } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 74c359881a1d7..2af9f95ad059e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -165,37 +165,6 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2) DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) -static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& tensor_proto_dir, - std::basic_string& external_file_path, - onnxruntime::FileOffsetType& file_offset, - SafeInt& tensor_byte_size) { - ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), - "Tensor does not have external data to read from."); - - ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), - "External data type cannot be UNDEFINED or STRING."); - - std::unique_ptr external_data_info; - ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); - - const auto& location = external_data_info->GetRelPath(); - - external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) - : (tensor_proto_dir / location); - - ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); - const size_t external_data_length = external_data_info->GetLength(); - ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, - "TensorProto: ", tensor_proto.name(), - " external data size mismatch. Computed size: ", *&tensor_byte_size, - ", external_data.length: ", external_data_length); - - file_offset = external_data_info->GetOffset(); - - return Status::OK(); -} - // Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully. // Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr // then uses the current directory instead. @@ -261,6 +230,37 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size) { + ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), + "Tensor does not have external data to read from."); + + ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), + "External data type cannot be UNDEFINED or STRING."); + + std::unique_ptr external_data_info; + ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); + + const auto& location = external_data_info->GetRelPath(); + + external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) + : (tensor_proto_dir / location); + + ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); + const size_t external_data_length = external_data_info->GetLength(); + ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, + "TensorProto: ", tensor_proto.name(), + " external data size mismatch. Computed size: ", *&tensor_byte_size, + ", external_data.length: ", external_data_length); + + file_offset = external_data_info->GetOffset(); + + return Status::OK(); +} + void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) { tensor_proto.set_raw_data(std::move(param)); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 227ba0706197e..262f7adaca1cb 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -23,6 +23,20 @@ namespace onnxruntime { namespace utils { +/** + * This function is used to get the external data info from the given tensor proto. + * @param tensor_proto given initializer tensor + * @param tensor_proto_dir directory of the tensor proto file + * @param external_file_path output external file path + * @param file_offset output tensor offset + * @param tensor_byte_size output tensor byte size + * @returns Status::OK() if the function is executed successfully + */ +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size); /** * This function is used to convert the endianess of Tensor data. * Mostly, will be used in big endian system to support the model file diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 8da255a288f17..fffe964e6aaf2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -12,27 +12,6 @@ namespace onnxruntime { namespace webnn { - -// Shared functions. -bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, - const logging::Logger& logger) { - for (const auto* node_arg : node.InputDefs()) { - const auto& input_name(node_arg->Name()); - if (!Contains(initializers, input_name)) - continue; - - const auto& tensor = *initializers.at(input_name); - if (tensor.has_data_location() && - tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - LOGS(logger, VERBOSE) << "Initializer [" << input_name - << "] with external data location are not currently supported"; - return true; - } - } - - return false; -} - // Add operator related. Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, @@ -58,10 +37,6 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) return false; - // We do not support external initializers for now. - if (HasExternalInitializer(initializers, node, logger)) - return false; - if (!HasSupportedOpSet(node, logger)) return false; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 044baa738e8c4..8a7fea0cde431 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -112,56 +112,73 @@ Status ModelBuilder::RegisterInitializers() { auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); std::byte* tensor_ptr = nullptr; - if (tensor.has_raw_data()) { - tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); + + if (utils::HasExternalData(tensor)) { + // Create WebNN Constant from external data. + std::basic_string external_file_path; + onnxruntime::FileOffsetType data_offset; + SafeInt tensor_byte_size; + ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo( + tensor, graph_viewer_.ModelPath(), external_file_path, data_offset, tensor_byte_size)); + + auto jsepRegisterMLConstant = emscripten::val::module_property("jsepRegisterMLConstant"); + operand = jsepRegisterMLConstant(emscripten::val(external_file_path), + static_cast(data_offset), + static_cast(tensor_byte_size), + wnn_builder_, + desc); } else { - // Store temporary unpacked_tensor. - unpacked_tensors_.push_back({}); - std::vector& unpacked_tensor = unpacked_tensors_.back(); - ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); - tensor_ptr = reinterpret_cast(unpacked_tensor.data()); - } - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - default: - break; + if (tensor.has_raw_data()) { + tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); + } else { + // Store temporary unpacked_tensor. + unpacked_tensors_.push_back({}); + std::vector& unpacked_tensor = unpacked_tensors_.back(); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); + tensor_ptr = reinterpret_cast(unpacked_tensor.data()); + } + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + default: + break; + } + + // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached + // buffers in JS side. Simply create a copy to fix it. + operand = wnn_builder_.call("constant", desc, view.call("slice")); } - - // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached - // buffers in JS side. Simply create a copy to fix it. - operand = wnn_builder_.call("constant", desc, view.call("slice")); } else { // TODO: support other type. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 68332d07a9782..78d60326dd0a8 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -235,5 +235,10 @@ Module['jsepInit'] = (name, params) => { Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { return backend['registerMLTensor'](tensor, dataType, shape); } + + Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => { + return backend['registerMLConstant']( + externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); + } } }; From fd8ee4894dbf828051dbfe6da7c9ccd9d4a46acb Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 23 Oct 2024 10:14:09 -0700 Subject: [PATCH 6/8] [JS/WebGPU] GroupQueryAttention rewrite (#20946) ### Description Implement JSEP GroupQueryAttention ### Motivation and Context Required to enable certain LLM models to run using WebGPU. --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 354 +++-- .../jsep/webgpu/ops/group-query-attention.ts | 301 ++--- .../jsep/webgpu/ops/multihead-attention.ts | 16 +- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 2 +- .../test/data/ops/group-query-attention.jsonc | 1144 +++++++++++++---- .../js/bert/group_query_attention.h | 30 +- 7 files changed, 1304 insertions(+), 547 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index fe824a5c4558a..09c786daa3fcd 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -19,7 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather'; import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized'; import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements'; import { gemm, parseGemmAttributes } from './ops/gemm'; -import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention'; +import { groupQueryAttention } from './ops/group-query-attention'; import { instanceNorm } from './ops/instance-norm'; import { layerNorm } from './ops/layer-norm'; import { matMul } from './ops/matmul'; @@ -104,7 +104,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], - ['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]], + ['GroupQueryAttention', [groupQueryAttention]], ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 832f6e132901e..6a78c8ae3b190 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -8,6 +8,7 @@ import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramU import { getMaxComponents, + IndicesHelper, inputVariable, outputVariable, ShaderHelper, @@ -65,14 +66,17 @@ export interface AttentionParameters { broadcastResPosBias: boolean; passPastInKv: boolean; qkvFormat: AttentionQkvFormat; - isPastkvBSNH?: boolean; + softcap?: number; + doRotary?: number; + rotaryInterLeaved?: number; + sommoothSoftmax?: number; + localWindowsSize?: number; } export interface AttentionAttrs { numHeads: number; - kvNumHeads?: number; - isUnidirectional?: number; - maskFilterValue?: number; + isUnidirectional: number; + maskFilterValue: number; scale: number; doRotary: number; qkvHiddenSizes: number[]; @@ -258,41 +262,106 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte }; }; -const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => { - const components = getMaxComponents(d); +const initVarStub = ( + seqLensInput: IndicesHelper | undefined, + totalSequenceLengthInput: IndicesHelper | undefined, + initPastSequenceLength: boolean, +) => { + // In the case of GQA, redefine total_sequence_length, present_sequence_length and past_sequence_length based on seqlen_k input + if (totalSequenceLengthInput && seqLensInput) { + return ` + let total_sequence_length_input = u32(${totalSequenceLengthInput.getByOffset('0')}); + let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length); + let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input; + let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input; + total_sequence_length = u32(${seqLensInput?.getByOffset('batchIdx')}) + 1; + var past_sequence_length: u32 = 0; + if (is_first_prompt == false) { + past_sequence_length = total_sequence_length - sequence_length; + } + `; + } else { + return ` + ${initPastSequenceLength ? 'let past_sequence_length = uniforms.past_sequence_length' : ''}; + let present_sequence_length = total_sequence_length; + `; + } +}; + +const createInPlaceSoftmaxProgramInfo = ( + input: TensorView, + batchSize: number, + numHeads: number, + pastSequenceLength: number, + sequenceLength: number, + totalSequenceLength: number, + seqLens: TensorView | undefined, + totalSequenceLengthInput: TensorView | undefined, +) => { + // Set components to 1 if seqLens is specified, i.e. GroupQueryAttention. + const components = getMaxComponents(seqLens ? 1 : totalSequenceLength); let WG = 64; - const dComp = d / components; - if (dComp < WG) { + const totalSequenceLengthComp = totalSequenceLength / components; + if (totalSequenceLengthComp < WG) { WG = 32; } - const elementsPerThread = Math.ceil(d / components / WG); + const elementsPerThread = Math.ceil(totalSequenceLength / components / WG); const programUniforms: ProgramUniform[] = [ - { type: DataType.float, data: 1 / d }, - { type: DataType.uint32, data: dComp }, + { type: DataType.uint32, data: batchSize }, + { type: DataType.uint32, data: numHeads }, + { type: DataType.uint32, data: pastSequenceLength }, + { type: DataType.uint32, data: sequenceLength }, + { type: DataType.uint32, data: totalSequenceLengthComp }, { type: DataType.uint32, data: elementsPerThread }, ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const f32Type = tensorTypeToWsglValueType(DataType.float, components); const inputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const getShaderSource = (shaderHelper: ShaderHelper) => { const inputHelper = outputVariable('x', input.dataType, input.dims, components); + const inputHelpers = [inputHelper]; + const seqLensInputHelper = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLensInputHelper) { + inputHelpers.push(seqLensInputHelper); + } + + const totalSequenceLengthInputHelper = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInputHelper) { + inputHelpers.push(totalSequenceLengthInputHelper); + } const elemValueType = tensorTypeToWsglValueType(input.dataType); const uniforms: UniformsArrayType = [ - { name: 'd_inv', type: 'f32' }, - { name: 'd_comp', type: 'u32' }, + { name: 'batch_size', type: 'u32' }, + { name: 'num_heads', type: 'u32' }, + { name: 'past_sequence_length', type: 'u32' }, + { name: 'sequence_length', type: 'u32' }, + { name: 'total_sequence_length', type: 'u32' }, { name: 'elements_per_thread', type: 'u32' }, ]; return ` var thread_max: array; var thread_sum: array; - ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers)} ${shaderHelper.mainStart([WG, 1, 1])} + let batchIdx = workgroup_id.z / uniforms.num_heads; + let headIdx = workgroup_id.z % uniforms.num_heads; + let sequence_length = uniforms.sequence_length; + var total_sequence_length = uniforms.total_sequence_length; + ${initVarStub(seqLensInputHelper, totalSequenceLengthInputHelper, false)} let local_offset = local_idx * uniforms.elements_per_thread; - let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset; - + let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; + let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; var thread_max_vector = ${f32Type}(-3.402823e+38f); - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } thread_max[local_idx] = ${(() => { @@ -315,7 +384,7 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number } var sum_vector = ${f32Type}(0); - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { sum_vector += exp(${f32Type}(x[offset + i]) - max_value); } thread_sum[local_idx] = ${(() => { @@ -338,15 +407,23 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number } if (sum == 0) { - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { - x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv)); + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { + x[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length)); } } else { - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { var f32input = ${f32Type}(x[offset + i]); x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum); } } + ${ + seqLens + ? ` + for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) { + x[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0)); + }` + : '' + }; }`; }; @@ -354,7 +431,11 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number name: 'AttentionProbsSoftmax', shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies }, getShaderSource, - getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }), + getRunData: () => ({ + outputs: [], + dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads }, + programUniforms, + }), }; }; @@ -365,19 +446,21 @@ const createAttentionProbsProgramInfo = ( pastKey: TensorView | undefined, attentionBias: TensorView | undefined, parameters: AttentionParameters, - attributes: AttentionAttrs, pastSequenceLength: number, + seqLens: TensorView | undefined, + totalSequenceLengthInput: TensorView | undefined, ) => { const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; - const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey; + const presentKey = outputCount > 1 && pastKey; + const kvNumHeads = parameters.kvNumHeads ? parameters.kvNumHeads : parameters.numHeads; const presentKeyShape = presentKey - ? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] + ? [parameters.batchSize, kvNumHeads, totalSequenceLength, parameters.headSize] : undefined; - + const nReps = parameters.nReps ? parameters.nReps : 1; // TODO: handle mask - const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + const alpha = parameters.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : parameters.scale; const components = getMaxComponents(parameters.headSize); const vectorizedHeadSize = parameters.headSize / components; const TILE_SIZE = 12; @@ -391,9 +474,11 @@ const createAttentionProbsProgramInfo = ( { type: DataType.uint32, data: vectorizedHeadSize }, { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: parameters.numHeads }, + { type: DataType.uint32, data: parameters.headSize }, { type: DataType.float, data: alpha }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: parameters.kvSequenceLength }, + { type: DataType.uint32, data: nReps }, ]; // Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0; @@ -404,6 +489,12 @@ const createAttentionProbsProgramInfo = ( if (attentionBias) { inputDependencies.push('type'); } + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }]; if (presentKey) { outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default }); @@ -419,6 +510,16 @@ const createAttentionProbsProgramInfo = ( if (attentionBias) { inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims)); } + const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLensInputVariable) { + inputVars.push(seqLensInputVariable); + } + const totalSequenceLengthInputVariable = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInputVariable) { + inputVars.push(totalSequenceLengthInputVariable); + } const output = outputVariable('output', q.dataType, probsShape); const outputVars = [output]; if (presentKey) { @@ -431,9 +532,11 @@ const createAttentionProbsProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, + { name: 'head_size', type: 'u32' }, { name: 'alpha', type: 'f32' as UniformDataElementType }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, + { name: 'n_reps', type: 'u32' }, ]; return ` const TILE_SIZE = ${TILE_SIZE}u; @@ -443,21 +546,20 @@ const createAttentionProbsProgramInfo = ( ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} // x holds the N and y holds the M - let headIdx = workgroup_id.z; + let headIdx = workgroup_id.z % uniforms.num_heads; + let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'}; + let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'}; + let batchIdx = workgroup_id.z / uniforms.num_heads; let m = workgroup_id.y * TILE_SIZE; let n = workgroup_id.x * TILE_SIZE; - let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; - ${(() => { - if (feedPastKey && presentKey) { - return ` - let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx; - let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`; - } else { - return ` - let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`; - } - })()} - ${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''} + let sequence_length = uniforms.M; + var total_sequence_length = uniforms.N; + ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; + let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; + ${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''}; + let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K; + ${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''} var value = ${f32Type}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) { @@ -468,31 +570,37 @@ const createAttentionProbsProgramInfo = ( ${(() => { if (feedPastKey && presentKey) { return ` - if (n + local_id.y < uniforms.past_sequence_length) { + if (n + local_id.y < past_sequence_length) { tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; - } else { - tileK[idx] = - key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x]; + } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { + tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x]; }`; } else { - return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];'; + return ` + if (n + local_id.y < uniforms.kv_sequence_length) { + tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; + }`; } })()} ${ - presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : '' + presentKey + ? `if (n + local_id.y < present_sequence_length) { + present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx]; + }` + : '' } } workgroupBarrier(); for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { - value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); + value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); } workgroupBarrier(); } - let headOffset = headIdx * uniforms.M * uniforms.N; - if (global_id.y < uniforms.M && global_id.x < uniforms.N) { + if (global_id.y < uniforms.M && global_id.x < total_sequence_length) { + let headOffset = workgroup_id.z * uniforms.M * uniforms.N; let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x; var sum: f32 = ${(() => { switch (components) { @@ -530,13 +638,16 @@ const createVxAttentionScoreProgramInfo = ( pastValue: TensorView | undefined, params: AttentionParameters, pastSequenceLength: number, + seqLens: TensorView | undefined = undefined, + totalSequenceLengthInput: TensorView | undefined = undefined, ) => { const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; const nReps = params.nReps ? params.nReps : 1; const repeatedVHiddenSize = params.vHiddenSize * nReps; - const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue; + const presentValue = outputCount > 1 && pastValue; + const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; const presentValueShape = presentValue - ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] + ? [params.batchSize, kvNumHeads, totalSequenceLength, params.headSize] : undefined; const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize]; const TILE_SIZE = 12; @@ -551,9 +662,11 @@ const createVxAttentionScoreProgramInfo = ( { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: params.vHeadSize }, { type: DataType.uint32, data: params.numHeads }, + { type: DataType.uint32, data: params.headSize }, { type: DataType.uint32, data: repeatedVHiddenSize }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: params.kvSequenceLength }, + { type: DataType.uint32, data: nReps }, ]; // Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0; @@ -561,6 +674,12 @@ const createVxAttentionScoreProgramInfo = ( if (feedPastValue) { inputDependencies.push('type'); } + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }]; if (presentValue) { outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default }); @@ -572,6 +691,16 @@ const createVxAttentionScoreProgramInfo = ( if (feedPastValue) { inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims)); } + const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLens) { + inputVars.push(seqLensInputVariable!); + } + const totalSequenceLengthInputVariable = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInput) { + inputVars.push(totalSequenceLengthInputVariable!); + } const output = outputVariable('output', probs.dataType, outputShape); const outputVars = [output]; if (presentValue) { @@ -582,34 +711,32 @@ const createVxAttentionScoreProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, + { name: 'head_size', type: 'u32' }, { name: 'v_hidden_size', type: 'u32' }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, + { name: 'n_reps', type: 'u32' }, ]; return ` const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; - var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + var tileV: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} - let headIdx = workgroup_id.z; + let headIdx = workgroup_id.z % uniforms.num_heads; + let batchIdx = workgroup_id.z / uniforms.num_heads; + let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'}; + let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'}; let m = global_id.y; let n = global_id.x; - - let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; - ${(() => { - if (feedPastValue && presentValue) { - return ` - let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n; - let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n; - `; - } else { - return ` - let offsetB = headIdx * uniforms.N * uniforms.K + n; - `; - } - })()} - ${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''} + let sequence_length = uniforms.M; + var total_sequence_length = uniforms.K; + ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; + let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch + ${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''}; + let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n; + ${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''} var value = ${probsHelper.type.storage}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (m < uniforms.M && w + local_id.x < uniforms.K) { @@ -620,33 +747,39 @@ const createVxAttentionScoreProgramInfo = ( ${(() => { if (feedPastValue && presentValue) { return ` - if (w + local_id.y < uniforms.past_sequence_length) { - tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; - } else { - tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N]; + if (w + local_id.y < past_sequence_length) { + tileV[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; + } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { + tileV[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N]; } `; } else { return ` - tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N]; - `; + if (w + local_id.y < uniforms.kv_sequence_length) { + tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N]; + }`; } })()} - ${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''} + ${ + presentValue + ? ` + if (w + local_id.y < present_sequence_length) { + present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx]; + }` + : '' + } } workgroupBarrier(); - for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { - value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x]; + for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) { + value += tileQ[TILE_SIZE * local_id.y + k] * tileV[TILE_SIZE * k + local_id.x]; } workgroupBarrier(); } // we need to transpose output from BNSH_v to BSND_v - let batchIdx = workgroup_id.z / uniforms.num_heads; - let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads; if (m < uniforms.M && n < uniforms.N) { let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size - + currentBatchHeadNumber * uniforms.N + n; + + headIdx * uniforms.N + n; output[outputIdx] = value; } }`; @@ -671,23 +804,29 @@ export const applyAttention = ( pastValue: TensorView | undefined, attentionBiasInput: TensorView | undefined, parameters: AttentionParameters, - attributes: AttentionAttrs, + seqLens: TensorView | undefined = undefined, + totalSequenceLengthInput: TensorView | undefined = undefined, ) => { - // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists. + // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists. const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0)); - const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; + const pastSequenceLength = outputCount > 1 ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const attentionBias = attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined; const inputsK = [q, k]; - if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { + if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { inputsK.push(pastKey); } if (attentionBias) { inputsK.push(attentionBias); } - + if (seqLens) { + inputsK.push(seqLens); + } + if (totalSequenceLengthInput) { + inputsK.push(totalSequenceLengthInput); + } // Run AttentionProbs const probs = context.compute( createAttentionProbsProgramInfo( @@ -697,31 +836,55 @@ export const applyAttention = ( pastKey, attentionBias, parameters, - attributes, pastSequenceLength, + seqLens, + totalSequenceLengthInput, ), - { inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] }, + { inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] }, )[0]; // Run Softmax context.compute( createInPlaceSoftmaxProgramInfo( probs, - parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.batchSize, + parameters.numHeads, + pastSequenceLength, + parameters.sequenceLength, totalSequenceLength, + seqLens, + totalSequenceLengthInput, ), - { inputs: [probs], outputs: [] }, + { inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [] }, ); - // Run AttrionScore + // Run AttentionScore const inputsV = [probs, v]; - if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { + if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { inputsV.push(pastValue); } - context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), { - inputs: inputsV, - outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0], - }); + if (seqLens) { + inputsV.push(seqLens); + } + if (totalSequenceLengthInput) { + inputsV.push(totalSequenceLengthInput); + } + context.compute( + createVxAttentionScoreProgramInfo( + outputCount, + probs, + v, + pastValue, + parameters, + pastSequenceLength, + seqLens, + totalSequenceLengthInput, + ), + { + inputs: inputsV, + outputs: outputCount > 1 ? [0, 2] : [0], + }, + ); }; const prepare = (context: ComputeContext, parameters: AttentionParameters) => { @@ -857,6 +1020,5 @@ export const attention = (context: ComputeContext, attributes: AttentionAttrs): undefined, context.inputs[5], params, - attributes, ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 56291c037b7da..bbe25460d6fd3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -1,31 +1,49 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; -import { ShapeUtil } from '../../util'; import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; -import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { ComputeContext } from '../types'; -import { - applyAttention, - AttentionAttrs, - AttentionMaskType, - AttentionParameters, - AttentionQkvFormat, -} from './attention'; -import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; +import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention'; import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; -import { createTileProgramInfo } from './tile'; +import { createSplitProgramInfo, SplitAttributes } from './split'; import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; - -export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { +export interface GroupQueryAttentionAttributes { + numHeads: number; + kvNumHeads: number; + scale: number; + softcap: number; + doRotary: number; + rotaryInterleaved: number; + smoothSoftmax: boolean; + localWindowSize: number; +} + +export const validateInputs = ( + inputs: readonly TensorView[], + attributes: GroupQueryAttentionAttributes, +): AttentionParameters => { + if (attributes.doRotary && inputs.length <= 7) { + throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified'); + } const query = inputs[0]; const key = inputs[1]; const value = inputs[2]; const pastKey = inputs[3]; const pastValue = inputs[4]; - + if (attributes.localWindowSize !== -1) { + throw new Error('Local attention is not supported'); + } + if (attributes.softcap !== 0) { + throw new Error('Softcap is not supported'); + } + if (attributes.rotaryInterleaved !== 0) { + throw new Error('Rotary interleaved is not supported'); + } + if (attributes.smoothSoftmax) { + throw new Error('Smooth softmax is not supported'); + } // Abbreviation and Meanings: // B: batch_size // S: sequence_length (input sequence length of query) @@ -62,17 +80,32 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent const dmmhaPacking = false; const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; - const hiddenSize = + let hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4]; let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; - let maxSequenceLength = 0; - const headSize = Math.floor(hiddenSize / attributes.numHeads); + const packedQKV = !key || key.dims.length === 0; + const headSize = !packedQKV + ? Math.floor(hiddenSize / attributes.numHeads) + : Math.floor(hiddenSize / (attributes.numHeads + 2 * attributes.kvNumHeads)); + if (packedQKV) { + hiddenSize = headSize * attributes.numHeads; + } const hasPastKey = pastKey && pastKey.dims.length !== 0; const hasPastValue = pastValue && pastValue.dims.length !== 0; - // TODO : this should be from attributes. - const isPastkvBSNH = true; + // Currenly the onnxruntime GQA specification only support key/value BNSH format. + const isPastkvBSNH = + hasPastKey && + pastKey.dims.length === 4 && + pastKey.dims[0] === batchSize && + pastKey.dims[1] !== attributes.kvNumHeads && + pastKey.dims[2] === attributes.kvNumHeads && + pastKey.dims[3] === headSize; + + if (isPastkvBSNH) { + throw new Error('BSNH pastKey/pastValue is not supported'); + } if (hasPastKey && hasPastValue) { if (pastKey.dims.length !== 4) { throw new Error('Input "past_key" is expected to have 4 dimensions'); @@ -80,21 +113,13 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (pastValue.dims.length !== 4) { throw new Error('Input "past_value" is expected to have 4 dimensions'); } - if (isPastkvBSNH) { - // For BSNH - pastSequenceLength = pastKey.dims[1]; - maxSequenceLength = pastKey.dims[1]; - } else { - // For BNSH - pastSequenceLength = pastKey.dims[2]; - maxSequenceLength = pastKey.dims[2]; - } + pastSequenceLength = pastKey.dims[2]; } else if (hasPastKey || hasPastValue) { throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); } - let qkvFormat: AttentionQkvFormat; - if (key) { + let qkvFormat: AttentionQkvFormat = AttentionQkvFormat.qkvBNSH; + if (key && key.dims.length > 0) { if (query.dims.length !== 3) { throw new Error('Input "query" is expected to have 3 dimensions when key is given'); } @@ -109,7 +134,6 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (query.dims[2] % key.dims[2] !== 0) { throw new Error('Dimension 2 of "query" should be a multiple of "key"'); } - qkvFormat = AttentionQkvFormat.qkvBSNH; kvSequenceLength = key.dims[1]; } else if (key.dims.length === 5) { if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { @@ -118,15 +142,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (value) { throw new Error('Expect "value" be none when "key" has packed kv format.'); } - qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; kvSequenceLength = key.dims[1]; } else { // key_dims.size() == 4 (cross-attention with past_key) if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } - - qkvFormat = AttentionQkvFormat.unknown; kvSequenceLength = key.dims[2]; } } else { @@ -143,8 +164,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent const maskType: AttentionMaskType = AttentionMaskType.none; let passPastInKv = false; - let vHiddenSize = hiddenSize; - if (value) { + let vHiddenSize = attributes.kvNumHeads ? headSize * attributes.kvNumHeads : hiddenSize; + if (value && value.dims.length > 0) { if (value.dims.length !== 3 && value.dims.length !== 4) { throw new Error('Input "value" is expected to have 3 or 4 dimensions'); } @@ -166,7 +187,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent passPastInKv = true; } } - const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const seqlLens = inputs.length > 4 ? inputs[5] : undefined; + if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) { + throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size'); + } + const totalSequenceLength = -1; + const maxSequenceLength = -1; const broadcastResPosBias = false; return { @@ -180,181 +206,36 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent hiddenSize, vHiddenSize, headSize, - vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!), + vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads), numHeads: attributes.numHeads, kvNumHeads: attributes.kvNumHeads, - nReps: attributes.numHeads / attributes.kvNumHeads!, + nReps: attributes.numHeads / attributes.kvNumHeads, pastPresentShareBuffer: false, maskType, scale: attributes.scale, broadcastResPosBias, passPastInKv, qkvFormat, - isPastkvBSNH, }; }; -const createConcatProgramInfo = ( - a: TensorView, - b: TensorView | undefined, - dataType: DataType, - params: AttentionParameters, -): ProgramInfo => { - const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize]; - const component = 4; - const outputSize = ShapeUtil.size(outputShape) / component; - const presentSequenceLength = params.totalSequenceLength; - const output = outputVariable('present_kv', dataType, outputShape.length, component); - const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component); - const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined; - - const H = Math.ceil(params.headSize / component); - const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 }; - - const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank']; - - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: outputSize }, - { type: DataType.uint32, data: params.pastSequenceLength }, - { type: DataType.uint32, data: params.kvSequenceLength }, - { type: DataType.uint32, data: params.totalSequenceLength }, - ]; - - const inputs = [inputA]; - if (inputB) { - programUniforms.push( - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b!.dims), - ...createTensorShapeVariables(outputShape), - ); - inputs.push(inputB); - } else { - programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape)); - } - const uniforms: UniformsArrayType = [ - { name: 'output_size', type: 'u32' }, - { name: 'past_seqlen', type: 'u32' }, - { name: 'new_seqlen', type: 'u32' }, - { name: 'present_seqlen', type: 'u32' }, - ]; - - const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H; - var past_head_stride = uniforms.past_seqlen * H; - if (is_bsnh) { - past_head_stride = H; - } - let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; - present_kv[out_offset] = past_kv[in_offset];`; - const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H; - let new_row_stride = num_heads * H; - let new_head_stride = H; - let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; - present_kv[out_offset] = new_kv[in_offset];`; - const concatStr = b - ? `if (s < past_seqlen) { - ${pastStr} - } else if (s < past_seqlen + uniforms.new_seqlen) { - ${newStr} - }` - : `if (s < past_seqlen + uniforms.new_seqlen) { - ${newStr} - }`; - - // TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit. - const getShaderSource = (shaderHelper: ShaderHelper) => ` - - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)} - ${shaderHelper.mainStart([H, params.kvNumHeads!, 1])} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - var indices = ${output.offsetToIndices('global_idx')}; - let h = local_id.x; - let n = local_id.y; - let s = workgroup_id.x; - let b = workgroup_id.y; - let num_heads = ${params.kvNumHeads!}u; - let H = ${H}u; - - let present_seqlen = uniforms.present_seqlen; - let present_batch_stride = present_seqlen * num_heads * H; - var row_stride = H; - let is_bsnh = ${params.isPastkvBSNH}; - - if (is_bsnh) { - row_stride = num_heads * H; - } - var present_head_stride = present_seqlen * H; - if (is_bsnh) { - present_head_stride = H; - } - - let past_seqlen = uniforms.past_seqlen; - - let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; - ${concatStr} - }`; - - return { - name: 'ConcatPastNew', - shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies }, - getRunData: () => ({ - outputs: [{ dims: outputShape, dataType }], - dispatchGroup: dispatch, - programUniforms, - }), - getShaderSource, - }; -}; - -export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => - createAttributeWithCacheKey({ ...attributes }); - const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] }); -const maybeExpandAndTransposeToBNSH = ( - context: ComputeContext, - input: TensorView, - pastKV: TensorView | undefined, - params: AttentionParameters, - outputIndex: number, -) => { +const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params: AttentionParameters) => { let reshapedInput = input; const numHeads = params.kvNumHeads!; - const nReps = params.nReps!; if (input.dims.length === 3 && params.kvSequenceLength !== 0) { reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]); - } - - if (pastKV) { - reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), { - inputs: [reshapedInput, pastKV], - outputs: [params.isPastkvBSNH ? outputIndex : -1], - })[0]; - } else { - reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), { - inputs: [reshapedInput], - outputs: [params.isPastkvBSNH ? outputIndex : -1], - })[0]; - } - if (nReps !== 1) { - reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), { + reshapedInput = context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], })[0]; - reshapedInput = reshapedInput.reshape([ - params.batchSize, - params.totalSequenceLength, - numHeads * nReps, - params.headSize, - ]); } - return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { - inputs: [reshapedInput], - outputs: [-1], - })[0]; + return reshapedInput; }; -export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { +export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { const params = validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 5) { throw new Error('Packed QKV is not implemented'); @@ -364,19 +245,49 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti throw new Error('Packed KV is not implemented'); } + const q = context.inputs[0]; + const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined; + const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined; + const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; + const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; + const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined; + const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined; + const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; + + // TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead. + + const splitAttributes: SplitAttributes = createAttributeWithCacheKey({ + axis: 2, + numOutputs: 3, + splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize], + }); + const [query, key, value] = + !k && !v + ? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] }) + : [q, k!, v!]; + const Q = maybeTransposeToBNSHAndAddBias( context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, - context.inputs[0], + query, undefined, 0, ); - const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; - const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; - const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1); - const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2); - applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes); + applyAttention( + context, + Q, + maybeTransposeToBNSH(context, key, params), + maybeTransposeToBNSH(context, value, params), + undefined, + undefined, + pastKey, + pastValue, + undefined, + params, + seqLens, + totalSequenceLengthInput, + ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 1a31253905694..db7a4b8e68b79 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -403,19 +403,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio ); if (kvBNSH) { - return applyAttention( - context, - Q, - key, - value, - keyPaddingMask, - undefined, - pastKey, - pastValue, - attentionBias, - params, - attributes, - ); + return applyAttention(context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params); } if (!key || !value) { throw new Error('key and value must be provided'); @@ -442,5 +430,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio 2 * params.hiddenSize, ); - applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes); + applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 1dc3a206cf94b..8c39505734e41 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => { }`; }; -const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { +export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); const dataType = inputs[0].dataType; diff --git a/js/web/test/data/ops/group-query-attention.jsonc b/js/web/test/data/ops/group-query-attention.jsonc index 2a4b265078456..036069f43eb54 100644 --- a/js/web/test/data/ops/group-query-attention.jsonc +++ b/js/web/test/data/ops/group-query-attention.jsonc @@ -1,6 +1,316 @@ [ { - "name": "GroupQueryAttention Basic", + "name": "GroupQueryAttention 0", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 1", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + // past key, BS* + { + "data": [40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // past value, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [2], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [48, 49, 50, 51, 52, 53, 54, 55], + "dims": [1, 1, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [40, 41, 42, 43, 44, 45, 46, 47, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 2", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // key, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 72, 73, 74, 75, 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 3", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], + "type": "float32" + }, + // key, BS* + { + "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 4", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ @@ -12,44 +322,293 @@ "name": "T[0]", "inputs": [ { - "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, - 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 - ], - "dims": [1, 3, 16], + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 32], + "type": "float32" + }, + // key, BS* + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // value, BS* + { + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, + 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // past key, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // past value, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, + 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81, 82, 83, 84, 85, + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 32], + "type": "float32" + }, + { + // present key, BNSH + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 8, 9, 10, 11, 12, + 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 2, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 64, 65, 66, 67, 68, 69, 70, 71, 80, 81, 82, 83, 84, 85, 86, 87, 56, 57, + 58, 59, 60, 61, 62, 63, 72, 73, 74, 75, 76, 77, 78, 79, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 2, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 5", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 6", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], + "type": "float32" + }, + // key, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 7", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], "type": "float32" }, // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], "dims": [1, 3, 8], "type": "float32" }, // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], "dims": [1, 3, 8], "type": "float32" }, // past key, BS* { - "data": null, + "data": [96, 97, 98, 99, 100, 101, 102, 103], + "dims": [1, 1, 1, 8], "type": "float32" }, // past value, BS* { - "data": null, + "data": [104, 105, 106, 107, 108, 109, 110, 111], + "dims": [1, 1, 1, 8], "type": "float32" }, // seqlens_k, unimplemented { - "data": [1], + "data": [3], "dims": [1], "type": "int32" }, // total_sequence_length, unimplemented { - "data": [1], + "data": [4], "dims": [1], "type": "int32" } @@ -57,22 +616,28 @@ "outputs": [ { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, - 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21 + 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108, + 109, 110, 111 ], - "dims": [1, 3, 16], + "dims": [1, 3, 8], "type": "float32" }, { - // present key, BS* - "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], - "dims": [1, 3, 2, 4], + // present key, BNSH + "data": [ + 96, 97, 98, 99, 100, 101, 102, 103, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + 65, 66, 67, 68, 69, 70, 71 + ], + "dims": [1, 1, 4, 8], "type": "float32" }, { - // present value, BS* - "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], - "dims": [1, 3, 2, 4], + // present value, BNSH + "data": [ + 104, 105, 106, 107, 108, 109, 110, 111, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 1, 4, 8], "type": "float32" } ] @@ -80,13 +645,12 @@ ] }, { - "name": "GroupQueryAttention Scale", + "name": " GroupQueryAttention 8", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" }, - { "name": "scale", "data": 2.0, "type": "float" } + { "name": "kv_num_heads", "data": 2, "type": "int" } ], "cases": [ { @@ -94,38 +658,43 @@ "inputs": [ { "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 4, 8], + "dims": [1, 1, 32], "type": "float32" }, + // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], "type": "float32" }, + // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], + "dims": [1, 1, 16], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -135,35 +704,34 @@ "outputs": [ { "data": [ - 1.000006079673767, 1.000006079673767, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, - 1, 1, 1, 1.9820137023925781, 1.9820137023925781, 1.9999991655349731, 1.9999991655349731 + 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 4, 8], + "dims": [1, 1, 32], "type": "float32" }, { - // present key, BS* - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present key, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 1, 8], "type": "float32" }, { - // present value, BS* - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], + "dims": [1, 2, 1, 8], "type": "float32" } ] } ] }, - { - "name": "GroupQueryAttention, different sequence length", + "name": "GroupQueryAttention 9", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "num_heads", "data": 2, "type": "int" }, { "name": "kv_num_heads", "data": 2, "type": "int" } ], "cases": [ @@ -171,39 +739,41 @@ "name": "T[0]", "inputs": [ { - "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 - ], - "dims": [1, 4, 8], + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], "type": "float32" }, + // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], "type": "float32" }, + // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -212,23 +782,20 @@ ], "outputs": [ { - "data": [ - 1.014165997505188, 1.014165997505188, 1.0000015497207642, 1.0000015497207642, 1.99828040599823, - 1.99828040599823, 1.9998981952667236, 1.9998981952667236, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, - 1.9995813369750977, 1.9995813369750977, 1.9999752044677734, 1.9999752044677734, 1, 1, 1, 1, - 1.8044296503067017, 1.8044296503067017, 1.9929646253585815, 1.9929646253585815 - ], - "dims": [1, 4, 8], + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], "type": "float32" }, { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 2, 1, 8], "type": "float32" }, { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 1, 8], "type": "float32" } ] @@ -236,12 +803,164 @@ ] }, { - "name": "GroupQueryAttention Basic, q k v same head number", + "name": "GroupQueryAttention 10", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 4, "type": "int" } + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 16], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 16], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 16], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 1, 16], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 11", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 2, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 2, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [2], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [2], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 12", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -249,45 +968,49 @@ "inputs": [ { "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, - 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, + // key, BS* { "data": [ - 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, - 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, + // value, BS* { "data": [ - 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, - 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 32], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 32], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -297,26 +1020,28 @@ "outputs": [ { "data": [ - 1, 12, 21, 131, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, - 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, - 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 3, 4, 4], + "dims": [1, 1, 1, 32], "type": "float32" }, { + // present value, BNSH "data": [ - 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, - 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 3, 4, 4], + "dims": [1, 1, 1, 32], "type": "float32" } ] @@ -324,12 +1049,12 @@ ] }, { - "name": "GroupQueryAttention, no past kv, used as reference", + "name": "GroupQueryAttention 13", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -337,50 +1062,51 @@ "inputs": [ { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 7, 16], + "dims": [1, 4, 8], "type": "float32" }, + // key, BS* { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 7, 8], + "dims": [1, 4, 8], "type": "float32" }, + // value, BS* { "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 7, 8], + "dims": [1, 4, 8], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { - "data": [1], + "data": [4], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { - "data": [1], + "data": [4], "dims": [1], "type": "int32" } @@ -388,29 +1114,28 @@ "outputs": [ { "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 7, 16], + "dims": [1, 4, 8], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 7, 2, 4], + "dims": [1, 1, 4, 8], "type": "float32" }, { + // present value, BNSH "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 7, 2, 4], + "dims": [1, 1, 4, 8], "type": "float32" } ] @@ -418,12 +1143,12 @@ ] }, { - "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 1", + "name": "GroupQueryAttention PackedQKV 14", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -431,52 +1156,41 @@ "inputs": [ { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 7, 16], + "dims": [1, 1, 32], "type": "float32" }, - // new key, BS* + // key, BS* { - "data": [ - 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, - 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 6, 8], + "data": null, "type": "float32" }, - // new value, BS* + // value, BS* { - "data": [ - 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, - 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 6, 8], + "data": null, "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": [1, 2, 3, 4, 5, 6, 7, 8], - "dims": [1, 1, 2, 4], + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": [0, 1, 2, 3, 4, 5, 6, 7], - "dims": [1, 1, 2, 4], + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -485,30 +1199,20 @@ ], "outputs": [ { - "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 - ], - "dims": [1, 7, 16], + "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], "type": "float32" }, { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 7, 2, 4], + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], "type": "float32" }, { - "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 7, 2, 4], + // present value, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 8], "type": "float32" } ] @@ -516,7 +1220,7 @@ ] }, { - "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 2", + "name": "GroupQueryAttention PackedQKV 15", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ @@ -529,54 +1233,48 @@ "inputs": [ { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 + 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, + 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, + 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, + 22, 21, 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, + 1, 3, 4, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, + 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, + 2, 131, 22, 21 ], - "dims": [1, 7, 16], + "dims": [1, 3, 64], "type": "float32" }, - // new key, BS* + // key { - "data": [ - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, - 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 5, 8], + "data": null, "type": "float32" }, - // new value, BS* + // value { - "data": [ - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, - 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 5, 8], + "data": null, "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], - "dims": [1, 2, 2, 4], + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - "dims": [1, 2, 2, 4], + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { - "data": [1], + "data": [3], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { - "data": [1], + "data": [3], "dims": [1], "type": "int32" } @@ -584,29 +1282,29 @@ "outputs": [ { "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 + 1, 9, 1, 1, 2, 2, 2, 2, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 1, 12, 21, 131, 22, 21, 2, + 2, 8, 12, 233, 4, 5, 6, 7, 8, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4, + 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4 ], - "dims": [1, 7, 16], + "dims": [1, 3, 32], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 2, 3, 4, 5, 6, 7, 131, 22, 21, 2, 2, 131, 22, 21, 5, 6, 7, 8, 1, 1, 3, 4, + 8, 11, 12, 13, 14, 15, 16, 17, 1, 1, 1, 1, 2, 2, 2, 2 ], - "dims": [1, 7, 2, 4], + "dims": [1, 2, 3, 8], "type": "float32" }, { + // present value, BNSH "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + 1, 9, 1, 1, 2, 2, 2, 2, 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, + 5, 6, 7, 8, 1, 1, 3, 4, 131, 22, 21, 2, 2, 131, 22, 21 ], - "dims": [1, 7, 2, 4], + "dims": [1, 2, 3, 8], "type": "float32" } ] diff --git a/onnxruntime/contrib_ops/js/bert/group_query_attention.h b/onnxruntime/contrib_ops/js/bert/group_query_attention.h index 7553883a2478d..dff8663133c31 100644 --- a/onnxruntime/contrib_ops/js/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/js/bert/group_query_attention.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once - +#include "contrib_ops/cpu/bert/gqa_attention_base.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { @@ -11,31 +11,29 @@ namespace js { using onnxruntime::js::JsKernel; -class GroupQueryAttention : public JsKernel { +class GroupQueryAttention : public JsKernel, GQAAttentionBase { public: explicit GroupQueryAttention(const OpKernelInfo& info) - : JsKernel(info) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - scale_ = info.GetAttrOrDefault("scale", 0.0f); + : JsKernel(info), GQAAttentionBase(info, false) { JSEP_INIT_KERNEL_ATTRIBUTE(GroupQueryAttention, ({ "numHeads" : $1, "kvNumHeads" : $2, "scale" : $3, + "softcap" : $4, + "doRotary" : $5, + "rotaryInterleaved" : $6, + "smoothSoftmax" : $7, + "localWindowSize" : $8 }), static_cast(num_heads_), static_cast(kv_num_heads_), - static_cast(scale_)); + static_cast(scale_), + static_cast(softcap_), + static_cast(do_rotary_), + static_cast(rotary_interleaved_), + static_cast(use_smooth_softmax_), + static_cast(local_window_size_)); } - - protected: - int num_heads_; // number of attention heads - int kv_num_heads_; // number of k and v heads - float scale_; // custom scale will be used if specified. Default value is 1/sqrt(head_size) }; } // namespace js From a25c9315eafd6ad8c6de207108e3c142e5b3b9dd Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 23 Oct 2024 11:57:15 -0700 Subject: [PATCH 7/8] Move ORT Training pipeline to github actions (#22543) Move ORT Training pipeline to github actions and enable CodeQL scan for the code(including inference code). We will move all pull request pipelines to Github Actions. --- .github/codeql/codeql-config.yml | 7 ++ .github/workflows/linux_training.yml | 55 +++++++++++ .../orttraining-linux-ci-pipeline.yml | 95 ------------------- .../orttraining-linux-gpu-ci-pipeline.yml | 55 ----------- .../github/linux/build_training_ci.sh | 4 - .../docker/Dockerfile.ubuntu_gpu_training | 60 ------------ 6 files changed, 62 insertions(+), 214 deletions(-) create mode 100644 .github/codeql/codeql-config.yml create mode 100644 .github/workflows/linux_training.yml delete mode 100644 tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml delete mode 100644 tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml delete mode 100755 tools/ci_build/github/linux/build_training_ci.sh delete mode 100644 tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training diff --git a/.github/codeql/codeql-config.yml b/.github/codeql/codeql-config.yml new file mode 100644 index 0000000000000..6a76f7bcdbcb0 --- /dev/null +++ b/.github/codeql/codeql-config.yml @@ -0,0 +1,7 @@ +name: "CodeQL config" +queries: + - uses: security-extended + - uses: security-and-quality +paths-ignore: + - tests + - build \ No newline at end of file diff --git a/.github/workflows/linux_training.yml b/.github/workflows/linux_training.yml new file mode 100644 index 0000000000000..51af6cd20de7d --- /dev/null +++ b/.github/workflows/linux_training.yml @@ -0,0 +1,55 @@ +name: orttraining-linux-ci-pipeline +on: + push: + branches: + - main + - rel-* + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + orttraining-linux-ci-pipeline: + runs-on: ubuntu-24.04 + permissions: + actions: read + contents: read + security-events: write + steps: + - uses: actions/checkout@v4 + - run: | + python3 -m pip install -r tools/ci_build/github/linux/python/requirements.txt + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + config-file: ./.github/codeql/codeql-config.yml + languages: 'cpp' + - run: | + set -e -x + rm -rf build + python3 tools/ci_build/build.py --build_dir build --config Release --enable_training --skip_submodule_sync --parallel --update --build + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:cpp" + output: sarif-results + upload: failure-only + + - name: filter-sarif + uses: advanced-security/filter-sarif@v1 + with: + patterns: | + +**/*.cc + +**/*.h + -tests/**/*.* + -build/**/*.* + input: sarif-results/cpp.sarif + output: sarif-results/cpp.sarif + + - name: Upload SARIF + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: sarif-results/cpp.sarif \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml deleted file mode 100644 index 5c3273f79bd30..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +++ /dev/null @@ -1,95 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -jobs: -- job: Linux_Build - timeoutInMinutes: 180 - workspace: - clean: all - variables: - skipComponentGovernanceDetection: true - CCACHE_DIR: $(Pipeline.Workspace)/ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - pool: onnxruntime-Ubuntu-2204-Training-CPU - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: none - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=registry.access.redhat.com/ubi8/ubi" - Repository: onnxruntimecpubuildcentos8x64_packaging - - - task: Cache@2 - inputs: - key: '"$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' - path: $(CCACHE_DIR) - cacheHitVar: CACHE_RESTORED - restoreKeys: | - "$(TODAY)" | "$(Build.SourceBranch)" - "$(TODAY)" | - displayName: Cach Task - - - task: CmdLine@2 - displayName: 'build' - inputs: - script: | - set -e -x - mkdir -p $HOME/.onnx - mkdir -p $(Pipeline.Workspace)/ccache - docker run --rm \ - --volume /data/onnx:/data/onnx:ro \ - --volume /data/models:/build/models:ro \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - --volume $(Pipeline.Workspace)/ccache:/cache \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - -e CCACHE_DIR=/cache \ - onnxruntimecpubuildcentos8x64_packaging \ - /onnxruntime_src/tools/ci_build/github/linux/build_training_ci.sh - workingDirectory: $(Build.SourcesDirectory) - - - task: PublishTestResults@2 - displayName: 'Publish unit test results' - inputs: - testResultsFiles: '**/*.results.xml' - searchFolder: '$(Build.BinariesDirectory)' - testRunTitle: 'Unit Test Run' - condition: succeededOrFailed() diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml deleted file mode 100644 index 494035637a79d..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml +++ /dev/null @@ -1,55 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -jobs: -- template: templates/linux-ci.yml - parameters: - AgentPool : 'Onnxruntime-Linux-GPU-NC6sv3' - JobName: 'Onnxruntime_Linux_GPU_Training' - RunDockerBuildArgs: > - -o ubuntu20.04 -d gpu - -t onnxruntime_orttraining_ortmodule_tests_image - -u - -e - -x " - --enable_training - --config Release - --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 - --build_wheel - --enable_nvtx_profile - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=70 - " - RunInjectedPipeline: 'true' - InjectedPipeline: 'orttraining-linux-gpu-test-ci-pipeline.yml' - DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image' - TimeoutInMinutes: 190 - # Enable unreleased onnx opsets in CI builds - # This facilitates testing the implementation for the new opsets - AllowReleasedOpsetOnly: '0' diff --git a/tools/ci_build/github/linux/build_training_ci.sh b/tools/ci_build/github/linux/build_training_ci.sh deleted file mode 100755 index 82f75a5cbbc50..0000000000000 --- a/tools/ci_build/github/linux/build_training_ci.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -set -e -x -python3.12 -m pip install -r /onnxruntime_src/tools/ci_build/github/linux/python/requirements.txt -python3.12 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --enable_training --skip_submodule_sync --parallel diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training deleted file mode 100644 index 4d11cbbde3354..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training +++ /dev/null @@ -1,60 +0,0 @@ -ARG BASEIMAGE=nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-ubuntu18.04 - -FROM $BASEIMAGE - -ARG PYTHON_VERSION=3.9 -ARG INSTALL_DEPS_EXTRA_ARGS -ARG USE_CONDA=false - -ADD scripts /tmp/scripts -RUN /tmp/scripts/install_ubuntu.sh -p $PYTHON_VERSION && \ - /tmp/scripts/install_os_deps.sh -d gpu $INSTALL_DEPS_EXTRA_ARGS - -# If USE_CONDA is false, use root to install python dependencies. -RUN if [ "$USE_CONDA" = false ] ; \ - then /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d gpu $INSTALL_DEPS_EXTRA_ARGS ; \ - fi - -WORKDIR /root - -# Allow configure to pick up GDK and CuDNN where it expects it. -# (Note: $CUDNN_VERSION is defined by NVidia's base image) -RUN _CUDNN_VERSION=$(echo $CUDNN_VERSION | cut -d. -f1-2) && \ - mkdir -p /usr/local/cudnn-$_CUDNN_VERSION/cuda/include && \ - ln -s /usr/include/cudnn.h /usr/local/cudnn-$_CUDNN_VERSION/cuda/include/cudnn.h && \ - mkdir -p /usr/local/cudnn-$_CUDNN_VERSION/cuda/lib64 && \ - ln -s /etc/alternatives/libcudnn_so /usr/local/cudnn-$_CUDNN_VERSION/cuda/lib64/libcudnn.so && \ - ln -s /usr/local/cudnn{-$_CUDNN_VERSION,} - -ENV LD_LIBRARY_PATH /usr/local/openblas/lib:$LD_LIBRARY_PATH - -ARG BUILD_USER=onnxruntimedev -ARG BUILD_UID=1000 -RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID -WORKDIR /home/$BUILD_USER -USER $BUILD_USER - -ARG MINICONDA_PREFIX=/home/$BUILD_USER/miniconda3 -RUN if [ "$USE_CONDA" = true ] ; \ - then MINICONDA=miniconda.sh && \ - wget --no-verbose https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh -O $MINICONDA && \ - chmod a+x $MINICONDA && \ - ./$MINICONDA -b -p $MINICONDA_PREFIX && \ - rm ./$MINICONDA && \ - $MINICONDA_PREFIX/bin/conda clean --yes --all && \ - $MINICONDA_PREFIX/bin/conda install -y python=$PYTHON_VERSION ; \ - fi - -ENV PATH /home/$BUILD_USER/miniconda3/bin:$PATH - -# If USE_CONDA is true, use onnxruntimedev user to install python dependencies -RUN if [ "$USE_CONDA" = true ] ; \ - then /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d gpu $INSTALL_DEPS_EXTRA_ARGS -c ; \ - fi - -WORKDIR /root -USER root -RUN rm -rf /tmp/scripts - -WORKDIR /home/$BUILD_USER -USER $BUILD_USER From 08cc2612f46484030ad92ddf0baac9ab3fae0f6c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:04:46 -0700 Subject: [PATCH 8/8] Bump onnx from 1.16.0 to 1.17.0 in /onnxruntime/python/tools/transformers/models/stable_diffusion/requirements (#22561) Bumps [onnx](https://github.com/onnx/onnx) from 1.16.0 to 1.17.0.
Release notes

Sourced from onnx's releases.

v1.17.0

ONNX v1.17.0 is now available with exciting new features! We would like to thank everyone who contributed to this release! Please visit onnx.ai to learn more about ONNX and associated projects.

Key Updates

ai.onnx Opset 22

Python Changes

  • Support for numpy >= 2.0

Bug fixes and infrastructure improvements

  • Fix Check URLs errors 5972
  • Use CMAKE_PREFIX_PATH in finding libprotobuf 5975
  • Bump main VERSION_NUMBER to 1.17.0 5968
  • Fix source and pip tar.gz builds on s390x systems 5984
  • Fix unique_name 5992
  • Fix SegFault bug in shape inference 5990
  • Fix onnx.compose when connecting subgraphs 5991
  • Fix conversion from split 11 to split 18 6020
  • Update error messages for NegativeLogLikelihoodLoss inference function 6021
  • Generalize input/output number check in shape inference 6005
  • Replace rank inference with shape inference for Einsum op 6010
  • build from source instruction with latest cmake change 6038
  • Handle OneHot's depth value during shape inference 5963
  • Not to install cmake in pyproject.toml on Windows 6045
  • fix a skipped shape infer code 6049
  • Include the ".onnxtext" extension in supported serialization format 6051
  • Allow ReferenceEvaluator to return intermediate results 6066
  • Fix 1 typo in numpy_helper.py 6041
  • Remove benchmarking code 6076
  • Prevent crash on import after GCC 8 builds 6048
  • Check graph outputs are defined 6083
  • Enable additional ruff rules 6032
  • Add missing shape inference check for DequantizeLinear 6080
  • Add bfloat16 to all relevant ops 6099
  • fix(ci): install python dependencies with --only-binary :all: in manylinux 6120
  • fix: install google-re2 with --only-binary option 6129
  • Specify axis parameter for DequantizeLinear when input rank is 1 6095
  • Pin onnxruntime to 1.17.3 for release CIs 6143
  • Fix INT4 TensorProto byte size is 5x larger than expected with negative values 6161
  • Mitigate tarball directory traversal risks 6164
  • Fix reference implementation for ScatterND with 4D tensors 6174
  • Addition of group > 1 in test and in backend for ConvTranspose 6175
  • Support for bfloat16 for binary, unary operators in reference implementation 6166
  • Refactor windows workflow to work on standard windows 6190
  • Fix a few crashes while running shape inference 6195
  • Update onnx to work with numpy>=2.0 6196
  • Use sets to improve performance of dfs search 6213

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=onnx&package-manager=pip&previous-version=1.16.0&new-version=1.17.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../models/stable_diffusion/requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt index 8c9f0ba0f21be..8ff5990b7815a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt @@ -3,7 +3,7 @@ diffusers==0.28.0 transformers==4.41.2 numpy>=1.24.1 accelerate -onnx==1.16.0 +onnx==1.17.0 coloredlogs packaging # Use newer version of protobuf might cause crash