Skip to content

Commit

Permalink
[QNN EP] Support per-channel quantized weights (#20154)
Browse files Browse the repository at this point in the history
### Description
- Adds general support for per-channel quantized weights to QNN EP (HTP
backend).
- Add QNN EP unit tests for per-channel Conv
- Update quantization tool to allow selecting which ops are quantized
per-channel (and which axis) via tensor-level overrides. Currently,
setting `per_channel=True` assumes all Convs, MatMuls, Gemms,
InstanceNormalization, and LayerNormalization ops should be quantized
per-channel using some assumed default axis.

#### Creating QDQ per-channel Conv model example
```python
from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize
from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config, qnn_preprocess_model

class DataReader(CalibrationDataReader):
    # TODO: See ONNX Runtime QNN docs for example of a data reader
    # https://onnxruntime.ai/docs/execution-providers/QNN-ExecutionProvider.html#generating-a-quantized-model-x64
    pass

if __name__ == "__main__":
    input_model_path = "model.onnx"
    my_data_reader = DataReader(model_to_quantize)

    # Pre-process the original float32 model.
    preproc_model_path = "model.preproc.onnx"
    model_changed = qnn_preprocess_model(input_model_path, preproc_model_path)
    model_to_quantize = preproc_model_path if model_changed else input_model_path

    # RELEVANT TO THIS PR:
    # Make sure Conv's weight input is quantized to int8/symmetric/per-channel with axis == 0.
    # The presence of the 'axis' key indicates that this is a per-channel quantized weight.
    init_overrides = {'weight': [{'axis': 0, 'quant_type': QuantType.QInt8, 'symmetric': True}]}

    qnn_config = get_qnn_qdq_config(model_to_quantize,
                                    my_data_reader,
                                    init_overrides=init_overrides,
                                    activation_type=QuantType.QUInt16, # uint16 activations
                                    weight_type=QuantType.QUInt8)      # uint8 weights by default

    quantize(model_to_quantize, "model.qdq.onnx", qnn_config)
```

float32 model:
<img width="683" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/19691973/ca650e49-1ad0-47d8-8c46-17fbc224ca39">

QDQ model (per-channel Conv weight):
<img width="748" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/19691973/6bd469f2-968b-4d11-9526-09b3e71f98e7">

### Motivation and Context
Support more models, especially models with int4 quantized weights.
  • Loading branch information
adrianlizarraga authored Apr 16, 2024
1 parent 08d208b commit f644ff9
Show file tree
Hide file tree
Showing 45 changed files with 2,064 additions and 581 deletions.
10 changes: 9 additions & 1 deletion onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,16 @@ std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::Node
// If we can find the node index in the dq or q nodes this is a quantized input/output
if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) {
const auto node_inputs = node.InputDefs();
const auto& node_attrs = node.GetAttributes();

// Get the Q or DQ axis attribute if available.
std::optional<int64_t> axis;
if (auto entry = node_attrs.find("axis"); entry != node_attrs.end()) {
axis = entry->second.i();
}

// quantization scale and zp are always the input[1, 2]
NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr};
NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr, axis};

if (is_input) {
// DQ is input to the target node, use the DstArgIndex
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/node_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ struct NodeGroup {
// If the optional quant_param is present, then this is a quantized input,
// otherwise this is a regular input
struct NodeUnitIODef {
// The quantization parameter, scale is manadatory, and zero_point is optional
// The quantization parameter. Scale is mandatory. Zero-point and axis are optional.
struct QuantParam {
const NodeArg& scale;
const NodeArg* zero_point{nullptr};
std::optional<int64_t> axis{std::nullopt};
};

const NodeArg& node_arg;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ Status BaseOpBuilder::ProcessInput(QnnModelWrapper& qnn_model_wrapper,
}

Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name);
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, input_info.qnn_data_type, input_info.quant_param,
std::move(input_info.shape), std::move(unpacked_tensor));
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, input_info.qnn_data_type,
std::move(input_info.quant_param), std::move(input_info.shape),
std::move(unpacked_tensor));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
input_names.push_back(input_name);

