diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index 4dee1c14b3761..174942b9033d0 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -109,8 +109,16 @@ std::vector GetQDQIODefs(const Node& target_node, const QDQ::Node // If we can find the node index in the dq or q nodes this is a quantized input/output if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) { const auto node_inputs = node.InputDefs(); + const auto& node_attrs = node.GetAttributes(); + + // Get the Q or DQ axis attribute if available. + std::optional axis; + if (auto entry = node_attrs.find("axis"); entry != node_attrs.end()) { + axis = entry->second.i(); + } + // quantization scale and zp are always the input[1, 2] - NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr}; + NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr, axis}; if (is_input) { // DQ is input to the target node, use the DstArgIndex diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index 66afaec8ee1e2..a168495f12ebf 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -41,10 +41,11 @@ struct NodeGroup { // If the optional quant_param is present, then this is a quantized input, // otherwise this is a regular input struct NodeUnitIODef { - // The quantization parameter, scale is manadatory, and zero_point is optional + // The quantization parameter. Scale is mandatory. Zero-point and axis are optional. struct QuantParam { const NodeArg& scale; const NodeArg* zero_point{nullptr}; + std::optional axis{std::nullopt}; }; const NodeArg& node_arg; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 08c9a8449cc33..ba86e08822a94 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -65,8 +65,9 @@ Status BaseOpBuilder::ProcessInput(QnnModelWrapper& qnn_model_wrapper, } Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name); - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, input_info.qnn_data_type, input_info.quant_param, - std::move(input_info.shape), std::move(unpacked_tensor)); + QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, input_info.qnn_data_type, + std::move(input_info.quant_param), std::move(input_info.shape), + std::move(unpacked_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); input_names.push_back(input_name); @@ -129,7 +130,7 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, TensorInfo output_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[output_i], output_info)); - if (output_info.quant_param.encodingDefinition == QNN_DEFINITION_DEFINED) { + if (output_info.quant_param.IsQuantized()) { ORT_RETURN_IF_ERROR(OverrideOutputQuantParam(qnn_model_wrapper, node_unit, logger, input_names, output_i, output_info.qnn_data_type, output_info.quant_param)); } @@ -143,7 +144,7 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, QnnTensorWrapper cast_input_tensorwrapper(cast_input_name, QNN_TENSOR_TYPE_NATIVE, supported_qnn_data_type, - output_info.quant_param, + output_info.quant_param.Copy(), std::move(cast_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor."); output_names.push_back(cast_input_name); @@ -156,7 +157,7 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type, - output_info.quant_param, + std::move(output_info.quant_param), std::move(output_info.shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); } @@ -189,15 +190,15 @@ Status BaseOpBuilder::SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& size_t input_index, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const { + QnnQuantParamsWrapper& quant_param) const { const QnnTensorWrapper& input_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[input_index]); ORT_RETURN_IF_NOT(input_tensor_wrapper.GetTensorDataType() == qnn_data_type, "Input and output data types do not match"); - Qnn_QuantizeParams_t input_quant_param = GetQnnTensorQParams(input_tensor_wrapper.GetQnnTensor()); + const QnnQuantParamsWrapper& input_quant_param = input_tensor_wrapper.GetQnnQuantParams(); float scale_diff = 0.0f; int32_t offset_diff = 0; - ORT_RETURN_IF_ERROR(CompareQnnQuantParams(quant_param, input_quant_param, scale_diff, offset_diff)); + ORT_RETURN_IF_ERROR(CompareQnnQuantParams(quant_param.Get(), input_quant_param.Get(), scale_diff, offset_diff)); constexpr float NEARLY_EQUAL_THRESHOLD = 1e-9f; constexpr float WARN_THRESHOLD = 1e-6f; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 4eb599eb50175..8e4e05be82457 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -6,6 +6,7 @@ #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder.h" +#include "core/providers/qnn/builder/qnn_quant_params_wrapper.h" #include "core/framework/allocator.h" #include "QnnOpDef.h" @@ -57,7 +58,7 @@ class BaseOpBuilder : public IOpBuilder { const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const ORT_MUST_USE_RESULT { + QnnQuantParamsWrapper& quant_param) const ORT_MUST_USE_RESULT { // Do nothing by default. Op builders like Split implement this function to override output quant params. ORT_UNUSED_PARAMETER(qnn_model_wrapper); ORT_UNUSED_PARAMETER(node_unit); @@ -110,7 +111,7 @@ class BaseOpBuilder : public IOpBuilder { size_t input_index, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const ORT_MUST_USE_RESULT; + QnnQuantParamsWrapper& quant_param) const ORT_MUST_USE_RESULT; static const std::string& GetQnnOpType(const std::string& onnx_op_type) { // TODO: Use QNN operator names defined in "QnnOpDef.h" @@ -320,6 +321,8 @@ class BaseOpBuilder : public IOpBuilder { private: std::string op_builder_type_; + + protected: const std::vector nchw2nhwc_perm{0, 2, 3, 1}; const std::vector nchw2hwcn_perm{2, 3, 1, 0}; const std::vector cnhw2hwcn_perm{2, 3, 0, 1}; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index 294aa659872c4..70ad00b90c9dd 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -260,13 +260,16 @@ class BatchNormOpBuilder : public BaseOpBuilder { uint32_t channel = mean_info.shape[0]; mean_out.resize(channel); ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(mean_info.qnn_data_type, channel, mean_raw_ptr_length)); + ORT_RETURN_IF_NOT(!is_npu_backend || mean_info.quant_param.IsPerTensor(), + "BatchNormalization's input_mean does not support per-channel quantization"); int i = 0; int offset = 0; + const Qnn_QuantizeParams_t& quant_param = mean_info.quant_param.Get(); for (; i < static_cast(channel); ++i) { double mean_value = 0.0; ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset)); - mean_out[i] = (is_npu_backend) ? utils::Dequantize(mean_info.quant_param.scaleOffsetEncoding.offset, - mean_info.quant_param.scaleOffsetEncoding.scale, + mean_out[i] = (is_npu_backend) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, mean_value) : mean_value; } @@ -283,13 +286,16 @@ class BatchNormOpBuilder : public BaseOpBuilder { uint32_t channel = var_info.shape[0]; std_out.resize(channel); ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(var_info.qnn_data_type, channel, var_raw_ptr_length)); + ORT_RETURN_IF_NOT(!is_npu_backend || var_info.quant_param.IsPerTensor(), + "BatchNormalization's input_var does not support per-channel quantization"); int i = 0; int offset = 0; + const Qnn_QuantizeParams_t& quant_param = var_info.quant_param.Get(); for (; i < static_cast(channel); ++i) { double var_value = 0.0; ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset)); - std_out[i] = (is_npu_backend) ? utils::Dequantize(var_info.quant_param.scaleOffsetEncoding.offset, - var_info.quant_param.scaleOffsetEncoding.scale, + std_out[i] = (is_npu_backend) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, var_value) : var_value; std_out[i] = std::sqrt(std_out[i] + static_cast(epsilon)); @@ -309,13 +315,16 @@ class BatchNormOpBuilder : public BaseOpBuilder { uint32_t channel = scale_info.shape[0]; scale_out.resize(channel); ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(scale_info.qnn_data_type, channel, scale_raw_ptr_length)); + ORT_RETURN_IF_NOT(!is_npu_backend || scale_info.quant_param.IsPerTensor(), + "BatchNormalization's scale input does not support per-channel quantization"); int i = 0; int offset = 0; + const Qnn_QuantizeParams_t& quant_param = scale_info.quant_param.Get(); for (; i < static_cast(channel); ++i) { double scale_value = 0.0; ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset)); - scale_out[i] = (is_npu_backend) ? utils::Dequantize(scale_info.quant_param.scaleOffsetEncoding.offset, - scale_info.quant_param.scaleOffsetEncoding.scale, + scale_out[i] = (is_npu_backend) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, scale_value) : scale_value; scale_out[i] = scale_out[i] / std_double_tensor[i]; @@ -338,13 +347,16 @@ class BatchNormOpBuilder : public BaseOpBuilder { uint32_t channel = bias_info.shape[0]; bias_out.resize(channel); ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(bias_info.qnn_data_type, channel, bias_raw_ptr_length)); + ORT_RETURN_IF_NOT(!is_npu_backend || bias_info.quant_param.IsPerTensor(), + "BatchNormalization's bias input does not support per-channel quantization"); int i = 0; int offset = 0; + const Qnn_QuantizeParams_t& quant_param = bias_info.quant_param.Get(); for (; i < static_cast(channel); ++i) { double bias_value = 0.0; ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset)); - bias_out[i] = (is_npu_backend) ? utils::Dequantize(bias_info.quant_param.scaleOffsetEncoding.offset, - bias_info.quant_param.scaleOffsetEncoding.scale, + bias_out[i] = (is_npu_backend) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, bias_value) : bias_value; bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]); @@ -359,7 +371,7 @@ class BatchNormOpBuilder : public BaseOpBuilder { const std::vector& double_tensor, const double rmax, const double rmin, - Qnn_QuantizeParams_t& quant_param, + QnnQuantParamsWrapper& quant_param, std::vector& raw_tensor) const { if (is_npu_backend) { raw_tensor.resize(double_tensor.size()); @@ -370,8 +382,7 @@ class BatchNormOpBuilder : public BaseOpBuilder { info.qnn_data_type, scale, zero_point)); - quant_param = QNN_QUANTIZE_PARAMS_INIT; - utils::InitializeQuantizeParam(quant_param, true, scale, zero_point); + quant_param = QnnQuantParamsWrapper(scale, zero_point); for (size_t i = 0; i < double_tensor.size(); ++i) { // onnx only supports 8 bits quantization int quant_value_int = 0; @@ -382,6 +393,7 @@ class BatchNormOpBuilder : public BaseOpBuilder { int8_t quant_value = static_cast(quant_value_int); raw_tensor[i] = *reinterpret_cast(&quant_value); } else { + // TODO(adrianlizarraga): Should support 16-bit quantization as well. ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type); } } @@ -545,7 +557,7 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, if (!qnn_model_wrapper.IsQnnTensorWrapperExist(scale_name)) { std::vector scale_raw_tensor; - Qnn_QuantizeParams_t scale_quant_param = scale_info.quant_param; + QnnQuantParamsWrapper scale_quant_param = scale_info.quant_param; ORT_RETURN_IF_ERROR(Postprocess(scale_info, is_npu_backend, scale_double_tensor, @@ -554,15 +566,16 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, scale_quant_param, scale_raw_tensor)); Qnn_TensorType_t scale_tensor_type = GetInputTensorType(qnn_model_wrapper, scale_name); - QnnTensorWrapper input_tensorwrapper(scale_name, scale_tensor_type, scale_info.qnn_data_type, scale_quant_param, - std::move(scale_info.shape), std::move(scale_raw_tensor)); + QnnTensorWrapper input_tensorwrapper(scale_name, scale_tensor_type, scale_info.qnn_data_type, + std::move(scale_quant_param), std::move(scale_info.shape), + std::move(scale_raw_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); } input_names.push_back(scale_name); if (!qnn_model_wrapper.IsQnnTensorWrapperExist(bias_name)) { std::vector bias_raw_tensor; - Qnn_QuantizeParams_t bias_quant_param = bias_info.quant_param; + QnnQuantParamsWrapper bias_quant_param = bias_info.quant_param; ORT_RETURN_IF_ERROR(Postprocess(bias_info, is_npu_backend, bias_double_tensor, @@ -571,8 +584,9 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, bias_quant_param, bias_raw_tensor)); Qnn_TensorType_t bias_tensor_type = GetInputTensorType(qnn_model_wrapper, bias_name); - QnnTensorWrapper input_tensorwrapper(bias_name, bias_tensor_type, bias_info.qnn_data_type, bias_quant_param, - std::move(bias_info.shape), std::move(bias_raw_tensor)); + QnnTensorWrapper input_tensorwrapper(bias_name, bias_tensor_type, bias_info.qnn_data_type, + std::move(bias_quant_param), std::move(bias_info.shape), + std::move(bias_raw_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); } input_names.push_back(bias_name); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc index 000a94f888e97..ee2bb8a7f5b85 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc @@ -70,7 +70,7 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, type_proto, qnn_data_type)); - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, QNN_QUANTIZE_PARAMS_INIT, + QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, QnnQuantParamsWrapper(), std::move(input_shape), std::move(unpacked_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add input tensor for QNN Cast node."); @@ -106,7 +106,7 @@ Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, qnn_data_type, - QNN_QUANTIZE_PARAMS_INIT, + QnnQuantParamsWrapper(), std::move(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add output tensor for QNN Cast node."); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 84b6cad9c41c1..a1966168a81a8 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -116,6 +116,22 @@ Status ConvOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, } } + // Validate that weight is signed type for per-channel quantization (required by QNN docs). + if (is_npu_backend) { + const auto& input_1 = inputs[1]; // weight + bool is_per_axis_quant = false; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.IsPerChannelQuantized(input_1, is_per_axis_quant)); + + if (is_per_axis_quant) { + int32_t elem_data_type = 0; + ORT_RETURN_IF_ERROR(utils::GetOnnxTensorElemDataType(input_1.node_arg, elem_data_type)); + + const bool is_signed_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) || + (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16); + ORT_RETURN_IF_NOT(is_signed_type, "Conv weights must be of a signed quantized type if quantized per-channel"); + } + } + return Status::OK(); } @@ -171,7 +187,7 @@ Status ConvOpBuilder::ProcessConv2DInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); // - // Input 1: weight + // Input 1: weight. This input must be transposed manually by QNN EP. // { const std::string& input1_name = inputs[1].node_arg.Name(); @@ -203,8 +219,18 @@ Status ConvOpBuilder::ProcessConv2DInputs(QnnModelWrapper& qnn_model_wrapper, } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } + + // Transpose quantization parameter's axis if this is using per-channel quantization. + if (input_info.quant_param.IsPerChannel()) { + const std::vector& perm = conv_type == OnnxConvType::kConv ? nchw2hwcn_perm : cnhw2hwcn_perm; + std::vector perm_inv(perm.size()); + ORT_RETURN_IF_ERROR(utils::InvertPerm(perm, perm_inv)); + ORT_RETURN_IF_ERROR(input_info.quant_param.HandleTranspose(perm_inv)); + } } else { // Add transpose node above weight input. + ORT_RETURN_IF(input_info.quant_param.IsPerChannel(), + "Non-constant Conv inputs only support per-tensor quantization"); bool is_graph_input = qnn_model_wrapper.IsGraphInput(input1_name); LOGS(logger, VERBOSE) << "Add HWCN Transpose node after input: " << input1_name; @@ -234,7 +260,8 @@ Status ConvOpBuilder::ProcessConv2DInputs(QnnModelWrapper& qnn_model_wrapper, } Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, actual_name); - QnnTensorWrapper input_tensorwrapper(actual_name, tensor_type, input_info.qnn_data_type, input_info.quant_param, + QnnTensorWrapper input_tensorwrapper(actual_name, tensor_type, input_info.qnn_data_type, + std::move(input_info.quant_param), std::move(actual_shape), std::move(unpacked_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); } @@ -288,6 +315,9 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, }; if (!input0_info.is_initializer) { + ORT_RETURN_IF(input0_info.quant_param.IsPerChannel(), + "Non-constant Conv inputs only support per-tensor quantization"); + // Add Reshape node to transform 1D input to 2D (i.e., set height to 1). // We don't need to do this for initializers, because the number of elements does not change. We can just // modify the shape dimensions. @@ -300,11 +330,15 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, input0_info.quant_param, do_op_validation, is_graph_input)); + } else if (input0_info.quant_param.IsPerChannel()) { + // The reshape (unsqueeze) may require us to shift the quant parameter's axis. + ORT_RETURN_IF_ERROR(input0_info.quant_param.HandleUnsqueeze(input0_info.shape, shape)); } Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, conv_input0_name); - QnnTensorWrapper input_tensorwrapper(conv_input0_name, tensor_type, input0_info.qnn_data_type, input0_info.quant_param, - std::move(shape), std::move(unpacked_tensor)); + QnnTensorWrapper input_tensorwrapper(conv_input0_name, tensor_type, input0_info.qnn_data_type, + std::move(input0_info.quant_param), std::move(shape), + std::move(unpacked_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); } else { LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input0_name; @@ -370,6 +404,11 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, ONNX_NAMESPACE::TensorProto reshaped_initializer = onnxruntime::utils::TensorToTensorProto(tensor_2d, reshape_output); + // The reshape (unsqueeze) may require us to shift the quant parameter's axis. + if (input_info.quant_param.IsPerChannel()) { + ORT_RETURN_IF_ERROR(input_info.quant_param.HandleUnsqueeze(input_info.shape, shape_2d)); + } + // // Get transposed initializer bytes. // @@ -380,8 +419,19 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } + + // Transpose quantization parameter's axis if this is using per-channel quantization. + if (input_info.quant_param.IsPerChannel()) { + const std::vector& perm = conv_type == OnnxConvType::kConv ? nchw2hwcn_perm : cnhw2hwcn_perm; + std::vector perm_inv(perm.size()); + ORT_RETURN_IF_ERROR(utils::InvertPerm(perm, perm_inv)); + ORT_RETURN_IF_ERROR(input_info.quant_param.HandleTranspose(perm_inv)); + } } else { // Dynamic weight: Add nodes to reshape to 2D, and then transpose. + ORT_RETURN_IF(input_info.quant_param.IsPerChannel(), + "Non-constant Conv inputs only support per-tensor quantization"); + bool is_graph_input = qnn_model_wrapper.IsGraphInput(input1_name); LOGS(logger, VERBOSE) << "Adding Reshape (to 2D) and HWCN Transpose node after input: " << input1_name; ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(input1_name, @@ -419,7 +469,8 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, conv_weight_input_name); QnnTensorWrapper input_tensorwrapper(conv_weight_input_name, tensor_type, input_info.qnn_data_type, - input_info.quant_param, std::move(final_shape), std::move(unpacked_tensor)); + std::move(input_info.quant_param), std::move(final_shape), + std::move(unpacked_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); } @@ -648,17 +699,13 @@ Status ConvOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra const std::string& output_node_type = is_depthwise_conv2d ? QNN_OP_DEPTH_WISE_CONV_2D : GetQnnOpType(node_unit.OpType()); - Qnn_QuantizeParams_t output_quantize_param = QNN_QUANTIZE_PARAMS_INIT; + QnnQuantParamsWrapper output_quantize_param; + ORT_RETURN_IF_ERROR(output_quantize_param.Init(qnn_model_wrapper, outputs[0])); bool is_quantized_tensor = outputs[0].quant_param.has_value(); - utils::InitializeQuantizeParam(output_quantize_param, is_quantized_tensor); const auto* type_proto = outputs[0].node_arg.TypeAsProto(); Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; ORT_RETURN_IF_ERROR(utils::GetQnnDataType(is_quantized_tensor, type_proto, qnn_data_type)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(outputs[0].quant_param, - output_quantize_param.scaleOffsetEncoding.scale, - output_quantize_param.scaleOffsetEncoding.offset), - "Cannot get quantization parameter"); if (is_1d_conv) { const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name); @@ -669,8 +716,8 @@ Status ConvOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra output_shape[2], // C }; const std::string conv_output_name = output_name + "_ort_qnn_ep_conv2d"; - QnnTensorWrapper output_tensorwrapper(conv_output_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, output_quantize_param, - std::vector(output_shape_2d)); + QnnTensorWrapper output_tensorwrapper(conv_output_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, + output_quantize_param.Copy(), std::vector(output_shape_2d)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, @@ -693,8 +740,8 @@ Status ConvOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } else { const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name); Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; - QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, qnn_data_type, output_quantize_param, - std::move(output_shape)); + QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, qnn_data_type, + std::move(output_quantize_param), std::move(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc index 90e18e9fd0496..9e31cf9cae21a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc @@ -30,7 +30,7 @@ class ExpandOpBuilder : public BaseOpBuilder { const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; }; template @@ -75,7 +75,7 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, bool is_quantized_tensor = inputs[0].quant_param.has_value(); Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; const auto* type_proto = inputs[0].node_arg.TypeAsProto(); - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; + QnnQuantParamsWrapper quantize_param; if (is_quantized_tensor) { ORT_RETURN_IF_ERROR(utils::GetQnnDataType(true, type_proto, qnn_data_type)); float scale = 0.0f; @@ -87,7 +87,7 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, qnn_data_type, scale, zero_point)); - utils::InitializeQuantizeParam(quantize_param, true, scale, zero_point); + quantize_param = QnnQuantParamsWrapper(scale, zero_point); int quant_value_int = 0; double ini_value = 1.0; ORT_RETURN_IF_ERROR(utils::Quantize(ini_value, scale, zero_point, qnn_data_type, quant_value_int)); @@ -129,8 +129,9 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); std::string shape_input_name(input_name + "_" + output_name); - QnnTensorWrapper input_tensorwrapper(shape_input_name, QNN_TENSOR_TYPE_STATIC, qnn_data_type, quantize_param, - std::move(input_shape), std::move(shape_data)); + QnnTensorWrapper input_tensorwrapper(shape_input_name, QNN_TENSOR_TYPE_STATIC, qnn_data_type, + std::move(quantize_param), std::move(input_shape), + std::move(shape_data)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); input_names.push_back(shape_input_name); @@ -144,7 +145,11 @@ Status ExpandOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrap const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const { + QnnQuantParamsWrapper& quant_param) const { + if (!quant_param.IsPerTensor()) { + return Status::OK(); + } + // Force Expand output to use the same quantization parameters as the input if they are nearly equal. // This enables the HTP backend to employ certain optimizations. return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc index 9f396a27369e7..40bfc70158899 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc @@ -80,14 +80,11 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, qnn_data_type = QNN_DATATYPE_INT_32; } - // Even for Quantized model, Gather indices use int32 without quantization - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; - Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name); std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, input_shape), "Cannot get shape"); std::vector cast_output_shape(input_shape); - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_param, + QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, QnnQuantParamsWrapper(), std::move(input_shape), std::move(gather_indices)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); @@ -96,8 +93,8 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, if (qnn_data_type == QNN_DATATYPE_INT_64) { // Add Cast node for indices indices_input_name = input_name + "_ort_qnn_ep_cast"; - QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32, quantize_param, - std::move(cast_output_shape)); + QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32, + QnnQuantParamsWrapper(), std::move(cast_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name, QNN_OP_PACKAGE_NAME_QTI_AISW, @@ -157,18 +154,14 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w const auto& gather_output = node_unit.Outputs()[0]; const auto& output_name = gather_output.node_arg.Name(); - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; - bool is_quantized_tensor = gather_output.quant_param.has_value(); - utils::InitializeQuantizeParam(quantize_param, is_quantized_tensor); + QnnQuantParamsWrapper quantize_param; + ORT_RETURN_IF_ERROR(quantize_param.Init(qnn_model_wrapper, gather_output)); const auto* type_proto = gather_output.node_arg.TypeAsProto(); Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; - ORT_RETURN_IF_ERROR(utils::GetQnnDataType(is_quantized_tensor, type_proto, qnn_data_type)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(gather_output.quant_param, - quantize_param.scaleOffsetEncoding.scale, - quantize_param.scaleOffsetEncoding.offset), - "Cannot get quantization parameter"); - if (is_quantized_tensor) { + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(quantize_param.IsQuantized(), type_proto, qnn_data_type)); + + if (quantize_param.IsPerTensor()) { // Make sure the output quantization parameters are equal to the input. ORT_RETURN_IF_ERROR(SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, 0 /*input_index*/, 0 /*output_index*/, qnn_data_type, @@ -183,7 +176,7 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w bool reshape_required = (qnn_output_shape.size() != target_output_shape.size()); std::string gather_output_name = output_name + (reshape_required ? "_ort_qnn_ep_reshape" : ""); Qnn_TensorType_t tensor_type = (!reshape_required && is_graph_output) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; - QnnTensorWrapper gather_output_wrapper(gather_output_name, tensor_type, qnn_data_type, quantize_param, + QnnTensorWrapper gather_output_wrapper(gather_output_name, tensor_type, qnn_data_type, quantize_param.Copy(), std::move(qnn_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(gather_output_wrapper)), "Failed to add tensor."); @@ -199,7 +192,7 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w if (reshape_required) { // Add Reshape Node after Gather. Qnn_TensorType_t reshape_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; - QnnTensorWrapper reshape_output(output_name, reshape_tensor_type, qnn_data_type, quantize_param, + QnnTensorWrapper reshape_output(output_name, reshape_tensor_type, qnn_data_type, std::move(quantize_param), std::move(target_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_output)), "Failed to add tensor."); const static std::string qnn_node_type = "Reshape"; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc index 338e46765736f..bf409b8f508de 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc @@ -87,10 +87,10 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const auto& inputs = node_unit.Inputs(); for (size_t input_i = 0; input_i < inputs.size(); ++input_i) { - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; - bool is_quantized_tensor = inputs[input_i].quant_param.has_value(); - utils::InitializeQuantizeParam(quantize_param, is_quantized_tensor); + QnnQuantParamsWrapper quantize_param; + ORT_RETURN_IF_ERROR(quantize_param.Init(qnn_model_wrapper, inputs[input_i])); + bool is_quantized_tensor = inputs[input_i].quant_param.has_value(); const auto& input_name = inputs[input_i].node_arg.Name(); // Only skip if the input tensor has already been added (by producer op) *and* we don't need @@ -107,16 +107,12 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[input_i].node_arg, input_shape), "Cannot get shape"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(inputs[input_i].quant_param, - quantize_param.scaleOffsetEncoding.scale, - quantize_param.scaleOffsetEncoding.offset), - "Cannot get quantization parameter"); - std::vector unpacked_tensor; bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); if (is_initializer_input) { const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); if (1 == input_trans_flag.at(input_i)) { + ORT_RETURN_IF_ERROR(quantize_param.HandleTranspose(std::vector({1, 0}))); ORT_RETURN_IF_ERROR(TwoDimensionTranspose(qnn_model_wrapper, input_shape, *input_tensor, @@ -128,6 +124,8 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, std::string input_tensor_name = input_name; if (1 == input_trans_flag.at(input_i) && !is_initializer_input) { + ORT_RETURN_IF(quantize_param.IsPerChannel(), "Non-constant Gemm inputs only support per-tensor quantization"); + // Add Transpose node std::vector old_input_shape(input_shape); input_shape[0] = old_input_shape[1]; @@ -148,7 +146,7 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, input_names.push_back(input_tensor_name); Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_tensor_name); - QnnTensorWrapper input_tensorwrapper(input_tensor_name, tensor_type, qnn_data_type, quantize_param, + QnnTensorWrapper input_tensorwrapper(input_tensor_name, tensor_type, qnn_data_type, std::move(quantize_param), std::move(input_shape), std::move(unpacked_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc index 38172caa03768..5fe5e3bedd6eb 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc @@ -119,6 +119,9 @@ Status InstanceNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, }; if (!input0_info.is_initializer) { + ORT_RETURN_IF(input0_info.quant_param.IsPerChannel(), + "Non-constant InstanceNormalization inputs only support per-tensor quantization"); + // Add Reshape node to transform 1D input to 2D (i.e., set height to 1). // We don't need to do this for initializers, because the element layout does not change. We can just // modify the shape dimensions. @@ -131,11 +134,15 @@ Status InstanceNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, input0_info.quant_param, do_op_validation, is_graph_input)); + } else if (input0_info.quant_param.IsPerChannel()) { + // The reshape (unsqueeze) may require us to shift the quant parameter's axis. + ORT_RETURN_IF_ERROR(input0_info.quant_param.HandleUnsqueeze(input0_info.shape, op_shape)); } Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, op_input0_name); - QnnTensorWrapper input_tensorwrapper(op_input0_name, tensor_type, input0_info.qnn_data_type, input0_info.quant_param, - std::move(op_shape), std::move(initializer_data)); + QnnTensorWrapper input_tensorwrapper(op_input0_name, tensor_type, input0_info.qnn_data_type, + std::move(input0_info.quant_param), std::move(op_shape), + std::move(initializer_data)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); } else { ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); // Input 0 @@ -197,7 +204,7 @@ Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_m }; QnnTensorWrapper output_tensorwrapper(op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type, - output_info.quant_param, std::vector(op_output_shape)); + output_info.quant_param.Copy(), std::vector(op_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index d6752f76ef478..3f73ef76e9def 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -77,47 +77,50 @@ Status ProcessConstantValue(QnnModelWrapper& qnn_model_wrapper, if (input.quant_param.has_value()) { // QNN prefers pad_constant_value quantized with quantization params same as in[0], and data stored as 32-bit signed integer // Onnx doesn't guarantee it has same quantization parameter as in[0], so get back the float32 value and use non-quantized data directly + ORT_RETURN_IF_NOT(input_info.quant_param.IsPerTensor(), + "Pad's constant value must use per-tensor quantization"); + const Qnn_QuantizeParams_t& quant_param = input_info.quant_param.Get(); constant_value_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; float constant_value = 0; switch (input_info.qnn_data_type) { case QNN_DATATYPE_SFIXED_POINT_8: { auto int8_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale, + constant_value = static_cast(utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, static_cast(int8_span.data()[0]))); break; } case QNN_DATATYPE_SFIXED_POINT_16: { auto int16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale, + constant_value = static_cast(utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, static_cast(int16_span.data()[0]))); break; } case QNN_DATATYPE_SFIXED_POINT_32: { auto int32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale, + constant_value = static_cast(utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, static_cast(int32_span.data()[0]))); break; } case QNN_DATATYPE_UFIXED_POINT_8: { - constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale, + constant_value = static_cast(utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, static_cast(unpacked_tensor.data()[0]))); break; } case QNN_DATATYPE_UFIXED_POINT_16: { auto uint16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale, + constant_value = static_cast(utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, static_cast(uint16_span.data()[0]))); break; } case QNN_DATATYPE_UFIXED_POINT_32: { auto uint32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale, + constant_value = static_cast(utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, static_cast(uint32_span.data()[0]))); break; } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 872d9682b8355..ef1990ad8e69a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -4,6 +4,7 @@ #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/common/safeint.h" @@ -35,7 +36,7 @@ class PoolOpBuilder : public BaseOpBuilder { const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; private: Status SetCommonPoolParams(const NodeAttrHelper& node_helper, std::vector& filter_size, @@ -250,10 +251,10 @@ Status PoolOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrappe const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const { + QnnQuantParamsWrapper& quant_param) const { // Force MaxPool outputs to use the same quantization parameters as the input if they are nearly equal. // This helps the HTP backend employ certain optimizations. - if (node_unit.OpType() == "MaxPool") { + if (node_unit.OpType() == "MaxPool" && quant_param.IsPerTensor()) { return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, 0 /*input_index*/, output_index, qnn_data_type, quant_param); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc index 4b06df6a0e632..b6f414da950d8 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc @@ -4,6 +4,7 @@ #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" @@ -29,7 +30,7 @@ class ReshapeOpBuilder : public BaseOpBuilder { const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; }; Status ReshapeOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, @@ -57,7 +58,11 @@ Status ReshapeOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wra const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const { + QnnQuantParamsWrapper& quant_param) const { + if (!quant_param.IsPerTensor()) { + return Status::OK(); + } + // Force Reshape output to use the same quantization parameters as the input if nearly equal. // This helps the HTP backend emply certain optimizations. return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc index a021c4cf0c735..e1c9a391459b2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc @@ -48,7 +48,7 @@ class ResizeOpBuilder : public BaseOpBuilder { const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; private: // Info for each ONNX attribute of interest (attribute name + default value) @@ -376,7 +376,11 @@ Status ResizeOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrap const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const { + QnnQuantParamsWrapper& quant_param) const { + if (!quant_param.IsPerTensor()) { + return Status::OK(); + } + // Force Resize op's output to use the same quantization parameters as the input if nearly equal. // This helps the HTP backend employ certain optimizations. return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index dd678ab5467ed..82d71bb3e9dde 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -38,7 +38,7 @@ class SimpleOpBuilder : public BaseOpBuilder { const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; private: Status ExplicitOpCheck(const NodeUnit& node_unit) const; @@ -69,21 +69,19 @@ Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(qnn::utils::GetQminQmax(input_qnn_data_type, qmin, qmax)); double value_min = qnn::utils::Dequantize(input_offset, input_scale, qmin); double value_max = qnn::utils::Dequantize(input_offset, input_scale, qmax); - - Qnn_QuantizeParams_t convert_output_quant_param = QNN_QUANTIZE_PARAMS_INIT; - convert_output_quant_param.encodingDefinition = QNN_DEFINITION_DEFINED; - convert_output_quant_param.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + float scale = 0.0f; + int32_t offset = 0; ORT_RETURN_IF_ERROR(qnn::utils::GetQuantParams(static_cast(value_min), static_cast(value_max), output_qnn_data_type, - convert_output_quant_param.scaleOffsetEncoding.scale, - convert_output_quant_param.scaleOffsetEncoding.offset)); + scale, + offset)); std::vector output_shape_copy = output_shape; QnnTensorWrapper convert_output_tensorwrapper(convert_output_name, QNN_TENSOR_TYPE_NATIVE, output_qnn_data_type, - convert_output_quant_param, + QnnQuantParamsWrapper(scale, offset), std::move(output_shape_copy)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(convert_output_tensorwrapper)), "Failed to add tensor."); @@ -116,6 +114,9 @@ Status SimpleOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, if (!input0_info.is_initializer && !input1_info.is_initializer && input0_info.qnn_data_type == input1_info.qnn_data_type && input0_info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + ORT_RETURN_IF_NOT(input1_info.quant_param.IsPerTensor(), + "MatMul's activation inputs only support per-tensor quantization"); + const Qnn_QuantizeParams_t& quant_param = input1_info.quant_param.Get(); // insert Convert op after input1 std::string convert_input_name = input_names.back(); input_names.pop_back(); @@ -126,8 +127,8 @@ Status SimpleOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, convert_output_name, input1_info.qnn_data_type, QNN_DATATYPE_UFIXED_POINT_8, - input1_info.quant_param.scaleOffsetEncoding.offset, - input1_info.quant_param.scaleOffsetEncoding.scale, + quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, input1_info.shape, do_op_validation)); input_names.push_back(convert_output_name); @@ -218,7 +219,7 @@ Status ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const std::string input_name) { NodeAttrHelper node_helper(node_unit); - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; + QnnQuantParamsWrapper quantize_param; Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; union { float alpha; @@ -236,14 +237,14 @@ Status ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper, GetQuantizationParameter(&tensor_data.alpha, num_of_elements, scale, zero_point, thread_pool); unpacked_data.resize(1); ParQuantizeLinearStd(&tensor_data.alpha, unpacked_data.data(), num_of_elements, scale, zero_point, thread_pool); - utils::InitializeQuantizeParam(quantize_param, is_quantized_tensor, scale, static_cast(zero_point)); + quantize_param = QnnQuantParamsWrapper(scale, static_cast(zero_point)); qnn_data_type = QNN_DATATYPE_UFIXED_POINT_8; } else { unpacked_data.assign(tensor_data.unpack, tensor_data.unpack + sizeof(float)); } std::vector input_shape{1}; Qnn_TensorType_t tensor_type = QNN_TENSOR_TYPE_STATIC; - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_param, + QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, std::move(quantize_param), std::move(input_shape), std::move(unpacked_data)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); return Status::OK(); @@ -443,7 +444,7 @@ Status SimpleOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrap const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const { + QnnQuantParamsWrapper& quant_param) const { ORT_UNUSED_PARAMETER(input_names); const std::string& op_type = node_unit.OpType(); @@ -458,10 +459,10 @@ Status SimpleOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrap const auto& output = node_unit.Outputs()[0]; const std::string& output_name = output.node_arg.Name(); - if (quant_param.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { - if (OverrideQuantParams(op_type, qnn_data_type, quant_param.scaleOffsetEncoding)) { - const int32_t offset = quant_param.scaleOffsetEncoding.offset; - const float scale = quant_param.scaleOffsetEncoding.scale; + if (quant_param.IsPerTensor(/*include_bw*/ false)) { + if (OverrideQuantParams(op_type, qnn_data_type, quant_param.Get().scaleOffsetEncoding)) { + const int32_t offset = quant_param.Get().scaleOffsetEncoding.offset; + const float scale = quant_param.Get().scaleOffsetEncoding.scale; LOGS(logger, VERBOSE) << "QNN requires that 16-bit quantized " << op_type << " operators use offset/scale values " diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc index 9059f7459200a..b0b2dc6164e8e 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc @@ -140,8 +140,8 @@ Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, is_graph_input)); Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, op_input_name); - QnnTensorWrapper input_tensorwrapper(op_input_name, tensor_type, input_info.qnn_data_type, input_info.quant_param, - std::move(op_input_shape), {}); + QnnTensorWrapper input_tensorwrapper(op_input_name, tensor_type, input_info.qnn_data_type, + std::move(input_info.quant_param), std::move(op_input_shape), {}); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); return Status::OK(); @@ -199,8 +199,8 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_ op_output_shape[output_rank - 1] = output_info.shape[axis]; op_output_shape[axis] = output_info.shape[output_rank - 1]; - QnnTensorWrapper output_tensorwrapper(op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type, output_info.quant_param, - std::vector(op_output_shape)); + QnnTensorWrapper output_tensorwrapper(op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type, + output_info.quant_param.Copy(), std::vector(op_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc index 9849a05db329c..1a7411eb5136a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/cpu/tensor/slice_helper.h" @@ -37,7 +38,7 @@ class SplitOpBuilder : public BaseOpBuilder { const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; }; Status SplitOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, @@ -149,7 +150,11 @@ Status SplitOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapp const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const { + QnnQuantParamsWrapper& quant_param) const { + if (!quant_param.IsPerTensor()) { + return Status::OK(); + } + // Force Split outputs to use the same quantization parameters as the input if nearly equal. // This helps the HTP backend employ certain optimizations. // diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc index 721db9dd2670e..851ca84dce075 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/cpu/tensor/slice_helper.h" @@ -37,7 +38,7 @@ class TileOpBuilder : public BaseOpBuilder { const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; }; Status TileOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, @@ -100,7 +101,11 @@ Status TileOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrappe const std::vector& input_names, size_t output_index, Qnn_DataType_t qnn_data_type, - Qnn_QuantizeParams_t& quant_param) const { + QnnQuantParamsWrapper& quant_param) const { + if (!quant_param.IsPerTensor()) { + return Status::OK(); + } + // Force the Tile operator output to use the same quantization parameters as the input if nearly equal. // This helps the HTP backend employ certain optimizations. return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc index e69067ba8b0c6..c71ae4435f8bc 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc @@ -96,7 +96,7 @@ Status TransposeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_mode QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, input_tensor_wrapper.GetTensorDataType(), - GetQnnTensorQParams(input_tensor_wrapper.GetQnnTensor()), + input_tensor_wrapper.GetQnnQuantParams().Copy(), std::move(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index 55e72670a6971..a1b4dc8bbb716 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -115,18 +115,10 @@ void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data) } void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params) { - Qnn_QuantizationEncoding_t encoding = quantize_params.quantizationEncoding; - if (encoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET || - encoding == QNN_QUANTIZATION_ENCODING_UNDEFINED) { - if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { - qnn_tensor.v1.quantizeParams = quantize_params; - } else { - ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); - } - } else if (encoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { - ORT_THROW("Axis scale offset quantization parameter is not supported."); + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.quantizeParams = quantize_params; } else { - ORT_THROW("quantizationEncoding incorrect value."); + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); } } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index cb6344b4e7902..7d76006ed9b19 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -11,6 +11,7 @@ #include #include "core/graph/basic_types.h" #include "core/common/common.h" +#include "core/providers/qnn/builder/qnn_quant_params_wrapper.h" namespace onnxruntime { namespace qnn { @@ -144,12 +145,13 @@ class QnnTensorWrapper { QnnTensorWrapper(const std::string& name, Qnn_TensorType_t tensor_type, Qnn_DataType_t data_type, - const Qnn_QuantizeParams_t& quantize_params, + QnnQuantParamsWrapper&& quantize_params, std::vector&& shape, std::vector&& client_buf = {}, Qnn_TensorMemType_t mem_type = QNN_TENSORMEMTYPE_RAW) : tensor_name_(name), dimensions_(std::move(shape)), - client_buf_(std::move(client_buf)) { + client_buf_(std::move(client_buf)), + quant_params_(quantize_params) { SetQnnTensorType(qnn_tensor_, tensor_type); SetQnnTensorName(qnn_tensor_, tensor_name_.c_str()); SetQnnTensorDataType(qnn_tensor_, data_type); @@ -163,31 +165,36 @@ class QnnTensorWrapper { ORT_THROW("mem_type not supported for now."); } - SetQnnTensorQParams(qnn_tensor_, quantize_params); + SetQnnTensorQParams(qnn_tensor_, quant_params_.Get()); } - QnnTensorWrapper(const Qnn_Tensor_t& qnn_tensor) : tensor_name_(GetQnnTensorName(qnn_tensor)), - client_buf_{} { + // Initialize from a raw Qnn_Tensor_t. This method is currently used for graph inputs/outputs + // when deserializing from cached context object. Possible return errors due to: + // - Unexpected Qnn_TensorType_t: only handle graph inputs/outputs, not static initializers with data buffers. + // - Unexpected quantization encoding. + Status Init(const Qnn_Tensor_t& qnn_tensor) { + Qnn_TensorType_t tensor_type = GetQnnTensorType(qnn_tensor); + ORT_RETURN_IF(tensor_type == QNN_TENSOR_TYPE_STATIC, + "QnnTensorWrapper::Init(const Qnn_Tensor_t&) does not support static initializers"); + + tensor_name_ = GetQnnTensorName(qnn_tensor); + client_buf_.clear(); + qnn_tensor_ = qnn_tensor; SetQnnTensorName(qnn_tensor_, tensor_name_.c_str()); - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; - const auto& src_quantize_param = GetQnnTensorQParams(qnn_tensor); - // quantization only support SCALE_OFFSET encoding - quantize_param.encodingDefinition = src_quantize_param.encodingDefinition; - quantize_param.quantizationEncoding = src_quantize_param.quantizationEncoding; - quantize_param.scaleOffsetEncoding = src_quantize_param.scaleOffsetEncoding; - SetQnnTensorQParams(qnn_tensor_, quantize_param); + const Qnn_QuantizeParams_t& src_quantize_param = GetQnnTensorQParams(qnn_tensor); + ORT_RETURN_IF_ERROR(quant_params_.Init(src_quantize_param)); + SetQnnTensorQParams(qnn_tensor_, quant_params_.Get()); uint32_t shape_rank = GetQnnTensorRank(qnn_tensor); uint32_t* shape_data = GetQnnTensorDims(qnn_tensor); dimensions_.assign(shape_data, shape_data + shape_rank); SetQnnTensorDim(qnn_tensor_, dimensions_); - // This method is only used for graph inputs/outputs when desearilize from cached context - // no client buffer should be set - SetQnnTensorMemType(qnn_tensor_, QNN_TENSORMEMTYPE_RAW); + + return Status::OK(); } QnnTensorWrapper() = default; @@ -198,10 +205,12 @@ class QnnTensorWrapper { std::swap(tensor_name_, other.tensor_name_); std::swap(dimensions_, other.dimensions_); std::swap(client_buf_, other.client_buf_); + std::swap(quant_params_, other.quant_params_); std::swap(qnn_tensor_, other.qnn_tensor_); SetQnnTensorName(qnn_tensor_, tensor_name_.c_str()); SetQnnTensorDim(qnn_tensor_, dimensions_); SetQnnTensorClientBuf(qnn_tensor_, client_buf_); + SetQnnTensorQParams(qnn_tensor_, quant_params_.Get()); } ~QnnTensorWrapper() = default; @@ -214,6 +223,14 @@ class QnnTensorWrapper { return qnn_tensor_; } + const QnnQuantParamsWrapper& GetQnnQuantParams() const { + return quant_params_; + } + + QnnQuantParamsWrapper& GetQnnQuantParams() { + return quant_params_; + } + const std::string& GetName() const { return tensor_name_; } Qnn_TensorType_t GetTensorType() const { return GetQnnTensorType(qnn_tensor_); } @@ -231,22 +248,11 @@ class QnnTensorWrapper { } private: - void CopyQuantizationEncoding(Qnn_QuantizeParams_t& dst, const Qnn_QuantizeParams_t& src) { - Qnn_QuantizationEncoding_t encoding = src.quantizationEncoding; - if (encoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET || - encoding == QNN_QUANTIZATION_ENCODING_UNDEFINED) { - dst = src; - } else if (encoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { - ORT_THROW("Axis scale offset quantization parameter is not supported."); - } else { - ORT_THROW("quantizationEncoding incorrect value."); - } - } - std::string tensor_name_; std::vector dimensions_; std::vector client_buf_; Qnn_Tensor_t qnn_tensor_ = QNN_TENSOR_INIT; + QnnQuantParamsWrapper quant_params_; }; class QnnParamWrapper { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index b3501dfec1ba8..109ec743d8483 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -349,14 +349,16 @@ Status QnnModel::DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_Graph // Copy graph input Qnn_Tensor_t* input_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphInputs; for (size_t i = 0; i < graph_input_num; ++i) { - QnnTensorWrapper tensorwrapper(input_tensors[i]); + QnnTensorWrapper tensorwrapper; + ORT_RETURN_IF_ERROR(tensorwrapper.Init(input_tensors[i])); input_tensor_wrappers.push_back(std::move(tensorwrapper)); } // Copy graph output Qnn_Tensor_t* output_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphOutputs; for (size_t i = 0; i < graph_output_num; ++i) { - QnnTensorWrapper tensorwrapper(output_tensors[i]); + QnnTensorWrapper tensorwrapper; + ORT_RETURN_IF_ERROR(tensorwrapper.Init(output_tensors[i])); output_tensor_wrappers.push_back(std::move(tensorwrapper)); } } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index a422434205c68..750fe24882302 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -276,113 +276,139 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vector& zero_points) const { const auto& graph_initializers = GetInitializerTensors(); - auto offset_it = graph_initializers.find(offset_name); - if (offset_it == graph_initializers.end()) { - LOGS(logger_, ERROR) << "Not able to find initializer: " << offset_name; - return false; - } - const auto offset_tensor = offset_it->second; - const int32_t onnx_data_type = offset_tensor->data_type(); + auto iter = graph_initializers.find(initializer_name); + ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for zero-point(s): ", + initializer_name.c_str()); + gsl::not_null zp_tensor_proto = iter->second; + + ORT_RETURN_IF_NOT(zp_tensor_proto->has_data_type(), "Expected zero-point initializer ", initializer_name.c_str(), + " to have a proto data type."); + + const int32_t onnx_data_type = zp_tensor_proto->data_type(); + std::vector initializer_bytes; + + ORT_RETURN_IF_ERROR(UnpackInitializerData(*zp_tensor_proto, initializer_bytes)); - std::vector unpacked_tensor; - ORT_THROW_IF_ERROR(UnpackInitializerData(*offset_tensor, unpacked_tensor)); switch (onnx_data_type) { // QNN use -offset for some reason case ONNX_NAMESPACE::TensorProto_DataType_INT8: { - auto int8_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - offset_value = -(int8_span.data()[0]); + auto int8_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); + std::transform(int8_span.begin(), int8_span.end(), std::back_inserter(zero_points), + [](int8_t zp) -> int32_t { + return -static_cast(zp); + }); break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { - auto uint8_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - offset_value = 0 - (uint8_span.data()[0]); + auto uint8_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); + std::transform(uint8_span.begin(), uint8_span.end(), std::back_inserter(zero_points), + [](uint8_t zp) -> int32_t { + return -static_cast(zp); + }); break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { - auto uint16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - offset_value = -static_cast(uint16_span.data()[0]); + auto uint16_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); + std::transform(uint16_span.begin(), uint16_span.end(), std::back_inserter(zero_points), + [](uint16_t zp) -> int32_t { + return -static_cast(zp); + }); break; } case ONNX_NAMESPACE::TensorProto_DataType_INT16: { - auto int16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - offset_value = -static_cast(int16_span.data()[0]); + auto int16_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); + std::transform(int16_span.begin(), int16_span.end(), std::back_inserter(zero_points), + [](int16_t zp) -> int32_t { + return -static_cast(zp); + }); break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - auto int32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - offset_value = -(int32_span.data()[0]); + auto int32_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); + std::transform(int32_span.begin(), int32_span.end(), std::back_inserter(zero_points), + [](int32_t zp) -> int32_t { + return -zp; + }); break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { - auto uint32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - offset_value = 0 - (uint32_span.data()[0]); + auto uint32_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); + std::transform(uint32_span.begin(), uint32_span.end(), std::back_inserter(zero_points), + [](uint32_t zp) -> int32_t { + return -static_cast(zp); + }); break; } default: { - LOGS(logger_, ERROR) << "Data type not supported!"; - return false; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Zero-point ONNX data type `", onnx_data_type, + "` is not supported."); } } - return true; + + return Status::OK(); } -bool QnnModelWrapper::ProcessScale(const std::string& scale_name, - float& scale_value) const { +Status QnnModelWrapper::UnpackScales(const std::string& initializer_name, std::vector& scales) const { const auto& graph_initializers = GetInitializerTensors(); - auto offset_it = graph_initializers.find(scale_name); - if (offset_it == graph_initializers.end()) { - LOGS(logger_, ERROR) << "Not able to find initializer: " << scale_name; - return false; - } - const auto scale_tensor = offset_it->second; - std::vector unpacked_tensor; + auto iter = graph_initializers.find(initializer_name); + ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for scale(s): ", + initializer_name.c_str()); + gsl::not_null scale_tensor_proto = iter->second; - ORT_THROW_IF_ERROR(UnpackInitializerData(*scale_tensor, unpacked_tensor)); - const float* scale_data = reinterpret_cast(unpacked_tensor.data()); - scale_value = scale_data[0]; - return true; -} + ORT_RETURN_IF_NOT(scale_tensor_proto->has_data_type(), "Expected scale initializer ", initializer_name.c_str(), + " to have a proto data type."); + ORT_RETURN_IF_NOT(scale_tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + "Expected scale initializer to be of type FLOAT"); -bool QnnModelWrapper::ProcessQuantizationParameter(const std::optional& quant_param, - float& scale_value, - int32_t& offset_value) const { - if (quant_param.has_value()) { - // Parse scale & zero_point - const auto& scale_name = quant_param->scale.Name(); - bool rt = ProcessScale(scale_name, scale_value); - if (!rt) { - return rt; - } + std::vector initializer_bytes; - if (quant_param->zero_point) { - const auto& zero_point_name = quant_param->zero_point->Name(); - return ProcessOffset(zero_point_name, offset_value); - } + ORT_RETURN_IF_ERROR(UnpackInitializerData(*scale_tensor_proto, initializer_bytes)); + + gsl::span src = gsl::make_span(reinterpret_cast(initializer_bytes.data()), + initializer_bytes.size() / sizeof(float)); + + scales.insert(scales.end(), src.begin(), src.end()); + return Status::OK(); +} + +// Checks if a tensor in the ONNX graph is per-channel quantized. +Status QnnModelWrapper::IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& io_def, + /*out*/ bool& is_per_axis) const { + if (!io_def.quant_param) { + is_per_axis = false; + return Status::OK(); } - return true; + + const std::string& scale_name = io_def.quant_param->scale.Name(); + const auto& graph_initializers = GetInitializerTensors(); + auto iter = graph_initializers.find(scale_name); + ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for scale(s): ", + scale_name.c_str()); + gsl::not_null scale_tensor_proto = iter->second; + TensorShape scale_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(*scale_tensor_proto); + + // Check the number of scale values to determine if the tensor is per-channel. + // This is consistent with CPU EP's Quant/Dequant logic. We can't use the presence of an axis because even a + // per-channel DQ/Q op may not have an explicit axis attribute (assumed to default to 1 if missing). + const bool is_scalar_or_1_elem_vector = scale_shape.NumDimensions() == 0 || + (scale_shape.NumDimensions() == 1 && scale_shape.Size() == 1); + + is_per_axis = !is_scalar_or_1_elem_vector; + return Status::OK(); } Status QnnModelWrapper::GetTensorInfo(const NodeUnitIODef& input, TensorInfo& tensor_info) const { const std::string& name = input.node_arg.Name(); // Fill in quantization param info. - tensor_info.quant_param = QNN_QUANTIZE_PARAMS_INIT; - bool is_quantized_tensor = input.quant_param.has_value(); - utils::InitializeQuantizeParam(tensor_info.quant_param, is_quantized_tensor); - - if (is_quantized_tensor) { - ORT_RETURN_IF_NOT(ProcessQuantizationParameter(input.quant_param, - tensor_info.quant_param.scaleOffsetEncoding.scale, - tensor_info.quant_param.scaleOffsetEncoding.offset), - "QNN EP: Cannot get quantization parameters for input ", name.c_str()); - } + ORT_RETURN_IF_ERROR(tensor_info.quant_param.Init(*this, input)); // Fill in QNN data type. tensor_info.qnn_data_type = QNN_DATATYPE_FLOAT_32; - ORT_RETURN_IF_ERROR(utils::GetQnnDataType(is_quantized_tensor, input.node_arg.TypeAsProto(), + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(input.quant_param.has_value(), input.node_arg.TypeAsProto(), tensor_info.qnn_data_type)); // Fill in shape. @@ -402,14 +428,19 @@ Status QnnModelWrapper::AddReshapeNode(const std::string& input_name, const std::vector& input_shape, const std::vector& output_shape, const Qnn_DataType_t& tensor_data_type, - const Qnn_QuantizeParams_t& quantize_param, + const QnnQuantParamsWrapper& quantize_param, bool do_op_validation, bool is_for_input, bool is_for_output) { + // Do not allow QNN EP to insert Reshape nodes with per-channel quantization on dynamic tensors. + // We could technically support this by shifting the quantization param's axis value, but + // we don't need this right now. + ORT_RETURN_IF(quantize_param.IsPerChannel(), + "Do not support inserted Reshape nodes with per-channel quantization"); QnnTensorWrapper input_tensorwrapper(input_name, is_for_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_NATIVE, tensor_data_type, - quantize_param, + quantize_param.Copy(), std::vector(input_shape)); ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(input_tensorwrapper)), "QNN EP: Failed to add input tensor for inserted Reshape."); @@ -418,7 +449,7 @@ Status QnnModelWrapper::AddReshapeNode(const std::string& input_name, QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, tensor_data_type, - quantize_param, + quantize_param.Copy(), std::vector(output_shape)); ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)), "QNN EP: Failed to add output tensor for inserted Reshape."); @@ -442,17 +473,22 @@ Status QnnModelWrapper::AddTransposeNode(NodeIndex node_index, const std::vector& transpose_perm, const std::vector& output_shape, const Qnn_DataType_t& tensor_data_type, - const Qnn_QuantizeParams_t& quantize_param, + const QnnQuantParamsWrapper& quantize_param, bool do_op_validation, bool is_for_input, bool is_for_output) { + // Do not allow QNN EP to insert transpose nodes with per-channel quantization on dynamic tensors. + // We could technically support this by transposing the quantization param's axis value, but + // we don't need this right now. + ORT_RETURN_IF(quantize_param.IsPerChannel(), + "Do not support inserted Transpose nodes with per-channel quantization"); // No need to add this for output nodes as it is added as output tensor for previous node if (is_for_input) { Qnn_TensorType_t tensor_type = QNN_TENSOR_TYPE_APP_WRITE; QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, tensor_data_type, - quantize_param, + quantize_param.Copy(), std::vector(input_shape)); ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); } @@ -469,7 +505,7 @@ Status QnnModelWrapper::AddTransposeNode(NodeIndex node_index, QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, tensor_data_type, - quantize_param, + quantize_param.Copy(), std::move(output_shape_copy)); ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); const static std::string qnn_node_type = "Transpose"; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 1e2993f246ae4..446c082950653 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -14,6 +14,7 @@ #include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" #include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_quant_params_wrapper.h" namespace onnxruntime { namespace qnn { @@ -23,7 +24,7 @@ namespace qnn { struct TensorInfo { std::vector shape; Qnn_DataType_t qnn_data_type; - Qnn_QuantizeParams_t quant_param; + QnnQuantParamsWrapper quant_param; bool is_initializer; const ONNX_NAMESPACE::TensorProto* initializer_tensor; }; @@ -97,16 +98,6 @@ class QnnModelWrapper { static bool GetOnnxShape(const NodeArg& node_arg, std::vector& shape); - bool ProcessOffset(const std::string& offset_name, - int32_t& offset_value) const; - - bool ProcessScale(const std::string& scale_name, - float& scale_value) const; - - bool ProcessQuantizationParameter(const std::optional& quant_param, - float& scale_value, - int32_t& offset_value) const; - bool IsQnnTensorWrapperExist(const std::string& tensor_name) const; bool IsGraphOutput(const std::string& tensor_name) const { @@ -124,7 +115,7 @@ class QnnModelWrapper { const std::vector& input_shape, const std::vector& output_shape, const Qnn_DataType_t& tensor_data_type, - const Qnn_QuantizeParams_t& quantize_param, + const QnnQuantParamsWrapper& quantize_param, bool do_op_validation, bool is_for_input = true, bool is_for_output = false); @@ -136,7 +127,7 @@ class QnnModelWrapper { const std::vector& transpose_perm, const std::vector& output_shape, const Qnn_DataType_t& tensor_data_type, - const Qnn_QuantizeParams_t& quantize_param, + const QnnQuantParamsWrapper& quantize_param, bool do_op_validation, bool is_for_input = true, bool is_for_output = false); @@ -148,7 +139,7 @@ class QnnModelWrapper { const std::vector& input_shape, const std::vector& output_shape, const Qnn_DataType_t& tensor_data_type, - const Qnn_QuantizeParams_t& quantize_param, + const QnnQuantParamsWrapper& quantize_param, bool do_op_validation, bool is_for_input = true, bool is_for_output = false) { @@ -165,7 +156,7 @@ class QnnModelWrapper { const std::vector& input_shape, const std::vector& output_shape, const Qnn_DataType_t& tensor_data_type, - const Qnn_QuantizeParams_t& quantize_param, + const QnnQuantParamsWrapper& quantize_param, bool do_op_validation, bool is_for_input = true, bool is_for_output = false) { @@ -182,6 +173,15 @@ class QnnModelWrapper { const GraphViewer& GetGraphViewer() const { return graph_viewer_; } + // Unpack float scales from initializer (1 scale for per-tensor, > 1 for per-axis). + Status UnpackScales(const std::string& initializer_name, std::vector& scales) const; + + // Unpack zero-points from initializer and convert to int32_t (1 zero-point for per-tensor, > 1 for per-channel). + Status UnpackZeroPoints(const std::string& initializer_name, std::vector& zero_points) const; + + // Checks if a tensor in the ONNX graph is per-axis quantized. + Status IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& io_def, /*out*/ bool& is_per_axis) const; + private: bool CreateQnnInputOutputTensors(const std::string& qnn_node_name, const std::vector& names, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc new file mode 100644 index 0000000000000..401d403c15b01 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_quant_params_wrapper.h" +#include +#include +#include +#include +#include "QnnTypes.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" + +namespace onnxruntime { +namespace qnn { + +QnnQuantParamsWrapper::QnnQuantParamsWrapper(const QnnQuantParamsWrapper& other) + : params_(QNN_QUANTIZE_PARAMS_INIT) { + Status status = Init(other.params_); + assert(status.IsOK()); // Expect other QnnQuantParamsWrapper to always have a supported quantization encoding. +} + +QnnQuantParamsWrapper& QnnQuantParamsWrapper::operator=(const QnnQuantParamsWrapper& other) { + if (this != &other) { + Status status = Init(other.params_); + assert(status.IsOK()); // Expect other QnnQuantParamsWrapper to always have a supported quantization encoding. + } + + return *this; +} + +QnnQuantParamsWrapper::QnnQuantParamsWrapper(float scale, int32_t offset) { + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + params_.scaleOffsetEncoding.scale = scale; + params_.scaleOffsetEncoding.offset = offset; +} + +QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const { + return QnnQuantParamsWrapper(*this); +} + +Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { + if (scale_offset_data_) { + scale_offset_data_.reset(nullptr); + params_ = QNN_QUANTIZE_PARAMS_INIT; + } + + if (params.encodingDefinition != QNN_DEFINITION_DEFINED) { + params_ = params; + return Status::OK(); + } + + switch (params.quantizationEncoding) { + case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: + params_ = params; + break; + case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: { + params_.encodingDefinition = params.encodingDefinition; + params_.quantizationEncoding = params.quantizationEncoding; + params_.axisScaleOffsetEncoding.axis = params.axisScaleOffsetEncoding.axis; + params_.axisScaleOffsetEncoding.numScaleOffsets = params.axisScaleOffsetEncoding.numScaleOffsets; + + // Deep copy the scaleOffset data. + const uint32_t num_elems = params.axisScaleOffsetEncoding.numScaleOffsets; + + if (num_elems > 0) { + scale_offset_data_ = std::make_unique(num_elems); + gsl::span src_span(params.axisScaleOffsetEncoding.scaleOffset, num_elems); + std::copy(src_span.begin(), src_span.end(), scale_offset_data_.get()); + params_.axisScaleOffsetEncoding.scaleOffset = scale_offset_data_.get(); + } else { + params_.axisScaleOffsetEncoding.scaleOffset = nullptr; + } + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", params.quantizationEncoding); + } + + return Status::OK(); +} + +Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& io_def) { + const std::optional& ort_quant_params = io_def.quant_param; + + if (scale_offset_data_) { + scale_offset_data_.reset(nullptr); + params_ = QNN_QUANTIZE_PARAMS_INIT; + } + + if (!ort_quant_params.has_value()) { + params_.encodingDefinition = QNN_DEFINITION_UNDEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; + return Status::OK(); + } + + std::vector scales; + std::vector zero_points; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(ort_quant_params->scale.Name(), scales)); + + if (ort_quant_params->zero_point != nullptr) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points)); + } + + const bool is_per_tensor = scales.size() == 1; + + if (is_per_tensor) { + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + + // Parse scale & zero_point + params_.scaleOffsetEncoding.scale = scales[0]; + + if (ort_quant_params->zero_point != nullptr) { + ORT_RETURN_IF_NOT(zero_points.size() == 1, "Expected one zero-point value"); + params_.scaleOffsetEncoding.offset = zero_points[0]; + } else { + params_.scaleOffsetEncoding.offset = 0; + } + } else { + // Per-channel quantization. + const auto* io_shape = io_def.node_arg.Shape(); + ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape"); + const int32_t io_rank = io_shape->dim_size(); + + constexpr int64_t DEFAULT_QDQ_AXIS = 1; + int64_t axis = ort_quant_params->axis.value_or(DEFAULT_QDQ_AXIS); + if (axis < 0) { + axis += io_rank; + } + ORT_RETURN_IF_NOT(axis >= 0 && axis < io_rank, + "Quantization axis must be within the range [0, rank - 1]"); + + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; + + const size_t num_elems = scales.size(); + const bool no_zero_points = zero_points.empty(); + ORT_RETURN_IF_NOT(num_elems > 1, "Expected more than one scale value"); + ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems, + "Expected the same number of zero-points and scales for per-channel quantization"); + + scale_offset_data_ = std::make_unique(num_elems); + gsl::span data_span(scale_offset_data_.get(), num_elems); + + for (size_t i = 0; i < num_elems; i++) { + data_span[i].scale = scales[i]; + data_span[i].offset = no_zero_points ? 0 : zero_points[i]; + } + + params_.axisScaleOffsetEncoding.axis = static_cast(axis); + params_.axisScaleOffsetEncoding.numScaleOffsets = static_cast(num_elems); + params_.axisScaleOffsetEncoding.scaleOffset = data_span.data(); + } + + return Status::OK(); +} +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h new file mode 100644 index 0000000000000..3cf04da97a8ff --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "QnnTypes.h" +#include "core/common/common.h" +#include "core/common/gsl.h" +#include "core/framework/node_unit.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; // Forward-declare + +class QnnQuantParamsWrapper { + public: + QnnQuantParamsWrapper() : params_(QNN_QUANTIZE_PARAMS_INIT) {} + + QnnQuantParamsWrapper(const QnnQuantParamsWrapper& other); + QnnQuantParamsWrapper& operator=(const QnnQuantParamsWrapper& other); + + QnnQuantParamsWrapper(QnnQuantParamsWrapper&& other) = default; + QnnQuantParamsWrapper& operator=(QnnQuantParamsWrapper&& other) = default; + + // Construct a per-tensor quantization param (SCALE_OFFSET) + QnnQuantParamsWrapper(float scale, int32_t offset); + + Qnn_QuantizeParams_t& Get() { return params_; } + const Qnn_QuantizeParams_t& Get() const { return params_; } + + // Initialize this object from a raw Qnn_QuantizeParam_t object. + Status Init(const Qnn_QuantizeParams_t& params); + + // Initialize this object from a (potentially) quantized ONNX tensor. + // QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers. + Status Init(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& io_def); + + QnnQuantParamsWrapper Copy() const; + + bool IsQuantized() const { + return params_.encodingDefinition == QNN_DEFINITION_DEFINED; + } + + bool IsPerTensor(bool include_bw = false) const { + return params_.encodingDefinition == QNN_DEFINITION_DEFINED && + (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET || + (include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET)); + } + + bool IsPerChannel(bool include_bw = false) const { + return params_.encodingDefinition == QNN_DEFINITION_DEFINED && + (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET || + (include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET)); + } + + // Handle transposing of a per-channel quantized tensor. The quantization parameter's axis + // must be transposed using the inverse permutation of the Transpose. + template + Status HandleTranspose(gsl::span perm) { + if (!IsPerChannel(true)) { + return Status::OK(); + } + + if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + ORT_RETURN_IF_NOT(static_cast(params_.axisScaleOffsetEncoding.axis) < perm.size(), + "Axis value is out of range of the provided permutation"); + const int32_t new_axis = static_cast(perm[params_.axisScaleOffsetEncoding.axis]); + params_.axisScaleOffsetEncoding.axis = new_axis; + } else if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + ORT_RETURN_IF_NOT(static_cast(params_.bwAxisScaleOffsetEncoding.axis) < perm.size(), + "Axis value is out of range of the provided permutation"); + const int32_t new_axis = static_cast(perm[params_.bwAxisScaleOffsetEncoding.axis]); + params_.bwAxisScaleOffsetEncoding.axis = new_axis; + } + + return Status::OK(); + } + + // Handle "unsqueeze" of a per-channel quantized tensor. The quantization parameter's axis + // may need to be shifted if the unsqueeze inserted 1s before the quantization axis. + template + Status HandleUnsqueeze(gsl::span orig_shape, + gsl::span new_shape) { + if (!IsPerChannel(true)) { + return Status::OK(); + } + + ORT_RETURN_IF_NOT(orig_shape.size() < new_shape.size(), "Expected unsqueezed shape to have a greater rank."); + + // Get the axis value. + int32_t axis = 0; + if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + axis = params_.axisScaleOffsetEncoding.axis; + } else if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + axis = params_.bwAxisScaleOffsetEncoding.axis; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unhandled quantization encoding: ", params_.quantizationEncoding); + } + + // Find where the axis was moved to after unsqueeze. + size_t num_found = 0; + size_t j = 0; + for (size_t i = 0; i < orig_shape.size() && j < new_shape.size(); i++) { + while (orig_shape[i] != new_shape[j] && j < new_shape.size()) { + assert(new_shape[j] == 1); + j++; + } + assert(orig_shape[i] == new_shape[j]); + if (num_found == static_cast(axis)) { + break; + } + num_found += 1; + j++; + } + + if (j == static_cast(axis)) { + return Status::OK(); + } + + // Set new axis. + if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + params_.axisScaleOffsetEncoding.axis = static_cast(j); + } else if (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + params_.bwAxisScaleOffsetEncoding.axis = static_cast(j); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unhandled quantization encoding: ", params_.quantizationEncoding); + } + + return Status::OK(); + } + + private: + Qnn_QuantizeParams_t params_; + std::unique_ptr scale_offset_data_; // Stores per-channel scales and offsets +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index e4074fa6fb60b..7a75b055e7de9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -216,6 +216,31 @@ std::ostream& operator<<(std::ostream& out, const Qnn_QuantizeParams_t& quantize if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { out << " scale=" << quantize_params.scaleOffsetEncoding.scale; out << " offset=" << quantize_params.scaleOffsetEncoding.offset; + } else if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + out << " axis=" << quantize_params.axisScaleOffsetEncoding.axis; + size_t num_elems = quantize_params.axisScaleOffsetEncoding.numScaleOffsets; + out << " scales=("; + for (size_t i = 0; i < num_elems; i++) { + out << quantize_params.axisScaleOffsetEncoding.scaleOffset[i].scale << (i == num_elems - 1 ? "" : " "); + } + out << ") offsets=("; + for (size_t i = 0; i < num_elems; i++) { + out << quantize_params.axisScaleOffsetEncoding.scaleOffset[i].offset << (i == num_elems - 1 ? "" : " "); + } + out << ")"; + } else if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + out << " axis=" << quantize_params.bwAxisScaleOffsetEncoding.axis; + out << " bw=" << quantize_params.bwAxisScaleOffsetEncoding.bitwidth; + size_t num_elems = quantize_params.bwAxisScaleOffsetEncoding.numElements; + out << " scales=("; + for (size_t i = 0; i < num_elems; i++) { + out << quantize_params.bwAxisScaleOffsetEncoding.scales[i] << (i == num_elems - 1 ? "" : " "); + } + out << ") offsets=("; + for (size_t i = 0; i < num_elems; i++) { + out << quantize_params.bwAxisScaleOffsetEncoding.offsets[i] << (i == num_elems - 1 ? "" : " "); + } + out << ")"; } else { out << " encoding not supported."; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index edbef7ae92ee0..f61117addd83a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -1,14 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - -#include "QnnTypes.h" -#include "core/session/onnxruntime_cxx_api.h" +#pragma once #include #include -#include #include +#include +#include +#include "QnnTypes.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/node_unit.h" #include "core/util/qmath.h" namespace onnxruntime { @@ -30,11 +32,28 @@ Status GetQnnDataType(const bool is_quantized_tensor, const ONNX_NAMESPACE::Type bool OnnxDataTypeToQnnDataType(const int32_t data_type, Qnn_DataType_t& qnn_data_type, bool is_quantized = false); -inline void InitializeQuantizeParam(Qnn_QuantizeParams_t& quantize_param, bool is_quantized_tensor, float scale = 0.0f, int32_t offset = 0) { - quantize_param.encodingDefinition = is_quantized_tensor ? QNN_DEFINITION_DEFINED : QNN_DEFINITION_UNDEFINED; - quantize_param.quantizationEncoding = is_quantized_tensor ? QNN_QUANTIZATION_ENCODING_SCALE_OFFSET : QNN_QUANTIZATION_ENCODING_UNDEFINED; - quantize_param.scaleOffsetEncoding.scale = scale; - quantize_param.scaleOffsetEncoding.offset = offset; +inline Status GetOnnxTensorElemDataType(const NodeArg& node_arg, /*out*/ int32_t& onnx_data_type) { + auto type_proto = node_arg.TypeAsProto(); + ORT_RETURN_IF_NOT(type_proto != nullptr && type_proto->has_tensor_type() && type_proto->tensor_type().has_elem_type(), + "NodeArg must have a tensor TypeProto"); + onnx_data_type = type_proto->tensor_type().elem_type(); + return Status::OK(); +} + +template +static Status InvertPerm(gsl::span perm, /*out*/ gsl::span perm_inv) { + static_assert(std::is_integral::value, "permutation arrays must contain integer elements"); + + size_t rank = perm.size(); + ORT_RETURN_IF_NOT(perm_inv.size() == rank, "perm.size() != perm_inv.size()"); + + for (size_t i = 0; i < rank; ++i) { + size_t j = static_cast(perm[i]); + ORT_RETURN_IF_NOT(j < rank, "perm element out of range [0, rank - 1]"); + perm_inv[j] = static_cast(i); + } + + return Status::OK(); } // Utility function that checks if an array of strings contains a specific string. diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 80617b7b5edaa..e3d591a0d5bb4 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -24,6 +24,7 @@ QuantType, find_by_name, model_has_infer_metadata, + normalize_axis, quantize_data, quantize_nparray, save_and_reload_model_with_shape_infer, @@ -120,9 +121,9 @@ def __init__( # Get tensor-level quantization overrides and ensure they are valid. self.tensor_quant_overrides = TensorQuantOverridesHelper(self.extra_options.get("TensorQuantOverrides", {})) - initializer_names = {initzer.name for initzer in self.model.initializer()} + self.initializers = {initzer.name: initzer for initzer in self.model.initializer()} overrides_valid, overrides_err = self.tensor_quant_overrides.is_valid( - initializer_names, self.value_infos.keys(), activation_qType + self.initializers, self.value_infos.keys(), activation_qType ) if not overrides_valid: raise ValueError(overrides_err) @@ -252,7 +253,7 @@ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1 quantized_bias_zp_name = quantized_bias_name + "_zero_point" if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN: packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, self.weight_qType, [1], [0.0]) - elif self.is_per_channel(): + elif bias_scale.size > 1: bias_zp_data = np.zeros(bias_scale.shape, dtype=np.int32).reshape(-1) packed_bias_zp_initializer = onnx.numpy_helper.from_array(bias_zp_data, quantized_bias_zp_name) else: @@ -282,7 +283,7 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa # Quantize weight data. Use quantization overrides if provided by the user. weight_data = tensor_proto_to_array(weight) - quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(weight.name) + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(weight.name, default_val={}) if "quant_type" in quant_overrides: qType = quant_overrides["quant_type"].tensor_type # noqa: N806 @@ -358,20 +359,51 @@ def quantize_weight_per_channel_impl( raise ValueError("{} is not an initializer", weight_name) weights = tensor_proto_to_array(initializer) + weights_rank = len(weights.shape) + is_axis_valid, axis_norm = normalize_axis(channel_axis, weights_rank) + if not is_axis_valid: + raise ValueError( + f"Weight {weight_name} has a per-channel axis with value {channel_axis} that is " + f"out-of-bounds for rank {weights_rank}" + ) + + channel_axis = axis_norm channel_count = weights.shape[channel_axis] - quant_overrides_for_channels = self.tensor_quant_overrides.get_per_channel_overrides(weight_name, channel_count) + quant_overrides_for_channels = self.tensor_quant_overrides.get_per_channel_overrides( + weight_name, default_val=[{"axis": channel_axis}] + ) + + num_channel_overrides = len(quant_overrides_for_channels) + if num_channel_overrides != 1 and num_channel_overrides != channel_count: + raise ValueError( + f"Per-channel tensor quantization overrides for {weight_name} must have " + f"either 1 or {channel_count} elements in the list of dictionaries." + ) - # If user provides per-channel quantization overrides, all channels must use the same quantization type. - # So, just use the first channel's type. + is_axis_override_valid, axis_override = normalize_axis(quant_overrides_for_channels[0]["axis"], weights_rank) + if not is_axis_override_valid or axis_override != channel_axis: + raise ValueError( + f"Tensor quantization overrides for {weight_name} specify an unexpected axis. " + f"Expected {channel_axis}, but got {quant_overrides_for_channels[0]['axis']}." + ) + + # If user provides per-channel quantization overrides, all channels must use the same quant_type, + # axis, symmetric, and reduce_range values. So, just use the first channel's values. if "quant_type" in quant_overrides_for_channels[0]: weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type # noqa: N806 + symmetric = quant_overrides_for_channels[0].get( + "symmetric", + (self.is_weight_symmetric or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN)), + ) + reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range) zero_point_list = [] scale_list = [] quantized_per_channel_data_list = [] for i in range(channel_count): per_channel_data = weights.take(i, channel_axis) - channel_quant_overrides = quant_overrides_for_channels[i] + channel_override_index = i if i < num_channel_overrides else 0 + channel_quant_overrides = quant_overrides_for_channels[channel_override_index] if "scale" in channel_quant_overrides and "zero_point" in channel_quant_overrides: zero_point = np.array(channel_quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[weight_qType]) @@ -389,18 +421,11 @@ def quantize_weight_per_channel_impl( ), f"Unexpected type {type(quantized_per_channel_data)}" else: - symmetric = channel_quant_overrides.get( - "symmetric", - ( - self.is_weight_symmetric - or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN) - ), - ) _, _, zero_point, scale, quantized_per_channel_data = quantize_data( per_channel_data.flatten(), weight_qType, symmetric, - reduce_range=channel_quant_overrides.get("reduce_range", self.reduce_range and reduce_range), + reduce_range=reduce_range, min_real_range=self.min_real_range, rmin_override=channel_quant_overrides.get("rmin"), rmax_override=channel_quant_overrides.get("rmax"), diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py index d59a0ec74ca7c..6396e87c73d03 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py @@ -138,7 +138,7 @@ def create_from_model( value_infos.update({it.name: it for it in model.graph.input}) # Ensure that the user-provided initial overrides are actually valid. - valid, err = overrides.is_valid(set(initializers), set(value_infos), default_activation_qtype) + valid, err = overrides.is_valid(initializers, set(value_infos), default_activation_qtype) if not valid: pprint_overrides = overrides.pprint_str(indent=4) logging.error(f"Provided invalid tensor quantization overrides:\n{pprint_overrides}") @@ -233,7 +233,7 @@ def apply( raise ValueError(f"TypeRequest for tensor {tensor_name} has no producer or consumers.") # Done. Check if the overrides are valid. - valid, err = self.overrides.is_valid(set(self.initializers), set(self.value_infos), default_activation_qtype) + valid, err = self.overrides.is_valid(self.initializers, set(self.value_infos), default_activation_qtype) if not valid: pprint_overrides = self.overrides.pprint_str(indent=4) logging.error( diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 3a217fdfaaffd..3c9b319c78535 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -8,6 +8,7 @@ import copy import logging from pathlib import Path +from typing import Any import numpy as np import onnx @@ -41,18 +42,77 @@ def warn_unable_to_override( def get_qnn_qdq_config( model_input: str | Path | onnx.ModelProto, calibration_data_reader: CalibrationDataReader, - calibrate_method=CalibrationMethod.MinMax, - activation_type=QuantType.QUInt8, - weight_type=QuantType.QUInt8, - per_channel=False, - init_overrides=None, - add_qtype_converts=True, - activation_symmetric=False, - weight_symmetric=None, -): - if per_channel: - raise ValueError("QNN EP does not yet support per-channel quantization.") - + calibrate_method: CalibrationMethod = CalibrationMethod.MinMax, + activation_type: QuantType = QuantType.QUInt8, + weight_type: QuantType = QuantType.QUInt8, + per_channel: bool = False, + init_overrides: dict[str, list[dict[str, Any]]] | None = None, + add_qtype_converts: bool = True, + activation_symmetric: bool = False, + weight_symmetric: bool | None = None, +) -> StaticQuantConfig: + """ + Returns a static quantization configuration suitable for running QDQ models on QNN EP. + This is done primarily by setting tensor-level quantization overrides. + + Params: + model_input: Path to the input model file or ModelProto. + calibration_data_reader: Calibration data reader. + calibrate_methode: The calibration method. Defaults to MinMax. + activation_type: The default activation quantization type. Defaults to QUInt8. + weight_type: The default weight quantization type. Defaults to QUInt8. + per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel. + Defaults to false. Alternatively, use the tensor-level `init_overrides` to select individual operators + and their quantization axes. + + If set, the quantization tool uses per-channel quantization for the following operator types and inputs: + - Conv: + - input[1] on axis 0 + - input[2] (bias) on axis 0 + - ConvTranspose: + - input[1] on axis 1 + - input[2] (bias) on axis 0 + init_overrides: Initial tensor-level quantization overrides. Defaults to None. This function updates of a copy + of these overrides with any necessary adjustments and includes them in the returned + configuration object (i.e., config.extra_options['TensorQuantOverrides']). + + The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list + contains a single dictionary. For per-channel quantization, the list contains either a dictionary for + each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis' + key must be present in the first dictionary for per-channel quantization. + + Each dictionary contains optional overrides with the following keys and values. + 'quant_type' = QuantType : The tensor's quantization data type. + 'axis' = Int : The per-channel axis. Must be present for per-channel weights. + 'scale' = Float : The scale value to use. Must also specify `zero_point` if set. + 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set. + 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also + set `scale` or `zero_point`. + 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also + set `scale` or `zero_point`. Only valid for initializers. + 'rmax' = Float : Override the maximum real tensor value in calibration data. + Invalid if also set `scale` or `zero_point`. + 'rmin' = Float : Override the minimum real tensor value in calibration data. + Invalid if also set `scale` or `zero_point`. + 'convert' = Dict : A nested dictionary with the same keys for an activation + tensor that should be converted to another quantization type. + 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation, + other nodes get the original type. If not specified, + assume all consumer nodes get the converted type. + add_qtype_converts: True if this function should automatically add "convert" entries to the provided + `init_overrides` to ensure that operators use valid input/output types (activations only). + Ex: if you override the output of an Add to 16-bit, this option ensures that the activation inputs + of the Add are also up-converted to 16-bit and that data types for surrounding ops are converted + appropriately. Refer to the documentation in mixed_precision_overrides_utils.py for additional details. + activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default. + Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uin16, + the zero-point values are 128 and 32,768, respectively. + weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default. + Defaults to None. If set to None, weight_symmetric is assumed true if the weight_type is a signed int. + + Returns: + A StaticQuantConfig object + """ if weight_symmetric is None: weight_symmetric = weight_type in {QuantType.QInt8, QuantType.QInt16} @@ -88,6 +148,7 @@ def get_qnn_qdq_config( weight_type, activation_symmetric, weight_symmetric, + per_channel, overrides_helper, name_to_initializer, ) @@ -115,6 +176,7 @@ def get_qnn_qdq_config( activation_type=activation_type, weight_type=weight_type, op_types_to_quantize=list(op_types.difference(OP_TYPES_TO_EXCLUDE)), + per_channel=per_channel, use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), extra_options=extra_options, ) @@ -132,6 +194,7 @@ def __init__( default_weight_qtype: QuantType, activation_symmetric: bool, weight_symmetric: bool, + per_channel: bool, overrides: TensorQuantOverridesHelper, initializers: dict[str, onnx.TensorProto], ): @@ -139,6 +202,7 @@ def __init__( self.default_weight_qtype = default_weight_qtype self.activation_symmetric = activation_symmetric self.weight_symmetric = weight_symmetric + self.per_channel = per_channel self.overrides = overrides self.initializers = initializers @@ -155,73 +219,102 @@ def process_node(self, node: onnx.NodeProto): if process_fn is not None: process_fn(node) - def _process_matmul(self, node: onnx.NodeProto): + def _make_static_inputs_use_default_weight_type(self, node: onnx.NodeProto): """ - Overrides MatMul's initializer input(s) to use the default weight type if: + Overrides initializer input(s) to use the default weight type if: - The default weight type is 8-bit - One of the inputs is a 16-bit activation + - The other input is an initializer (per-tensor quantized) + + This is necessary because the quantization tool does not assign MatMul or LayerNorm initializer + inputs the default weight type. Instead, it assigns the default activation type. """ - assert node.op_type == "MatMul", f"Expected MatMul, but got {node.op_type}" if self.default_weight_qtype not in Q8_TYPES: return - input_16bit_act = None - input_wgt = None + input_16bit_act_name = None + input_weight_name = None - for input_name in node.input: - if input_name and input_name not in self.initializers: - qtype = self.overrides.get_node_input_qtype_info( - input_name, node.name, self.default_activation_qtype - ).quant_type - if qtype in Q16_TYPES: - input_16bit_act = input_name - else: - input_wgt = input_name - - # Override initializer to use the default weight type. - if input_16bit_act and input_wgt: + # Loop through first 2 inputs to find a 16-bit activation and a (per-tensor) weight. + for i in range(2): + input_name = node.input[i] + if not input_name: + continue + + is_weight = input_name in self.initializers + qtype_info = self.overrides.get_node_input_qtype_info( + input_name, + node.name, + default_qtype=None if is_weight else self.default_activation_qtype, + ) + + if qtype_info.axis is not None: + return # Don't process MatMul with a per-channel quantized input. + + if ( + is_weight + and qtype_info.quant_type == self.default_weight_qtype + and qtype_info.symmetric == self.weight_symmetric + ): + return # Return. Weight is already overridden to use the desired weight type. + + if is_weight: + input_weight_name = input_name + elif qtype_info.quant_type in Q16_TYPES: + input_16bit_act_name = input_name + + # Override initializer input to use the default weight type. + if input_16bit_act_name and input_weight_name: did_update = self.overrides.update_tensor_overrides( - input_wgt, + input_weight_name, {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric}, overwrite=False, ) if not did_update: - warn_unable_to_override(node, "quant_type/symmetric", input_wgt, "input weight") + warn_unable_to_override(node, "quant_type/symmetric", input_weight_name, "input weight") + + def _process_matmul(self, node: onnx.NodeProto): + assert node.op_type == "MatMul", f"Expected MatMul, but got {node.op_type}" + + if not self.per_channel: + self._make_static_inputs_use_default_weight_type(node) + return + + # QNN does not support per-channel MatMul. However, the ORT quantization tool attempts to use per-channel + # quantization for MatMul by default *if* the global per_channel setting is enabled. So, we need to + # provide explicit per-tensor quantization overrides for MatMul if per_channel is enabled and + # the user did not provide any other overrides. + for input_name in node.input: + is_weight_no_overrides = input_name in self.initializers and input_name not in self.overrides + if is_weight_no_overrides: + self.overrides.update_tensor_overrides( + input_name, + {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric}, + ) def _process_layernorm(self, node: onnx.NodeProto): - """ - Overrides LayerNormalization's initializer input(s), except for bias, to use the default weight type if: - - The default weight type is 8-bit - - One of the inputs is a 16-bit activation - """ assert node.op_type == "LayerNormalization", f"Expected LayerNormalization, but got {node.op_type}" - if self.default_weight_qtype not in Q8_TYPES: + + if not self.per_channel: + self._make_static_inputs_use_default_weight_type(node) return - has_q16_activation = False - for input_name in node.input: - if input_name and input_name not in self.initializers: - qtype = self.overrides.get_node_input_qtype_info( - input_name, node.name, self.default_activation_qtype - ).quant_type - if qtype in Q16_TYPES: - has_q16_activation = True - break - - # Override initializers to use the self.default_weight_qtype. Don't override the bias input. - if has_q16_activation: - for i in range(2): - input_name = node.input[i] - if input_name and input_name in self.initializers: - did_update = self.overrides.update_tensor_overrides( - input_name, - {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric}, - overwrite=False, - ) - - if not did_update: - warn_unable_to_override(node, "quant_type/symmetric", input_name, "input weight") + has_weight_no_overrides = node.input[1] in self.initializers and node.input[1] not in self.overrides + has_bias_no_overrides = ( + len(node.input) > 2 + and node.input[2] + and node.input[2] in self.initializers + and node.input[2] not in self.overrides + ) + + if has_weight_no_overrides or has_bias_no_overrides: + # TODO: Make bias input not per-channel. QNN needs it to be per-tensor, but quantizer + # tries to makes it per-channel if the weight is also per-channel. + raise ValueError( + "get_qnn_qdq_config() does not currently support the global per_channel option with LayerNormalization." + " Please try using custom overrides that make bias per-tensor quantized." + ) def _process_sigmoid(self, node: onnx.NodeProto): """ diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 4b76de6ecf1cb..f84e00abd6105 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -963,7 +963,7 @@ def calculate_quantization_params(self): if not isinstance(td, TensorData): raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") - quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name) + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name, default_val={}) quant_type = self.activation_qType if "quant_type" in quant_overrides: diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index b053c65ad6f85..922884a5f6383 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -246,9 +246,11 @@ def quantize(self): if not self.disable_qdq_for_node_output: self.quantizer.quantize_activation_tensor(node.output[0]) - if self.quantizer.is_per_channel(): - axis = 0 if node.op_type == "Conv" else 1 - self.quantizer.quantize_weight_tensor_per_channel(node.input[1], axis) + is_weight_per_channel, weight_axis = self.quantizer.is_tensor_per_channel( + node.input[1], default_axis=0 if node.op_type == "Conv" else 1 + ) + if is_weight_per_channel: + self.quantizer.quantize_weight_tensor_per_channel(node.input[1], weight_axis) else: self.quantizer.quantize_weight_tensor(node.input[1]) diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py index df24e256aa7fc..5d7bf6e2cd2d7 100644 --- a/onnxruntime/python/tools/quantization/operators/gemm.py +++ b/onnxruntime/python/tools/quantization/operators/gemm.py @@ -146,8 +146,11 @@ def quantize(self): if not self.disable_qdq_for_node_output: self.quantizer.quantize_activation_tensor(node.output[0]) - if self.quantizer.is_per_channel(): - self.quantizer.quantize_weight_tensor_per_channel(node.input[1], 0 if is_B_transposed(node) else 1) + is_weight_per_channel, weight_axis = self.quantizer.is_tensor_per_channel( + node.input[1], default_axis=0 if is_B_transposed(node) else 1 + ) + if is_weight_per_channel: + self.quantizer.quantize_weight_tensor_per_channel(node.input[1], weight_axis) else: self.quantizer.quantize_weight_tensor(node.input[1]) diff --git a/onnxruntime/python/tools/quantization/operators/matmul.py b/onnxruntime/python/tools/quantization/operators/matmul.py index af76a68f137ab..5d2961581b8b5 100644 --- a/onnxruntime/python/tools/quantization/operators/matmul.py +++ b/onnxruntime/python/tools/quantization/operators/matmul.py @@ -219,9 +219,10 @@ def quantize(self): nodes_to_iterate = itertools.chain(node.input, node.output) for tensor_name in nodes_to_iterate: - # only support per-channel quantization on weight - if self.quantizer.is_per_channel() and find_by_name(tensor_name, self.quantizer.model.initializer()): - channel_axis = self.quantizer.qdq_op_type_per_channel_support_to_axis.get(node.op_type, 1) + is_per_channel, channel_axis = self.quantizer.is_tensor_per_channel( + tensor_name, default_axis=1, op_type=node.op_type + ) + if is_per_channel: self.quantizer.quantize_weight_tensor_per_channel(tensor_name, channel_axis) else: self.quantizer.quantize_activation_tensor(tensor_name) diff --git a/onnxruntime/python/tools/quantization/operators/norm.py b/onnxruntime/python/tools/quantization/operators/norm.py index 3c14c926a7e75..8c4c6c78582ac 100644 --- a/onnxruntime/python/tools/quantization/operators/norm.py +++ b/onnxruntime/python/tools/quantization/operators/norm.py @@ -19,17 +19,20 @@ def quantize(self): # Scale scale_is_initializer = self.quantizer.is_input_a_initializer(node.input[1]) + scale_is_per_channel, scale_channel_axis = self.quantizer.is_tensor_per_channel( + node.input[1], default_axis=1, op_type=node.op_type + ) - if self.quantizer.is_per_channel() and scale_is_initializer: - channel_axis = self.quantizer.qdq_op_type_per_channel_support_to_axis.get(node.op_type, 1) - self.quantizer.quantize_weight_tensor_per_channel(node.input[1], axis=channel_axis) + if scale_is_per_channel: + self.quantizer.quantize_weight_tensor_per_channel(node.input[1], axis=scale_channel_axis) elif scale_is_initializer: self.quantizer.quantize_weight_tensor(node.input[1]) else: self.quantizer.quantize_activation_tensor(node.input[1]) # Bias - self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1]) + if len(node.input) > 2 and node.input[2]: + self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1]) # Output if not self.disable_qdq_for_node_output: diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index c323c6fec545a..2416cf970e466 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -35,6 +35,7 @@ find_by_name, get_qmin_qmax_for_qType, ms_domain, + normalize_axis, tensor_proto_to_array, ) from .registry import CreateQDQQuantizer @@ -335,8 +336,9 @@ def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, be logging.info( f"Quantizing bias tensor '{bias_name}' as a weight due to the presence of user-specified overrides" ) - if self.per_channel: - self.quantize_weight_tensor_per_channel(bias_name, 0) + is_per_channel, axis = self.is_tensor_per_channel(bias_name, default_axis=0) + if is_per_channel: + self.quantize_weight_tensor_per_channel(bias_name, axis) else: self.quantize_weight_tensor(bias_name) return @@ -471,6 +473,7 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): qtype = self.activation_qType if self.activation_qType == onnx.onnx_pb.TensorProto.UINT8: qtype = onnx_proto.TensorProto.INT8 + q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( weight_name, # Quantization type is forced to be TensorProto.INT8. @@ -930,6 +933,56 @@ def quantize_initializer( self.quantized_value_map[weight.name] = QDQTensorQuantizedValue(quantized_value, None, None) return q_weight_name, zp_name, scale_name + def is_tensor_per_channel( + self, + tensor_name: str, + default_axis: int, + op_type: str | None = None, + ) -> tuple[bool, int | None]: + """ + Checks if a given tensor is configured to be quantized per-channel. If so, also returns the channel axis. + + ORT only supports per-channel quantization on static weights (i.e., ONNX initializers). If the user did not provide + tensor quantization overrides for this tensor, then the value of self.per_channel determines if the weight + is to be quantized per-channel. + + Params: + tensor_name: The name of the tensor to check. + default_axis: The default channel axis. This method checks if the normalized axis is within bounds. + Can be overridden via the extra_options 'QDQOpTypePerChannelSupportToAxis' + and 'TensorQuantOverrides'. + op_type: Optional, defaults to None. The operator type that is the only consumer of this weight. + Used to access the extra option 'QDQOpTypePerChannelSupportToAxis'. + Returns: + A tuple (is_per_channel, axis) in which the first element indicates whether the tensor is + quantized per-channel and the second element is the channel axis. + The returned axis is only None if the tensor is not per-channel or the axis is out of bounds. + """ + weight_initializer = self.initializers.get(tensor_name) + if weight_initializer is None: + return False, None # Only support per-channel weights + + if self.tensor_quant_overrides.has_per_tensor_overrides(tensor_name): + return False, None # User provided per-tensor overrides for this initializer + + has_per_chan_overrides = self.tensor_quant_overrides.has_per_channel_overrides(tensor_name) + if not self.per_channel and not has_per_chan_overrides: + return False, None # global self.per_channel is off and user did not provide per-channel overrides. + + axis = self.qdq_op_type_per_channel_support_to_axis.get(op_type, default_axis) if op_type else default_axis + if has_per_chan_overrides: + per_chan_overrides = self.tensor_quant_overrides.get_per_channel_overrides(tensor_name) + axis = per_chan_overrides[0]["axis"] # Prefer axis from user-specified tensor-level overrides if available + + weight_nparray = tensor_proto_to_array(weight_initializer) + weight_rank = len(weight_nparray.shape) + axis_valid, axis = normalize_axis(axis, weight_rank) + if not axis_valid: + logging.warning(f"Axis {axis} is out-of-range for weight '{tensor_name}' with rank {weight_rank}") + return False, None + + return True, axis + def quantize_weight_per_channel( self, weight_name: str, @@ -1106,7 +1159,7 @@ def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: if not isinstance(td, TensorData): raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") - quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name) + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name, default_val={}) original = self.calc_quant_params(td, quant_overrides) converted = None converted_recv_nodes = None diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 131e55458fb86..35b5e1c8ba825 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -1,3 +1,10 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + import logging import os import tempfile @@ -253,7 +260,17 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=Non scale = numpy.array(1.0, dtype=rmax.dtype) zero_point = numpy.array(0, dtype=qmin.dtype) else: - zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=qmin.dtype) + if symmetric: + # When symmetric (i.e., rmax == -rmin), the zero_point formula reduces to round((qmax + qmin) / 2.0). + # This simpler formula doesn't depend on scale and guarantees that the zero point values + # for int8, uint8, int16, and uint16 are always 0, 128, 0, and 32768, respectively. + # This is important for per-channel/symmetric QLinearConv on CPU EP, which requires all channels to have + # the exact same zero_point values. + zero_point = numpy.array( + numpy.round((qmin + qmax) / numpy.array(2.0, dtype=numpy.float64)), dtype=qmin.dtype + ) + else: + zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=qmin.dtype) scale = scale.astype(rmax.dtype) return [zero_point, scale] @@ -407,6 +424,18 @@ def get_qrange_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N return qmax - qmin +def normalize_axis(axis: int, rank: int) -> tuple[bool, int]: + """ + Helper function that tries to return a normalized axis in the range [0, rank - 1]. + :parameter axis: The axis to normalize. + :parameter rank: The tensor rank (number of dimensions). + :return (is_valid, axis_norm) + """ + axis_norm = axis + rank if axis < 0 else axis + is_valid = axis_norm >= 0 and axis_norm < rank + return is_valid, axis_norm + + class QuantizedInitializer: """ Represents a linearly quantized weight input from ONNX operators diff --git a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py index 793d58cbc4e3e..6050bd2e05ec5 100644 --- a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py +++ b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py @@ -10,7 +10,9 @@ from dataclasses import dataclass from typing import Any -from .quant_utils import QuantType +import onnx + +from .quant_utils import QuantType, tensor_proto_to_array @dataclass @@ -22,6 +24,7 @@ class QuantTypeInfo: quant_type: QuantType symmetric: bool | None = None # If None, assumes default is used. reduce_range: bool | None = None # If None, assumes default is used. + axis: int | None = None # If None, assumes per-tensor quantization def __eq__(self, other: object): if isinstance(other, QuantTypeInfo): @@ -29,20 +32,22 @@ def __eq__(self, other: object): self.quant_type == other.quant_type and (self.symmetric is None or other.symmetric is None or self.symmetric == other.symmetric) and (self.reduce_range is None or other.reduce_range is None or self.reduce_range == other.reduce_range) + and (self.axis == other.axis) ) return NotImplemented @staticmethod def load_from_dict( raw_dict: dict[str, Any], - default_activation_qtype: QuantType | None = None, - default_activation_symmetric: bool | None = None, - default_activation_reduce_range: bool | None = None, + default_qtype: QuantType | None = None, + default_symmetric: bool | None = None, + default_reduce_range: bool | None = None, ) -> QuantTypeInfo: return QuantTypeInfo( - raw_dict.get("quant_type", default_activation_qtype), - raw_dict.get("symmetric", default_activation_symmetric), - raw_dict.get("reduce_range", default_activation_reduce_range), + raw_dict.get("quant_type", default_qtype), + raw_dict.get("symmetric", default_symmetric), + raw_dict.get("reduce_range", default_reduce_range), + raw_dict.get("axis"), ) def save_to_dict(self, raw_dict: dict[str, Any]): @@ -51,6 +56,8 @@ def save_to_dict(self, raw_dict: dict[str, Any]): raw_dict["symmetric"] = self.symmetric if self.reduce_range is not None: raw_dict["reduce_range"] = self.reduce_range + if self.axis is not None: + raw_dict["axis"] = self.axis class TensorQuantOverridesHelper(MutableMapping): @@ -61,29 +68,44 @@ class TensorQuantOverridesHelper(MutableMapping): def __init__(self, raw_overrides: dict[str, list[dict[str, Any]]]): self.overrides = raw_overrides self.quant_types = None + self.keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} + + def has_per_tensor_overrides(self, tensor_name: str) -> bool: + overrides_list = self.overrides.get(tensor_name) + return overrides_list and "axis" not in overrides_list[0] + + def has_per_channel_overrides(self, tensor_name: str) -> bool: + overrides_list = self.overrides.get(tensor_name) + return overrides_list and "axis" in overrides_list[0] - def get_per_tensor_overrides(self, tensor_name: str) -> dict[str, Any]: - overrides_list = self.overrides.get(tensor_name, [{}]) - num_overrides = len(overrides_list) - if num_overrides > 1: + def get_per_tensor_overrides( + self, + tensor_name: str, + default_val: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + default_list_val = [default_val] if default_val is not None else None + overrides_list = self.overrides.get(tensor_name, default_list_val) + if overrides_list and "axis" in overrides_list[0]: raise ValueError( f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, " - f"but found {num_overrides} per-channel overrides." + f"but found per-channel overrides." ) - return overrides_list[0] if num_overrides > 0 else {} + return overrides_list[0] if overrides_list else None def get_per_channel_overrides( self, tensor_name: str, - num_channels: int, - ) -> list[dict[str, Any]]: - overrides_list = self.overrides.get(tensor_name, [{} for i in range(num_channels)]) + default_val: list[dict[str, Any]] | None = None, + ) -> list[dict[str, Any]] | None: + overrides_list = self.overrides.get(tensor_name, default_val) + + if not overrides_list: + return None - if len(overrides_list) != num_channels: + if "axis" not in overrides_list[0]: raise ValueError( - f"Expected tensor '{tensor_name}' to have {num_channels} per-channel quantization overrides, " - f"but found {len(overrides_list)} instead." + f"Expected tensor '{tensor_name}' to have per-channel quantization overrides (axis value is missing).", ) return overrides_list @@ -105,9 +127,236 @@ def get_quant_types(self) -> set[QuantType]: return self.quant_types + def _is_valid_per_tensor( + self, + initializers, + default_activation_qtype, + tensor_name: str, + quant_overrides: dict[str, Any], + ) -> tuple[bool, str | None]: + if not isinstance(quant_overrides, dict): + return ( + False, + f"Tensor quantization overrides for '{tensor_name}' are not in a dict", + ) + + is_initializer = tensor_name in initializers + + quant_type = quant_overrides.get("quant_type") + if quant_type: + self.quant_types.add(quant_type) + + has_scale = "scale" in quant_overrides + has_zero_point = "zero_point" in quant_overrides + + if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): + return ( + False, + "Must provide both 'scale' and 'zero_point' if one of the overrides is provided", + ) + + if has_scale: + keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides)) + if keys: + return ( + False, + f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'", + ) + + if "reduce_range" in quant_overrides and not is_initializer: + return ( + False, + f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", + ) + + if "convert" in quant_overrides: + if is_initializer: + return False, "Cannot use 'convert' override for initializers" + + if "quant_type" not in quant_overrides["convert"]: + return False, f"'convert' options (tensor '{tensor_name}') must specify a 'quant_type'" + + if "reduce_range" in quant_overrides["convert"]: + return ( + False, + f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", + ) + + convert_quant_type = quant_overrides["convert"]["quant_type"] + original_quant_type = quant_type if quant_type is not None else default_activation_qtype + if convert_quant_type == original_quant_type: + return ( + False, + f"'convert' quant_type must differ from original quant_type (tensor '{tensor_name}')", + ) + + convert_has_scale = "scale" in quant_overrides["convert"] + convert_has_zero_point = "zero_point" in quant_overrides["convert"] + + if (convert_has_scale and not convert_has_zero_point) or (convert_has_zero_point and not convert_has_scale): + return ( + False, + f"Must provide both 'scale' and 'zero_point' if one of the overrides is provided (tensor '{tensor_name}')", + ) + + if convert_has_scale: + keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides["convert"])) + if keys: + return ( + False, + f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point' " + f"(tensor '{tensor_name}')", + ) + + self.quant_types.add(convert_quant_type) + + return True, None + + def _is_valid_per_channel( + self, + initializers, + tensor_name: str, + quant_overrides_list: list[dict[str, Any]], + ) -> tuple[bool, str | None]: + is_initializer = tensor_name in initializers + + if not is_initializer: + return ( + False, + f"Tensor '{tensor_name}' has per-channel overrides, but is not an initializer", + ) + + axis = quant_overrides_list[0].get("axis") + + if axis is None: + return ( + False, + f"Per-channel overrides for tensor {tensor_name} is missing an 'axis' value in " + "the first channel dictionary.", + ) + + weight_shape = tensor_proto_to_array(initializers[tensor_name]).shape + weight_rank = len(weight_shape) + norm_axis = axis + if norm_axis < 0: + norm_axis += weight_rank + + if norm_axis < 0 or norm_axis >= len(weight_shape): + return ( + False, + f"Axis override value is out-of-bounds for tensor {tensor_name} (rank {len(weight_shape)})", + ) + + if len(quant_overrides_list) > 1 and len(quant_overrides_list) != weight_shape[norm_axis]: + return ( + False, + f"Incorrect number of channel overrides for tensor {tensor_name} (axis {axis}), " + f"expected {weight_shape[axis]}, but found {len(quant_overrides_list)}.", + ) + + if "convert" in quant_overrides_list[0]: + return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}." + + quant_type = quant_overrides_list[0].get("quant_type") + if quant_type: + self.quant_types.add(quant_type) + + symmetric = quant_overrides_list[0].get("symmetric") + reduce_range = quant_overrides_list[0].get("reduce_range") + + has_scale = "scale" in quant_overrides_list[0] + has_zero_point = "zero_point" in quant_overrides_list[0] + has_scale_zp = has_scale and has_zero_point + + if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): + return ( + False, + "Must provide both 'scale' and 'zero_point' if one of the overrides is provided", + ) + + if has_scale_zp: + keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides_list[0])) + if keys: + return ( + False, + f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'", + ) + + has_rmin = "rmin" in quant_overrides_list[0] + has_rmax = "rmax" in quant_overrides_list[0] + has_rmin_rmax = has_rmin and has_rmax + if (has_rmin and not has_rmax) or (not has_rmin and has_rmax): + return ( + False, + "Must provide both 'rmin' and 'rmax' if one is provided", + ) + + for index, quant_overrides in enumerate(quant_overrides_list[1:]): + if not isinstance(quant_overrides, dict): + return ( + False, + f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict", + ) + + if "convert" in quant_overrides: + return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}." + + # For per-channel quantization, all channels must use the same quantization type, axis, symmetric + # and reduce_range values. And, if specified, they must be present in the first channel dict + # (i.e., quant_overrides_list[0]). + if "quant_type" in quant_overrides and quant_type != quant_overrides["quant_type"]: + return ( + False, + "Channel quantization types for tensor '{tensor_name}' do not match at index {index}.", + ) + if "axis" in quant_overrides and axis != quant_overrides["axis"] and norm_axis != quant_overrides["axis"]: + return ( + False, + "Channel axis for tensor '{tensor_name}' does not match at index {index}.", + ) + if "symmetric" in quant_overrides and symmetric != quant_overrides["symmetric"]: + return ( + False, + "Channel symmetric value for tensor '{tensor_name}' does not match at index {index}.", + ) + if "reduce_range" in quant_overrides and reduce_range != quant_overrides["reduce_range"]: + return ( + False, + "Channel reduce_range value for tensor '{tensor_name}' does not match at index {index}.", + ) + + # If override scale/zp, must do so for all channels. + chan_has_scale_zp = "scale" in quant_overrides and "zero_point" in quant_overrides + + if has_scale_zp and not chan_has_scale_zp: + return ( + False, + "Per-channel overrides that specify scale/zero_point must do so for all channels, " + f"but tensor '{tensor_name}' is missing them at index {index}.", + ) + + if chan_has_scale_zp: + keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides)) + if keys: + return ( + False, + f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'", + ) + + # If override rmin/rmax, must do so for all channels. + chan_has_rmin_rmax = "rmin" in quant_overrides and "rmax" in quant_overrides + if has_rmin_rmax and not chan_has_rmin_rmax: + return ( + False, + "Per-channel overrides that specify rmin/rmax must do so for all channels, " + f"but tensor '{tensor_name}' is missing them at index {index}.", + ) + + return True, None + def is_valid( self, - initializer_names: set[str], + initializers: dict[str, onnx.TensorProto], activation_names: set[str], default_activation_qtype, ) -> tuple[bool, str | None]: @@ -115,113 +364,31 @@ def is_valid( # Validate that compatible/valid overrides are provided. if self.overrides: - keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} - for tensor_name, quant_overrides_list in self.overrides.items(): - if tensor_name not in initializer_names and tensor_name not in activation_names: + if tensor_name not in initializers and tensor_name not in activation_names: return False, f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model" if not isinstance(quant_overrides_list, list): return False, f"Tensor quantization overrides for '{tensor_name}' are not in a list" - is_initializer = tensor_name in initializer_names - if not is_initializer and len(quant_overrides_list) > 1: - return ( - False, - f"Tensor '{tensor_name}' has a list of per-channel overrides, but is not an initializer", - ) + if not quant_overrides_list: + continue + + if not isinstance(quant_overrides_list[0], dict): + return False, f"Tensor quantization overrides at index 0 for '{tensor_name}' are not in a dict" + + if not quant_overrides_list[0]: + continue - quant_type = None - for index, quant_overrides in enumerate(quant_overrides_list): - if not isinstance(quant_overrides, dict): - return ( - False, - f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict", - ) - - # For per-channel quantization, all channels must use the same quantization type. - # Therefore, if the user tries to override the quant_type for a channel, it must match in all - # other channels. - if index == 0: - quant_type = quant_overrides.get("quant_type") - if quant_type: - self.quant_types.add(quant_type) - elif quant_type != quant_overrides.get("quant_type"): - return ( - False, - "Channel quantization types for tensor '{tensor_name}' do not match at index {index}.", - ) - - has_scale = "scale" in quant_overrides - has_zero_point = "zero_point" in quant_overrides - - if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): - return ( - False, - "Must provide both 'scale' and 'zero_point' if one of the overrides is provided", - ) - - if has_scale: - for key in keys_unsupported_with_scale_zp: - if key in quant_overrides: - return ( - False, - f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'", - ) - - if "reduce_range" in quant_overrides and not is_initializer: - return ( - False, - f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", - ) - - if "convert" in quant_overrides: - if index > 0: - return ( - False, - f"Per-channel overrides (tensor '{tensor_name}') do not support 'convert'.", - ) - - if is_initializer: - return False, "Cannot use 'convert' override for initializers" - - if "quant_type" not in quant_overrides["convert"]: - return False, f"'convert' options (tensor '{tensor_name}') must specify a 'quant_type'" - - if "reduce_range" in quant_overrides["convert"]: - return ( - False, - f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", - ) - - convert_quant_type = quant_overrides["convert"]["quant_type"] - original_quant_type = quant_type if quant_type is not None else default_activation_qtype - if convert_quant_type == original_quant_type: - return ( - False, - f"'convert' quant_type must differ from original quant_type (tensor '{tensor_name}')", - ) - - convert_has_scale = "scale" in quant_overrides["convert"] - convert_has_zero_point = "zero_point" in quant_overrides["convert"] - - if (convert_has_scale and not convert_has_zero_point) or ( - convert_has_zero_point and not convert_has_scale - ): - return ( - False, - f"Must provide both 'scale' and 'zero_point' if one of the overrides is provided (tensor '{tensor_name}')", - ) - - if convert_has_scale: - for key in keys_unsupported_with_scale_zp: - if key in quant_overrides["convert"]: - return ( - False, - f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point' (tensor '{tensor_name}')", - ) - - self.quant_types.add(convert_quant_type) + axis = quant_overrides_list[0].get("axis") + is_per_channel = len(quant_overrides_list) > 1 or axis is not None + + if is_per_channel: + return self._is_valid_per_channel(initializers, tensor_name, quant_overrides_list) + + return self._is_valid_per_tensor( + initializers, default_activation_qtype, tensor_name, quant_overrides_list[0] + ) return True, None @@ -266,11 +433,10 @@ def get_node_output_qtype_info( default_qtype: QuantType | None, default_symmetric: bool | None = None, ) -> QuantTypeInfo: + # Outputs are activations, which do not support 'reduce_range' or 'axis' if output_name not in self.overrides: return QuantTypeInfo(default_qtype, default_symmetric) - # Get the first overrides dict in the list. This works for both per-tensor and per-channel - # quantization because all channels must use the same quant type. tensor_overrides = self.overrides[output_name][0] return QuantTypeInfo( @@ -295,14 +461,19 @@ def get_node_input_qtype_info( producer_type = tensor_overrides.get("quant_type", default_qtype) if "convert" not in tensor_overrides: - return QuantTypeInfo(producer_type, default_symmetric, default_reduce_range) + return QuantTypeInfo( + producer_type, + tensor_overrides.get("symmetric", default_symmetric), + tensor_overrides.get("reduce_range", default_reduce_range), + tensor_overrides.get("axis"), + ) # This tensor is converted. Check if the node gets the original qtype or the converted qtype. convert_dict = tensor_overrides["convert"] qtype_info = QuantTypeInfo( producer_type, convert_dict.get("symmetric", default_symmetric), - convert_dict.get("reduce_range", default_reduce_range), + # Converted tensors are not initializers, so do not have 'axis' or 'reduce_range'. ) # Check if all nodes receive the converted type (i.e., recv_nodes is None) or this node diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 63577131480c6..57f10d9a4eb69 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -263,13 +263,14 @@ class ModelTestBuilder { Node& AddNode(const std::string& op_type, const std::vector& input_args, const std::vector& output_args, - const std::string& domain = "") { + const std::string& domain = "", + const NodeAttributes* attributes = nullptr) { return graph_.AddNode(graph_.GenerateNodeName("node"), op_type, "description", input_args, output_args, - nullptr, + attributes, domain); } @@ -299,6 +300,23 @@ class ModelTestBuilder { return AddNode("QuantizeLinear", input_args, {output_arg}, domain); } + template + typename std::enable_if::value, Node&>::type + AddQuantizeLinearNode(NodeArg* input_arg, + const std::vector& input_scales, + const std::vector& input_zero_points, + NodeArg* output_arg, + const NodeAttributes* attributes = nullptr, + bool use_ms_domain = false) { + std::vector input_args; + input_args.push_back(input_arg); + input_args.push_back(Make1DInitializer(input_scales)); + input_args.push_back(Make1DInitializer(input_zero_points)); + + std::string domain = use_ms_domain ? kMSDomain : ""; + return AddNode("QuantizeLinear", input_args, {output_arg}, domain, attributes); + } + Node& AddQuantizeLinearNode(NodeArg* input_arg, float input_scale, NodeArg* output_arg, @@ -311,6 +329,19 @@ class ModelTestBuilder { return AddNode("QuantizeLinear", input_args, {output_arg}, domain); } + Node& AddQuantizeLinearNode(NodeArg* input_arg, + const std::vector& input_scales, + NodeArg* output_arg, + const NodeAttributes* attributes = nullptr, + bool use_ms_domain = false) { + std::vector input_args; + input_args.push_back(input_arg); + input_args.push_back(Make1DInitializer(input_scales)); + + std::string domain = use_ms_domain ? kMSDomain : ""; + return AddNode("QuantizeLinear", input_args, {output_arg}, domain, attributes); + } + template typename std::enable_if::value, Node&>::type AddDequantizeLinearNode(NodeArg* input_arg, @@ -327,6 +358,23 @@ class ModelTestBuilder { return AddNode("DequantizeLinear", input_args, {output_arg}, domain); } + template + typename std::enable_if::value, Node&>::type + AddDequantizeLinearNode(NodeArg* input_arg, + const std::vector& input_scales, + const std::vector& input_zero_points, + NodeArg* output_arg, + const NodeAttributes* attributes = nullptr, + bool use_ms_domain = false) { + std::vector input_args; + input_args.push_back(input_arg); + input_args.push_back(Make1DInitializer(input_scales)); + input_args.push_back(Make1DInitializer(input_zero_points)); + + std::string domain = use_ms_domain ? kMSDomain : ""; + return AddNode("DequantizeLinear", input_args, {output_arg}, domain, attributes); + } + Node& AddDequantizeLinearNode(NodeArg* input_arg, float input_scale, NodeArg* output_arg, @@ -339,6 +387,19 @@ class ModelTestBuilder { return AddNode("DequantizeLinear", input_args, {output_arg}, domain); } + Node& AddDequantizeLinearNode(NodeArg* input_arg, + const std::vector& input_scales, + NodeArg* output_arg, + const NodeAttributes* attributes = nullptr, + bool use_ms_domain = false) { + std::vector input_args; + input_args.push_back(input_arg); + input_args.push_back(Make1DInitializer(input_scales)); + + std::string domain = use_ms_domain ? kMSDomain : ""; + return AddNode("DequantizeLinear", input_args, {output_arg}, domain, attributes); + } + template Node& AddQLinearConvNode(NodeArg* input_arg, float input_scale, diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 5cb4633dadd46..414a0fbeb78f5 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -40,8 +40,21 @@ AddQDQNodePairWithOutputAsGraphOutput(ModelTestBuilder& builder, NodeArg* q_inpu return dq_output; } +template +typename std::enable_if::value, NodeArg*>::type +AddQDQNodePair(ModelTestBuilder& builder, NodeArg* q_input, const std::vector& scales, + const std::vector& zero_points, const NodeAttributes* q_attrs = nullptr, + const NodeAttributes* dq_attrs = nullptr, bool use_ms_domain = false) { + auto* q_output = builder.MakeIntermediate(); + auto* dq_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(q_input, scales, zero_points, q_output, q_attrs, use_ms_domain); + builder.AddDequantizeLinearNode(q_output, scales, zero_points, dq_output, dq_attrs, use_ms_domain); + return dq_output; +} + template -GetQDQTestCaseFn BuildQDQConvTransposeTestCase(const std::vector& input_shape, const std::vector& weights_shape) { +GetQDQTestCaseFn BuildQDQConvTransposeTestCase(const std::vector& input_shape, + const std::vector& weights_shape) { return [input_shape, weights_shape](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput(input_shape, -1.f, 1.f); auto* output_arg = builder.MakeOutput(); @@ -71,7 +84,8 @@ GetQDQTestCaseFn BuildQDQConvTransposeTestCase(const std::vector& input dq_w_output); auto* dq_bias_output = builder.MakeIntermediate(); - auto* bias = builder.MakeInitializer({weights_shape[0]}, static_cast(0), static_cast(127)); + auto* bias = builder.MakeInitializer({weights_shape[0]}, static_cast(0), + static_cast(127)); builder.AddDequantizeLinearNode(bias, .0012f, 0, dq_bias_output); @@ -126,7 +140,8 @@ GetQDQTestCaseFn BuildQDQConvTestCase(const std::vector& input_shape, use_contrib_qdq); auto* dq_bias_output = builder.MakeIntermediate(); - auto* bias = builder.MakeInitializer({weights_shape[0]}, static_cast(0), static_cast(127)); + auto* bias = builder.MakeInitializer({weights_shape[0]}, static_cast(0), + static_cast(127)); builder.AddDequantizeLinearNode(bias, .0012f, 0, dq_bias_output, @@ -389,7 +404,8 @@ GetQDQTestCaseFn BuildConsolidationTestCase( const int64_t& axis, bool use_contrib_qdq = false) { return [input_shape, axis, use_contrib_qdq](ModelTestBuilder& builder) { - auto* input_arg = builder.MakeInput(input_shape, std::numeric_limits::min(), std::numeric_limits::max()); + auto* input_arg = builder.MakeInput(input_shape, std::numeric_limits::min(), + std::numeric_limits::max()); InputType dq_zp = std::numeric_limits::max() / 2; OutputType q_zp = std::numeric_limits::max() / 2; auto* upper_dq_output = builder.MakeIntermediate(); @@ -447,7 +463,8 @@ GetQDQTestCaseFn BuildDoubleQDQTestCases(Type1 zp_1, Type2 zp_2, Type3 zp_3, Typ template GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index, bool use_contrib_qdq = false) { return [=](ModelTestBuilder& builder) { - auto* input_arg = builder.MakeInput({2, 3, 4}, std::numeric_limits::min(), std::numeric_limits::max()); + auto* input_arg = builder.MakeInput({2, 3, 4}, std::numeric_limits::min(), + std::numeric_limits::max()); T zp = (std::numeric_limits::max() - std::numeric_limits::min()) / 2; float scale = 0.003f; std::vector outputs(4); @@ -632,7 +649,8 @@ GetQDQTestCaseFn BuildQDQConcatTestCase(const std::vector>& GetQDQTestCaseFn BuildQDQConcatTestCaseUnsupportedInputScaleZp(); -GetQDQTestCaseFn BuildQDQMatMulTestCase(const std::vector& input1_shape, const std::vector& input2_shape); +GetQDQTestCaseFn BuildQDQMatMulTestCase(const std::vector& input1_shape, + const std::vector& input2_shape); template GetQDQTestCaseFn BuildQDQGemmTestCase(const std::vector& input1_shape, @@ -673,7 +691,8 @@ GetQDQTestCaseFn BuildQDQGemmTestCase(const std::vector& input1_shape, if (has_bias) { auto* dq_bias_output = builder.MakeIntermediate(); - auto* bias = builder.MakeInitializer({input2_shape[0]}, static_cast(0), static_cast(127)); + auto* bias = builder.MakeInitializer({input2_shape[0]}, static_cast(0), + static_cast(127)); builder.AddDequantizeLinearNode(bias, 0.00156f, 0, dq_bias_output); diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 1cd8498ea1d37..0eeeccf4fe7d6 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -3,8 +3,10 @@ #if !defined(ORT_MINIMAL_BUILD) +#include #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -20,9 +22,10 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons const std::vector& strides, const std::vector& pads, const std::vector& dilations, + std::optional group, const std::string& auto_pad = "NOTSET") { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, auto_pad](ModelTestBuilder& builder) { + dilations, group, auto_pad](ModelTestBuilder& builder) { std::vector conv_inputs = { MakeTestInput(builder, input_def), MakeTestInput(builder, weights_def)}; @@ -33,19 +36,23 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons auto* output = builder.MakeOutput(); - Node& convNode = builder.AddNode(conv_op_type, conv_inputs, {output}); - convNode.AddAttribute("auto_pad", auto_pad); + Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {output}); + conv_node.AddAttribute("auto_pad", auto_pad); + + if (group.has_value()) { + conv_node.AddAttribute("group", group.value()); + } if (!pads.empty() && auto_pad == "NOTSET") { - convNode.AddAttribute("pads", pads); + conv_node.AddAttribute("pads", pads); } if (!strides.empty()) { - convNode.AddAttribute("strides", strides); + conv_node.AddAttribute("strides", strides); } if (!dilations.empty()) { - convNode.AddAttribute("dilations", dilations); + conv_node.AddAttribute("dilations", dilations); } }; } @@ -58,6 +65,7 @@ static void RunCPUConvOpTest(const std::string& conv_op_type, const TestInputDef const std::vector& strides, const std::vector& pads, const std::vector& dilations, + std::optional group, const std::string& auto_pad, ExpectedEPNodeAssignment expected_ep_assignment, int opset = 13, @@ -69,8 +77,9 @@ static void RunCPUConvOpTest(const std::string& conv_op_type, const TestInputDef #else provider_options["backend_path"] = "libQnnCpu.so"; #endif - - RunQnnModelTest(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, auto_pad), + auto build_fn = BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, + dilations, group, auto_pad); + RunQnnModelTest(build_fn, provider_options, opset, expected_ep_assignment, @@ -86,11 +95,12 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string const std::vector& strides, const std::vector& pads, const std::vector& dilations, + std::optional group, const std::string& auto_pad = "NOTSET", bool use_contrib_qdq = false) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, auto_pad, use_contrib_qdq](ModelTestBuilder& builder, - std::vector>& output_qparams) { + dilations, group, auto_pad, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> @@ -120,6 +130,104 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string conv_node.AddAttribute("auto_pad", auto_pad); + if (group.has_value()) { + conv_node.AddAttribute("group", group.value()); + } + + if (!pads.empty() && auto_pad == "NOTSET") { + conv_node.AddAttribute("pads", pads); + } + if (!strides.empty()) { + conv_node.AddAttribute("strides", strides); + } + if (!dilations.empty()) { + conv_node.AddAttribute("dilations", dilations); + } + + AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +template +static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false) { + return [conv_op_type, input_def, weights_def, bias_def, strides, pads, + dilations, group, auto_pad, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector conv_inputs; + + // input -> Q/DQ -> + auto* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + conv_inputs.push_back(input_qdq); + + // Quantized(weights) -> DQ -> + ORT_ENFORCE(weights_def.IsInitializer() && weights_def.IsRawData()); + int64_t weight_quant_axis = conv_op_type == "Conv" ? 0 : 1; // 0 for Conv, 1 for ConvTranspose + std::vector weight_scales; + std::vector weight_zero_points; + GetTestInputQuantParamsPerChannel(weights_def, weight_scales, weight_zero_points, + static_cast(weight_quant_axis), true); + + TensorShape weights_shape = weights_def.GetTensorShape(); + std::vector quantized_weights(weights_shape.Size()); + QuantizeValues(weights_def.GetRawData(), quantized_weights, weights_shape, + weight_scales, weight_zero_points, weight_quant_axis); + + NodeArg* weights_initializer = builder.MakeInitializer(weights_def.GetShape(), quantized_weights); + NodeArg* weights_dq = builder.MakeIntermediate(); + Node& weights_dq_node = builder.AddDequantizeLinearNode(weights_initializer, weight_scales, + weight_zero_points, weights_dq, + nullptr, use_contrib_qdq); + weights_dq_node.AddAttribute("axis", weight_quant_axis); + conv_inputs.push_back(weights_dq); + + // Quantized(bias) -> DQ -> + if (!bias_def.GetShape().empty()) { + // Bias requirement taken from python quantization tool: onnx_quantizer.py::quantize_bias_static() + // bias_scale = input_scale * weight_scale + // bias_zero_point = 0 + ORT_ENFORCE(bias_def.IsInitializer() && bias_def.IsRawData()); + std::vector bias_scales = weight_scales; + std::vector bias_zero_points(weight_scales.size(), 0); + for (size_t i = 0; i < bias_scales.size(); i++) { + bias_scales[i] *= input_qparams.scale; + } + + TensorShape bias_shape = bias_def.GetTensorShape(); + std::vector quantized_biases(bias_shape.Size()); + QuantizeValues(bias_def.GetRawData(), quantized_biases, bias_shape, bias_scales, + bias_zero_points, 0); + + NodeArg* bias_initializer = builder.MakeInitializer(bias_def.GetShape(), quantized_biases); + NodeArg* bias_dq = builder.MakeIntermediate(); + Node& bias_dq_node = builder.AddDequantizeLinearNode(bias_initializer, bias_scales, bias_zero_points, + bias_dq, nullptr, use_contrib_qdq); + + bias_dq_node.AddAttribute("axis", static_cast(0)); + conv_inputs.push_back(bias_dq); + } + + auto* conv_output = builder.MakeIntermediate(); + Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {conv_output}); + + conv_node.AddAttribute("auto_pad", auto_pad); + + if (group.has_value()) { + conv_node.AddAttribute("group", group.value()); + } + if (!pads.empty() && auto_pad == "NOTSET") { conv_node.AddAttribute("pads", pads); } @@ -144,6 +252,7 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef const std::vector& strides, const std::vector& pads, const std::vector& dilations, + std::optional group, const std::string& auto_pad, ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq = false, @@ -158,16 +267,47 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef #endif TestQDQModelAccuracy(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - auto_pad), + group, auto_pad), BuildQDQConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - auto_pad, use_contrib_qdq), + group, auto_pad, use_contrib_qdq), provider_options, opset, expected_ep_assignment, tolerance); } +// Runs a QDQ Conv model (per-axis quantization on weight/bias) on the QNN HTP backend. +// Checks the graph node assignment, and that inference outputs for QNN EP and CPU EP match. +template +static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad, + ExpectedEPNodeAssignment expected_ep_assignment, + bool use_contrib_qdq = false, + int opset = 13, + QDQTolerance tolerance = QDQTolerance()) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_fn = BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, + group, auto_pad); + auto qdq_fn = BuildQDQPerChannelConvTestCase(conv_op_type, input_def, weights_def, + bias_def, strides, pads, dilations, + group, auto_pad, use_contrib_qdq); + TestQDQModelAccuracy(f32_fn, qdq_fn, provider_options, opset, expected_ep_assignment, tolerance); +} + // Check that QNN compiles DQ -> Conv -> Q as a single unit. // Tests bias as a dynamic input. // TODO: Segfaults when calling graphFinalize(). v2.13 @@ -179,6 +319,7 @@ TEST_F(QnnCPUBackendTests, DISABLED_Convf32_dynamic_bias) { {1, 1}, // default strides {0, 0, 0, 0}, // default pads {1, 1}, // default dilations + 1, // default group "NOTSET", // No auto-padding ExpectedEPNodeAssignment::All); } @@ -193,6 +334,7 @@ TEST_F(QnnCPUBackendTests, Convf32_bias_initializer) { {1, 1}, // default strides {0, 0, 0, 0}, // default pads {1, 1}, // default dilations + 1, // default group "NOTSET", // No auto-padding ExpectedEPNodeAssignment::All); } @@ -206,6 +348,7 @@ TEST_F(QnnCPUBackendTests, Convf32_AutoPadUpper) { {1, 1}, // strides {}, // pads {1, 1}, // dilations + 1, // default group "SAME_UPPER", // auto_pad ExpectedEPNodeAssignment::All); } @@ -219,6 +362,7 @@ TEST_F(QnnCPUBackendTests, ConvTransposef32_AutoPadUpper) { {1, 1}, // strides {}, // pads {1, 1}, // dilations + 1, // default group "SAME_UPPER", // auto_pad ExpectedEPNodeAssignment::All); } @@ -232,6 +376,7 @@ TEST_F(QnnCPUBackendTests, Convf32_AutoPadLower) { {1, 1}, // strides {}, // pads {1, 1}, // dilations + 1, // default group "SAME_LOWER", // auto_pad ExpectedEPNodeAssignment::All); } @@ -245,6 +390,7 @@ TEST_F(QnnCPUBackendTests, ConvTransposef32_AutoPadLower) { {1, 1}, // strides {}, // pads {1, 1}, // dilations + 1, // default group "SAME_LOWER", // auto_pad ExpectedEPNodeAssignment::All); } @@ -258,6 +404,7 @@ TEST_F(QnnCPUBackendTests, Convf32_large_input1_pad_bias_initializer) { {1, 1}, {1, 1, 1, 1}, {1, 1}, + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, 13, @@ -280,6 +427,7 @@ TEST_F(QnnCPUBackendTests, Convf32_large_input2_nopad_bias_initializer) { {1, 1}, {0, 0, 0, 0}, {1, 1}, + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, 13, // opset @@ -296,6 +444,7 @@ TEST_F(QnnCPUBackendTests, Conv1Df32_StaticWeights_DefaultBias) { {1}, // Strides {0, 0}, // Pads {1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All); } @@ -310,6 +459,7 @@ TEST_F(QnnCPUBackendTests, Conv1Df32_DynamicWeights_DefaultBias) { {1}, // Strides {0, 0}, // Pads {1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All); } @@ -324,6 +474,7 @@ TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_StaticWeights_DefaultBias) { {1}, // Strides {0, 0}, // Pads {1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All); } @@ -338,6 +489,7 @@ TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_DynamicWeights_DefaultBias) { {1}, // Strides {0, 0}, // Pads {1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All); } @@ -363,7 +515,8 @@ TEST_F(QnnHTPBackendTests, Test_QDQConvWithDynamicWeightsFromMul) { auto BuildConvMulGraph = [](ModelTestBuilder& builder) { // DQ node for Conv input auto* dq_i_output = builder.MakeIntermediate(); - auto* conv_dq_input = builder.MakeInitializer({1, 32, 16, 113}, static_cast(0), static_cast(127)); + auto* conv_dq_input = builder.MakeInitializer({1, 32, 16, 113}, static_cast(0), + static_cast(127)); // DQ node for Conv bias auto* dq_bias_output = builder.MakeIntermediate(); @@ -375,7 +528,8 @@ TEST_F(QnnHTPBackendTests, Test_QDQConvWithDynamicWeightsFromMul) { auto* mul_input1 = builder.MakeInput({16, 32, 1, 1}, static_cast(0), static_cast(127)); auto* mul_dq2_output = builder.MakeIntermediate(); - auto* mul_input2 = builder.MakeInitializer({16, 1, 1, 1}, static_cast(0), static_cast(127)); + auto* mul_input2 = builder.MakeInitializer({16, 1, 1, 1}, static_cast(0), + static_cast(127)); builder.AddDequantizeLinearNode(mul_input1, .03f, 0, mul_dq1_output); builder.AddDequantizeLinearNode(mul_input2, .03f, 0, mul_dq2_output); @@ -420,6 +574,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_dynamic_input) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops @@ -428,6 +583,170 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_dynamic_input) { QDQTolerance(0.00413f)); } +// Test per-channel QDQ Conv. in0: u8, in1 (weight): s8, in2 (bias): s32, out: u8 +TEST_F(QnnHTPBackendTests, ConvU8S8S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 13); // opset +} + +// Test per-channel QDQ Conv that maps to QNN's DepthwiseConv2d (input_chans == output_chans == group). +// in0: u8, in1 (weight): s8, in2 (bias): s32, out: u8 +TEST_F(QnnHTPBackendTests, ConvDepthwiseU8S8S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; // (N, C, H, W) + std::vector weight_shape = {2, 1, 2, 2}; // (C, M/group, kH, kW) + std::vector bias_shape = {2}; // (M) + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 2, // group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 13); // opset +} + +// Test per-channel QDQ ConvTranspose. in0: u8, in1 (weight): s8, in2 (bias): s32, out: u8 +TEST_F(QnnHTPBackendTests, ConvTransposeU8S8S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {2, 3, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("ConvTranspose", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 13); // opset +} + +// Test per-channel QDQ Conv. in0: u16, in1 (weight): s8, in2 (bias): s32, out: u16 +TEST_F(QnnHTPBackendTests, ConvU16S8S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // use_qdq_contrib_ops + 13); // opset +} + +// Test per-channel QDQ ConvTranspose. in0: u16, in1 (weight): s8, in2 (bias): s32, out: u16 +TEST_F(QnnHTPBackendTests, ConvTransposeU16S8S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {2, 3, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("ConvTranspose", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // use_qdq_contrib_ops + 13); // opset +} + +// Test per-channel QDQ Conv that maps to QNN's DepthwiseConv2d (input_chans == output_chans == group). +// in0: u16, in1 (weight): s8, in2 (bias): s32, out: u16 +TEST_F(QnnHTPBackendTests, ConvDepthwiseU16S8S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; // (N, C, H, W) + std::vector weight_shape = {2, 1, 2, 2}; // (C, M/group, kH, kW) + std::vector bias_shape = {2}; // (M) + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 2, // group + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // use_qdq_contrib_ops + 13); // opset +} + // Tests 16-bit QDQ Conv with dynamic weights and bias (uses QNN's Conv2d) // TODO: Inaccuracy detected for output 'output', element 0. // Output quant params: scale=0.0040235077030956745, zero_point=0. @@ -444,6 +763,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S16S32_DynamicBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true); // Use com.microsoft QDQ ops for 16-bit @@ -461,6 +781,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_DepthwiseConvU16S16S32_DynamicBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true); // Use com.microsoft QDQ ops for 16-bit @@ -482,6 +803,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S16S32_NoBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true); // Use com.microsoft QDQ ops for 16-bit @@ -499,6 +821,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_DepthwiseConvU16S16S32_NoBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true); // Use com.microsoft QDQ ops for 16-bit @@ -521,6 +844,7 @@ TEST_F(QnnHTPBackendTests, DepthwiseConvU16U8S32_StaticBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true, // Use com.microsoft QDQ ops for 16-bit @@ -543,6 +867,7 @@ TEST_F(QnnHTPBackendTests, ConvU16U8S32_StaticBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true, // Use com.microsoft QDQ ops for 16-bit @@ -566,6 +891,7 @@ TEST_F(QnnHTPBackendTests, DepthwiseConvU16U8S32_DynamicBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true, // Use com.microsoft QDQ ops for 16-bit @@ -588,6 +914,7 @@ TEST_F(QnnHTPBackendTests, ConvU16U8S32_DynamicBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true, // Use com.microsoft QDQ ops for 16-bit @@ -610,6 +937,7 @@ TEST_F(QnnHTPBackendTests, ConvU16U8S32_NoBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true, // Use com.microsoft QDQ ops for 16-bit @@ -633,6 +961,7 @@ TEST_F(QnnHTPBackendTests, DepthwiseConvU16U8S32_NoBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, true, // Use com.microsoft QDQ ops for 16-bit @@ -649,6 +978,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_DynamicWeight_NoBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All); } @@ -663,6 +993,7 @@ TEST_F(QnnHTPBackendTests, ConvTransposeU8U8S32_DynamicWeight_NoBias) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All); } @@ -677,6 +1008,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_initializer) { {1, 1}, // Strides {0, 0, 0, 0}, // Pads {1, 1}, // Dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops @@ -695,6 +1027,7 @@ TEST_F(QnnHTPBackendTests, Conv1DU8U8S32_bias_initializer) { {1}, // strides {0, 0}, // pads {1}, // dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All); } @@ -709,6 +1042,7 @@ TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_bias_initializer) { {1}, // strides {0, 0}, // pads {1}, // dilations + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All); } @@ -722,6 +1056,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_AutoPadUpper) { {1, 1}, // strides {}, // pads {1, 1}, // dilations + 1, // default group "SAME_UPPER", // auto_pad ExpectedEPNodeAssignment::All, false, // use_contrib_qdq @@ -738,6 +1073,7 @@ TEST_F(QnnHTPBackendTests, Conv1DU8U8S32_AutoPadUpper) { {1}, // strides {0}, // pads {1}, // dilations + 1, // default group "SAME_UPPER", // auto_pad ExpectedEPNodeAssignment::All, false, // use_contrib_qdq @@ -754,6 +1090,7 @@ TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_AutoPadUpper) { {1}, // strides {0}, // pads {1}, // dilations + 1, // default group "SAME_UPPER", // auto_pad ExpectedEPNodeAssignment::All, false, // use_contrib_qdq @@ -769,6 +1106,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_AutoPadLower) { {1, 1}, // strides {}, // pads {1, 1}, // dilations + 1, // default group "SAME_LOWER", // auto_pad ExpectedEPNodeAssignment::All, false, // use_contrib_qdq @@ -784,6 +1122,7 @@ TEST_F(QnnHTPBackendTests, ConvTransposeU8U8S32_AutoPadLower) { {1, 1}, // strides {}, // pads {1, 1}, // dilations + 1, // default group "SAME_LOWER", // auto_pad ExpectedEPNodeAssignment::All, false, // use_contrib_qdq @@ -800,6 +1139,7 @@ TEST_F(QnnHTPBackendTests, Conv1DU8U8S32_AutoPadLower) { {1}, // strides {0}, // pads {1}, // dilations + 1, // default group "SAME_LOWER", // auto_pad ExpectedEPNodeAssignment::All, false, // use_contrib_qdq @@ -816,6 +1156,7 @@ TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_AutoPadLower) { {1}, // strides {0}, // pads {1}, // dilations + 1, // default group "SAME_LOWER", // auto_pad ExpectedEPNodeAssignment::All, false, // use_contrib_qdq @@ -830,6 +1171,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input1_padding_bias_initializer) { {1, 1}, {1, 1, 1, 1}, {1, 1}, + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops @@ -852,6 +1194,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input2_bias_initializer) { {1, 1}, {0, 0, 0, 0}, {1, 1}, + 1, // default group "NOTSET", ExpectedEPNodeAssignment::All, false, @@ -867,6 +1210,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_LargeInput_Dilations_Pads) { {2, 2}, // strides {3, 3, 3, 3}, // pads {1, 1}, // dilations + 1, // default group "NOTSET", // auto_pad ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index c0cfe3f0342fd..c474e989243ad 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -4,10 +4,12 @@ #pragma once #if !defined(ORT_MINIMAL_BUILD) -#include #include +#include +#include #include #include "core/framework/provider_options.h" +#include "core/framework/tensor_shape.h" #include "core/util/qmath.h" #include "test/optimizer/qdq_test_utils.h" @@ -30,7 +32,7 @@ struct QuantParams { float scale; QType zero_point; - static QuantParams Compute(float rmin, float rmax) { + static QuantParams Compute(float rmin, float rmax, bool symmetric = false) { // Ensure a minimum range of 0.0001 (required by QNN) rmax = std::max(rmax, rmin + 0.0001f); @@ -41,8 +43,23 @@ struct QuantParams { constexpr float qmin = static_cast(std::numeric_limits::min()); constexpr float qmax = static_cast(std::numeric_limits::max()); - const float scale = rmax == rmin ? 1.0f : (rmax - rmin) / (qmax - qmin); - const float initial_zero_point = qmin - (rmin / scale); + if (symmetric) { + const float abs_max = std::max(std::abs(rmin), std::abs(rmax)); + rmax = abs_max; + rmin = -abs_max; + } + + const float scale = (rmax - rmin) / (qmax - qmin); + float initial_zero_point = 0.0f; + + if (symmetric) { + // Symmetric uses same formula for zero-point as asymmetric, but we can cancel out terms for + // increased numerical accuracy. + initial_zero_point = (qmin + qmax) / 2.0f; + } else { + initial_zero_point = qmin - (rmin / scale); + } + const QType zero_point = static_cast(RoundHalfToEven(std::max(qmin, std::min(qmax, initial_zero_point)))); return QuantParams{scale, zero_point}; @@ -55,11 +72,12 @@ struct QuantParams { // range of output values. Note that the function is able to overwrite the output_qparams parameter if necessary // (Example: MaxPool must have identical input and output quantization params). template -using GetTestQDQModelFn = std::function>& output_qparams)>; +using GetTestQDQModelFn = std::function>& output_qparams)>; // Computes quantization parameters for an array of floating-point values. template -inline QuantParams GetDataQuantParams(gsl::span data) { +inline QuantParams GetDataQuantParams(gsl::span data, bool symmetric = false) { // Get min/max of raw data. float min_val = std::numeric_limits::max(); float max_val = std::numeric_limits::min(); @@ -69,7 +87,7 @@ inline QuantParams GetDataQuantParams(gsl::span data) { max_val = std::max(max_val, val); } - return QuantParams::Compute(min_val, max_val); + return QuantParams::Compute(min_val, max_val, symmetric); } /** @@ -150,6 +168,10 @@ struct TestInputDef { return shape_; } + const TensorShape GetTensorShape() const { + return TensorShape(shape_); + } + bool IsInitializer() const { return is_initializer_; } @@ -201,6 +223,42 @@ struct TestInputDef { return range; } + std::vector> GetRangePerChannel(size_t axis) const { + auto which_type = data_info_.index(); + const size_t num_ranges = static_cast(shape_.at(axis)); + + // Random. All axis dims get the same ranges (rand_min -> rand_max) + if (which_type == 1) { + RandomData rand_info = std::get(data_info_); + return std::vector>(num_ranges, std::pair(rand_info.min, rand_info.max)); + } + + // Raw data. Get min/max per axis dim val + assert(which_type == 0); + + const std::vector& raw_data = std::get(data_info_).data; + std::pair init_range(std::numeric_limits::max(), std::numeric_limits::min()); + std::vector> per_axis_ranges(num_ranges, init_range); + TensorShape shape(shape_); + size_t num_blocks = shape.SizeToDimension(axis); + size_t block_size = shape.SizeFromDimension(axis + 1); + + size_t i = 0; + for (size_t n = 0; n < num_blocks; n++) { + for (size_t r = 0; r < num_ranges; r++) { + for (size_t j = 0; j < block_size; j++) { + std::pair& range = per_axis_ranges[r]; + range.first = std::min(range.first, raw_data[i]); + range.second = std::max(range.second, raw_data[i]); + i++; + } + } + } + assert(i == raw_data.size()); + + return per_axis_ranges; + } + private: std::vector shape_; std::variant data_info_; @@ -210,9 +268,64 @@ struct TestInputDef { }; template -inline QuantParams GetTestInputQuantParams(const TestInputDef& input_def) { +inline QuantParams GetTestInputQuantParams(const TestInputDef& input_def, bool symmetric = false) { const std::pair frange = input_def.GetRange(); - return QuantParams::Compute(frange.first, frange.second); + return QuantParams::Compute(frange.first, frange.second, symmetric); +} + +template +static void GetTestInputQuantParamsPerChannel(const TestInputDef& input_def, std::vector& scales, + std::vector& zero_points, size_t axis, bool symmetric = false) { + const auto f32_ranges = input_def.GetRangePerChannel(axis); + + scales.reserve(f32_ranges.size()); + zero_points.reserve(f32_ranges.size()); + + for (const auto& range : f32_ranges) { + QuantParams params = QuantParams::Compute(range.first, range.second, symmetric); + scales.push_back(params.scale); + zero_points.push_back(params.zero_point); + } +} + +template +static void QuantizeValues(gsl::span input, gsl::span output, const TensorShape& shape, + gsl::span scales, gsl::span zero_points, + std::optional axis) { + const size_t input_rank = shape.NumDimensions(); + const size_t num_elems = static_cast(shape.Size()); + ORT_ENFORCE(input.size() == num_elems); + ORT_ENFORCE(output.size() == num_elems); + + size_t block_count = 1; + size_t broadcast_dim = 1; + size_t block_size = num_elems; + + if (axis.has_value()) { + size_t axis_no_neg = *axis < 0 ? static_cast(*axis) + input_rank : static_cast(*axis); + block_count = shape.SizeToDimension(axis_no_neg); + broadcast_dim = shape[axis_no_neg]; + block_size = shape.SizeFromDimension(axis_no_neg + 1); + } + + ORT_ENFORCE(scales.size() == broadcast_dim); + ORT_ENFORCE(zero_points.empty() || zero_points.size() == broadcast_dim); + + size_t i = 0; + + for (size_t n = 0; n < block_count; n++) { + for (size_t bd = 0; bd < broadcast_dim; bd++) { + QuantType zp = zero_points.empty() ? static_cast(0) : zero_points[bd]; + if constexpr (std::is_same_v) { + for (size_t e = 0; e < block_size; e++) { + output[i + e] = static_cast(input[i + e] / scales[bd]) + zp; + } + } else { + ParQuantizeLinearStd(&input[i], &output[i], block_size, scales[bd], zp, nullptr); + } + i += block_size; + } + } } /** @@ -281,8 +394,8 @@ struct QDQTolerance { * \param qnn_options QNN EP provider options. * \param opset_version The opset version. * \param expected_ep_assignment Describes "which nodes" should be assigned to the EP. - * \param tolerance The percent tolerance (as fraction) QNN EP results are allowed to differ from the QDQ model on CPU EP. - * This tolerance is a percentage of the output range. + * \param tolerance The percent tolerance (as fraction) QNN EP results are allowed to differ from the QDQ model + * on CPU EP. This tolerance is a percentage of the output range. * \param log_severity The logger's severity setting. */ template @@ -482,8 +595,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe * \param qnn_options QNN EP provider options. * \param opset_version The opset version. * \param expected_ep_assignment Describes "which nodes" should be assigned to the EP. - * \param tolerance The percent tolerance (as fraction) QNN EP results are allowed to differ from the FP16 model on CPU EP. - * This tolerance is a percentage of the output range. + * \param tolerance The percent tolerance (as fraction) QNN EP results are allowed to differ from the FP16 model + * on CPU EP. This tolerance is a percentage of the output range. * \param log_severity The logger's severity setting. */ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, @@ -708,9 +821,10 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef manual quantization (int32) => DQ => final float bias NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef& bias_def, float bias_scale, @@ -767,12 +881,13 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, * \returns A model building function. */ template -inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, - const std::vector>& quant_input_defs, - const std::vector>& non_quant_input_defs, - const std::vector& attrs, - const std::string& op_domain = kOnnxDomain, - bool use_contrib_qdq = false) { +inline GetTestQDQModelFn BuildQDQOpTestCase( + const std::string& op_type, + const std::vector>& quant_input_defs, + const std::vector>& non_quant_input_defs, + const std::vector& attrs, + const std::string& op_domain = kOnnxDomain, + bool use_contrib_qdq = false) { return [op_type, quant_input_defs, non_quant_input_defs, attrs, op_domain, use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 77f20b3caed96..ff97e04fb7fdf 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -344,11 +344,11 @@ def test_qdq_overrides_per_channel1(self): extra_options={ "TensorQuantOverrides": { "WGT": [ - {"zero_point": zp_vals[0], "scale": scale_vals[0]}, + {"axis": 0, "zero_point": zp_vals[0], "scale": scale_vals[0]}, {"zero_point": zp_vals[1], "scale": scale_vals[1]}, ], "BIAS": [ - {"zero_point": zp_vals[0], "scale": scale_vals[0]}, + {"axis": 0, "zero_point": zp_vals[0], "scale": scale_vals[0]}, {"zero_point": zp_vals[1], "scale": scale_vals[1]}, ], } @@ -373,55 +373,58 @@ def test_qdq_overrides_per_channel2(self): """ Test per-channel overriding of rmin, rmax, reduce_range, and quant_type for Conv weight. """ - rmin_vals = [0.0, 0.2] - rmax_vals = [1.0, 0.8] - quant_type = QuantType.QUInt8 - reduce_ranges = [True, False] - ( - _, - _, - _, - _, - wgt_zp, - wgt_sc, - bias_zp, - bias_sc, - _, - _, - ) = self.perform_qdq_quantization( - "model_per_channel_quant_overrides2.onnx", - extra_options={ - "TensorQuantOverrides": { - "WGT": [ - { - "quant_type": quant_type, - "rmin": np.array(rmin_vals[0], dtype=np.float32), - "rmax": np.array(rmax_vals[0], dtype=np.float32), - "reduce_range": reduce_ranges[0], - }, - { - "quant_type": quant_type, - "rmin": np.array(rmin_vals[1], dtype=np.float32), - "rmax": np.array(rmax_vals[1], dtype=np.float32), - "reduce_range": reduce_ranges[1], - }, - ], - } - }, - per_channel=True, - ) + for reduce_range in (False, True): + with self.subTest(reduce_range=reduce_range): + qdq_model_name = f"model_per_chan_overrides_2_reduce_range_{reduce_range}.onnx" + rmin_vals = [0.0, 0.2] + rmax_vals = [1.0, 0.8] + quant_type = QuantType.QUInt8 + ( + _, + _, + _, + _, + wgt_zp, + wgt_sc, + bias_zp, + bias_sc, + _, + _, + ) = self.perform_qdq_quantization( + qdq_model_name, + extra_options={ + "TensorQuantOverrides": { + "WGT": [ + { + "axis": 0, + "quant_type": quant_type, + "rmin": np.array(rmin_vals[0], dtype=np.float32), + "rmax": np.array(rmax_vals[0], dtype=np.float32), + "reduce_range": reduce_range, + }, + { + "quant_type": quant_type, + "rmin": np.array(rmin_vals[1], dtype=np.float32), + "rmax": np.array(rmax_vals[1], dtype=np.float32), + "reduce_range": reduce_range, + }, + ], + } + }, + per_channel=True, + ) - self.assertEqual(wgt_zp.data_type, quant_type.tensor_type) - for index, (zp, scale) in enumerate(zip(wgt_zp.int32_data, wgt_sc.float_data)): - wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type, reduce_range=reduce_ranges[index]) - expected_zp, expected_scale = compute_scale_zp( - np.array(rmin_vals[index], dtype=np.float32), - np.array(rmax_vals[index], dtype=np.float32), - wgt_qmin, - wgt_qmax, - ) - self.assertEqual(zp, expected_zp) - self.assertEqual(scale, np.float32(expected_scale)) + self.assertEqual(wgt_zp.data_type, quant_type.tensor_type) + for index, (zp, scale) in enumerate(zip(wgt_zp.int32_data, wgt_sc.float_data)): + wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type, reduce_range=reduce_range) + expected_zp, expected_scale = compute_scale_zp( + np.array(rmin_vals[index], dtype=np.float32), + np.array(rmax_vals[index], dtype=np.float32), + wgt_qmin, + wgt_qmax, + ) + self.assertEqual(zp, expected_zp) + self.assertEqual(scale, np.float32(expected_scale)) def test_16bit_overrides_set_ms_domain(self): """ @@ -503,7 +506,7 @@ def test_override_validation_bad_combination(self): }, ) - self.assertIn("option 'rmax' is invalid with 'scale' and 'zero_point'", str(context.exception)) + self.assertIn("option(s) [rmax] are invalid with 'scale' and 'zero_point'", str(context.exception)) with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -521,7 +524,7 @@ def test_override_validation_bad_combination(self): }, ) - self.assertIn("Tensor override option 'rmax' is invalid with 'scale' and 'zero_point'", str(context.exception)) + self.assertIn("option(s) [rmax] are invalid with 'scale' and 'zero_point'", str(context.exception)) with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -539,7 +542,7 @@ def test_override_validation_bad_combination(self): }, ) - self.assertIn("option 'symmetric' is invalid with 'scale' and 'zero_point'", str(context.exception)) + self.assertIn("option(s) [symmetric] are invalid with 'scale' and 'zero_point'", str(context.exception)) with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -557,7 +560,7 @@ def test_override_validation_bad_combination(self): }, ) - self.assertIn("option 'reduce_range' is invalid with 'scale' and 'zero_point'", str(context.exception)) + self.assertIn("option(s) [reduce_range] are invalid with 'scale' and 'zero_point'", str(context.exception)) def test_get_qnn_qdq_config_sigmoid(self): """ @@ -875,6 +878,86 @@ def test_get_qnn_qdq_config_matmul(self): self.assertEqual(weight_is_symmetric, qnn_config.extra_options["WeightSymmetric"]) + def test_get_qnn_qdq_config_matmul_per_channel(self): + """ + When per_channel is enabled, test that the QNN-specific configs explicitly override MatMul's + initializer inputs to use per-tensor quantization (QNN does not support per-channel MatMul). + """ + # Create float model with a Abs --> MatMul + graph = onnx.helper.make_graph( + [ + onnx.helper.make_node("Abs", ["input_0"], ["abs_0_out"], name="Abs_0"), + onnx.helper.make_node("MatMul", ["abs_0_out", "weight"], ["matmul_0_out"], name="MatMul_0"), + onnx.helper.make_node("Abs", ["matmul_0_out"], ["output_0"], name="Abs_1"), + ], + "matmul_graph", + [onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (2, 3))], + [onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (2, 2))], + initializer=[onnx.numpy_helper.from_array(np.random.random((3, 2)).astype(np.float32), "weight")], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + float_model_path = "model.onnx" + onnx.save_model(model, float_model_path) + + symmetric_wgt_qtypes = {QuantType.QInt8, QuantType.QInt16} + weight_override_16bit = {"weight": [{"quant_type": QuantType.QInt16, "symmetric": True}]} + + # Enumerate subtests (default_wgt_qtype, default_wgt_symmetric, other_override) + subtest_configs = [ + (QuantType.QUInt8, False, {}), + (QuantType.QInt8, True, {}), + (QuantType.QUInt8, None, {}), + (QuantType.QInt8, None, {}), + (QuantType.QInt8, None, weight_override_16bit), + ] + + # Test if MatMul's weight input is overridden to per-tensor correctly. + for default_wgt_qtype, default_wgt_symmetric, other_override in subtest_configs: + with self.subTest( + default_wgt_qtype=default_wgt_qtype, + default_wgt_symmetric=default_wgt_symmetric, + other_override=other_override, + ): + init_overrides = {} + init_overrides.update(other_override) + + qnn_config = get_qnn_qdq_config( + float_model_path, + DummyDataReader([]), + weight_type=default_wgt_qtype, + weight_symmetric=default_wgt_symmetric, + init_overrides=(init_overrides if init_overrides else None), + per_channel=True, + ) + + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "MatMul"}) + weight_is_symmetric = default_wgt_symmetric or default_wgt_qtype in symmetric_wgt_qtypes + + # User did not provide overrides for weight, so get_qnn_qdq_config() should set per-tensor overrides. + if not init_overrides: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["weight"], + [ + { + "quant_type": default_wgt_qtype, + "symmetric": weight_is_symmetric, + } + ], + ) + else: + # Should retain user's overrides. + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["weight"], weight_override_16bit["weight"] + ) + def test_get_qnn_qdq_config_layernorm(self): """ Test that the QNN-specific configs override LayerNorm's initializer input type to 8-bit if