Skip to content

Commit

Permalink
[Quant tool] Fix quantized bias's scale dtype to properly handle fp16…
Browse files Browse the repository at this point in the history
… bias inputs (#20340)

### Description
- Fix quantization tool bug that did not correctly set a quantized
bias's scale data type to fp16 if the original bias was fp16.
- Enabled fp16 ConvTranspose quantization unit tests that were disabled.



### Motivation and Context
Python quantization tests for fp16 ConvTranspose were originally
disabled due to a shape inference bug. It turns out that we also have a
bug in our quantizer that does not properly handle fp16 bias inputs.
Fixing the bug allows us to re-enable these tests with the latest
version of ONNX.
  • Loading branch information
adrianlizarraga authored Apr 17, 2024
1 parent 0a19025 commit eae7b70
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 3 additions & 1 deletion onnxruntime/python/tools/quantization/base_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1
bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims)
packed_bias_initializer = onnx.numpy_helper.from_array(bias_np_data, quantized_bias_name)
self.model.initializer_extend([packed_bias_initializer])
bias_scale_data = np.asarray(bias_scale, dtype=np.float32).reshape(-1)

# Bias's scale dtype should match the original bias data's unquantized type (float32 or float16).
bias_scale_data = np.asarray(bias_scale, dtype=bias_data.dtype).reshape(-1)
node_type = "DequantizeLinear"
node_qtype = self.weight_qType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def quantize_conv_transpose_u8u8(self, onnx_type, opset, ir_version):
def test_quantize_conv_transpose_u8u8(self):
self.quantize_conv_transpose_u8u8(TensorProto.FLOAT, 13, 7)

@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
def test_quantize_conv_transpose_u8u8_fp16(self):
self.quantize_conv_transpose_u8u8(TensorProto.FLOAT16, 19, 9)

Expand All @@ -160,7 +159,7 @@ def quantize_conv_transpose_s8s8(self, onnx_type, opset, ir_version):

np.random.seed(1)
model_fp32_path = "conv_transpose_fp32.onnx"
self.construct_model(model_fp32_path)
self.construct_model(model_fp32_path, onnx_type, opset, ir_version)
dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_type)
data_reader = self.input_feeds(1, {"input": [1, 1, 7, 7]}, dtype)

Expand All @@ -175,7 +174,6 @@ def quantize_conv_transpose_s8s8(self, onnx_type, opset, ir_version):
def test_quantize_conv_transpose_s8s8(self):
self.quantize_conv_transpose_s8s8(TensorProto.FLOAT, 13, 7)

@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
def test_quantize_conv_transpose_s8s8_fp16(self):
self.quantize_conv_transpose_s8s8(TensorProto.FLOAT16, 19, 9)

Expand Down

0 comments on commit eae7b70

Please sign in to comment.