Expand Down Expand Up @@ -129,7 +130,7 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
TensorInfo output_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[output_i], output_info));

if (output_info.quant_param.encodingDefinition == QNN_DEFINITION_DEFINED) {
if (output_info.quant_param.IsQuantized()) {
ORT_RETURN_IF_ERROR(OverrideOutputQuantParam(qnn_model_wrapper, node_unit, logger, input_names,
output_i, output_info.qnn_data_type, output_info.quant_param));
}
Expand All @@ -143,7 +144,7 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
QnnTensorWrapper cast_input_tensorwrapper(cast_input_name,
QNN_TENSOR_TYPE_NATIVE,
supported_qnn_data_type,
output_info.quant_param,
output_info.quant_param.Copy(),
std::move(cast_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor.");
output_names.push_back(cast_input_name);
Expand All @@ -156,7 +157,7 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
QnnTensorWrapper output_tensorwrapper(output_name,
tensor_type,
output_info.qnn_data_type,
output_info.quant_param,
std::move(output_info.quant_param),
std::move(output_info.shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
}
Expand Down Expand Up @@ -189,15 +190,15 @@ Status BaseOpBuilder::SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper&
size_t input_index,
size_t output_index,
Qnn_DataType_t qnn_data_type,
Qnn_QuantizeParams_t& quant_param) const {
QnnQuantParamsWrapper& quant_param) const {
const QnnTensorWrapper& input_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[input_index]);
ORT_RETURN_IF_NOT(input_tensor_wrapper.GetTensorDataType() == qnn_data_type,
"Input and output data types do not match");
Qnn_QuantizeParams_t input_quant_param = GetQnnTensorQParams(input_tensor_wrapper.GetQnnTensor());
const QnnQuantParamsWrapper& input_quant_param = input_tensor_wrapper.GetQnnQuantParams();

float scale_diff = 0.0f;
int32_t offset_diff = 0;
ORT_RETURN_IF_ERROR(CompareQnnQuantParams(quant_param, input_quant_param, scale_diff, offset_diff));
ORT_RETURN_IF_ERROR(CompareQnnQuantParams(quant_param.Get(), input_quant_param.Get(), scale_diff, offset_diff));
constexpr float NEARLY_EQUAL_THRESHOLD = 1e-9f;
constexpr float WARN_THRESHOLD = 1e-6f;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/providers/shared/utils/utils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/op_builder.h"
#include "core/providers/qnn/builder/qnn_quant_params_wrapper.h"
#include "core/framework/allocator.h"

#include "QnnOpDef.h"
Expand Down Expand Up @@ -57,7 +58,7 @@ class BaseOpBuilder : public IOpBuilder {
const std::vector<std::string>& input_names,
size_t output_index,
Qnn_DataType_t qnn_data_type,
Qnn_QuantizeParams_t& quant_param) const ORT_MUST_USE_RESULT {
QnnQuantParamsWrapper& quant_param) const ORT_MUST_USE_RESULT {
// Do nothing by default. Op builders like Split implement this function to override output quant params.
ORT_UNUSED_PARAMETER(qnn_model_wrapper);
ORT_UNUSED_PARAMETER(node_unit);
Expand Down Expand Up @@ -110,7 +111,7 @@ class BaseOpBuilder : public IOpBuilder {
size_t input_index,
size_t output_index,
Qnn_DataType_t qnn_data_type,
Qnn_QuantizeParams_t& quant_param) const ORT_MUST_USE_RESULT;
QnnQuantParamsWrapper& quant_param) const ORT_MUST_USE_RESULT;

static const std::string& GetQnnOpType(const std::string& onnx_op_type) {
// TODO: Use QNN operator names defined in "QnnOpDef.h"
Expand Down Expand Up @@ -320,6 +321,8 @@ class BaseOpBuilder : public IOpBuilder {

private:
std::string op_builder_type_;

protected:
const std::vector<size_t> nchw2nhwc_perm{0, 2, 3, 1};
const std::vector<size_t> nchw2hwcn_perm{2, 3, 1, 0};
const std::vector<size_t> cnhw2hwcn_perm{2, 3, 0, 1};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,16 @@ class BatchNormOpBuilder : public BaseOpBuilder {
uint32_t channel = mean_info.shape[0];
mean_out.resize(channel);
ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(mean_info.qnn_data_type, channel, mean_raw_ptr_length));
ORT_RETURN_IF_NOT(!is_npu_backend || mean_info.quant_param.IsPerTensor(),
"BatchNormalization's input_mean does not support per-channel quantization");
int i = 0;
int offset = 0;
const Qnn_QuantizeParams_t& quant_param = mean_info.quant_param.Get();
for (; i < static_cast<int>(channel); ++i) {
double mean_value = 0.0;
ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset));
mean_out[i] = (is_npu_backend) ? utils::Dequantize(mean_info.quant_param.scaleOffsetEncoding.offset,
mean_info.quant_param.scaleOffsetEncoding.scale,
mean_out[i] = (is_npu_backend) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
mean_value)
: mean_value;
}
Expand All @@ -283,13 +286,16 @@ class BatchNormOpBuilder : public BaseOpBuilder {
uint32_t channel = var_info.shape[0];
std_out.resize(channel);
ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(var_info.qnn_data_type, channel, var_raw_ptr_length));
ORT_RETURN_IF_NOT(!is_npu_backend || var_info.quant_param.IsPerTensor(),
"BatchNormalization's input_var does not support per-channel quantization");
int i = 0;
int offset = 0;
const Qnn_QuantizeParams_t& quant_param = var_info.quant_param.Get();
for (; i < static_cast<int>(channel); ++i) {
double var_value = 0.0;
ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset));
std_out[i] = (is_npu_backend) ? utils::Dequantize(var_info.quant_param.scaleOffsetEncoding.offset,
var_info.quant_param.scaleOffsetEncoding.scale,
std_out[i] = (is_npu_backend) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
var_value)
: var_value;
std_out[i] = std::sqrt(std_out[i] + static_cast<double>(epsilon));
Expand All @@ -309,13 +315,16 @@ class BatchNormOpBuilder : public BaseOpBuilder {
uint32_t channel = scale_info.shape[0];
scale_out.resize(channel);
ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(scale_info.qnn_data_type, channel, scale_raw_ptr_length));
ORT_RETURN_IF_NOT(!is_npu_backend || scale_info.quant_param.IsPerTensor(),
"BatchNormalization's scale input does not support per-channel quantization");
int i = 0;
int offset = 0;
const Qnn_QuantizeParams_t& quant_param = scale_info.quant_param.Get();
for (; i < static_cast<int>(channel); ++i) {
double scale_value = 0.0;
ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset));
scale_out[i] = (is_npu_backend) ? utils::Dequantize(scale_info.quant_param.scaleOffsetEncoding.offset,
scale_info.quant_param.scaleOffsetEncoding.scale,
scale_out[i] = (is_npu_backend) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
scale_value)
: scale_value;
scale_out[i] = scale_out[i] / std_double_tensor[i];
Expand All @@ -338,13 +347,16 @@ class BatchNormOpBuilder : public BaseOpBuilder {
uint32_t channel = bias_info.shape[0];
bias_out.resize(channel);
ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(bias_info.qnn_data_type, channel, bias_raw_ptr_length));
ORT_RETURN_IF_NOT(!is_npu_backend || bias_info.quant_param.IsPerTensor(),
"BatchNormalization's bias input does not support per-channel quantization");
int i = 0;
int offset = 0;
const Qnn_QuantizeParams_t& quant_param = bias_info.quant_param.Get();
for (; i < static_cast<int>(channel); ++i) {
double bias_value = 0.0;
ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset));
bias_out[i] = (is_npu_backend) ? utils::Dequantize(bias_info.quant_param.scaleOffsetEncoding.offset,
bias_info.quant_param.scaleOffsetEncoding.scale,
bias_out[i] = (is_npu_backend) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
bias_value)
: bias_value;
bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]);
Expand All @@ -359,7 +371,7 @@ class BatchNormOpBuilder : public BaseOpBuilder {
const std::vector<double>& double_tensor,
const double rmax,
const double rmin,
Qnn_QuantizeParams_t& quant_param,
QnnQuantParamsWrapper& quant_param,
std::vector<uint8_t>& raw_tensor) const {
if (is_npu_backend) {
raw_tensor.resize(double_tensor.size());
Expand All @@ -370,8 +382,7 @@ class BatchNormOpBuilder : public BaseOpBuilder {
info.qnn_data_type,
scale,
zero_point));
quant_param = QNN_QUANTIZE_PARAMS_INIT;
utils::InitializeQuantizeParam(quant_param, true, scale, zero_point);
quant_param = QnnQuantParamsWrapper(scale, zero_point);
for (size_t i = 0; i < double_tensor.size(); ++i) {
// onnx only supports 8 bits quantization
int quant_value_int = 0;
Expand All @@ -382,6 +393,7 @@ class BatchNormOpBuilder : public BaseOpBuilder {
int8_t quant_value = static_cast<int8_t>(quant_value_int);
raw_tensor[i] = *reinterpret_cast<uint8_t*>(&quant_value);
} else {
// TODO(adrianlizarraga): Should support 16-bit quantization as well.
ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type);
}
}
Expand Down Expand Up @@ -545,7 +557,7 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,

if (!qnn_model_wrapper.IsQnnTensorWrapperExist(scale_name)) {
std::vector<uint8_t> scale_raw_tensor;
Qnn_QuantizeParams_t scale_quant_param = scale_info.quant_param;
QnnQuantParamsWrapper scale_quant_param = scale_info.quant_param;
ORT_RETURN_IF_ERROR(Postprocess(scale_info,
is_npu_backend,
scale_double_tensor,
Expand All @@ -554,15 +566,16 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
scale_quant_param,
scale_raw_tensor));
Qnn_TensorType_t scale_tensor_type = GetInputTensorType(qnn_model_wrapper, scale_name);
QnnTensorWrapper input_tensorwrapper(scale_name, scale_tensor_type, scale_info.qnn_data_type, scale_quant_param,
std::move(scale_info.shape), std::move(scale_raw_tensor));
QnnTensorWrapper input_tensorwrapper(scale_name, scale_tensor_type, scale_info.qnn_data_type,
std::move(scale_quant_param), std::move(scale_info.shape),
std::move(scale_raw_tensor));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
}
input_names.push_back(scale_name);

if (!qnn_model_wrapper.IsQnnTensorWrapperExist(bias_name)) {
std::vector<uint8_t> bias_raw_tensor;
Qnn_QuantizeParams_t bias_quant_param = bias_info.quant_param;
QnnQuantParamsWrapper bias_quant_param = bias_info.quant_param;
ORT_RETURN_IF_ERROR(Postprocess(bias_info,
is_npu_backend,
bias_double_tensor,
Expand All @@ -571,8 +584,9 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
bias_quant_param,
bias_raw_tensor));
Qnn_TensorType_t bias_tensor_type = GetInputTensorType(qnn_model_wrapper, bias_name);
QnnTensorWrapper input_tensorwrapper(bias_name, bias_tensor_type, bias_info.qnn_data_type, bias_quant_param,
std::move(bias_info.shape), std::move(bias_raw_tensor));
QnnTensorWrapper input_tensorwrapper(bias_name, bias_tensor_type, bias_info.qnn_data_type,
std::move(bias_quant_param), std::move(bias_info.shape),
std::move(bias_raw_tensor));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
}
input_names.push_back(bias_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
type_proto,
qnn_data_type));

QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, QNN_QUANTIZE_PARAMS_INIT,
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, QnnQuantParamsWrapper(),
std::move(input_shape), std::move(unpacked_tensor));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)),
"Failed to add input tensor for QNN Cast node.");
Expand Down Expand Up @@ -106,7 +106,7 @@ Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
QnnTensorWrapper output_tensorwrapper(output_name,
tensor_type,
qnn_data_type,
QNN_QUANTIZE_PARAMS_INIT,
QnnQuantParamsWrapper(),
std::move(output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)),
"Failed to add output tensor for QNN Cast node.");
Expand Down
Loading

0 comments on commit f644ff9

Please sign in to comment.