Skip to content

Commit

Permalink
Add patch for ONNX 1.16.0 shape inference bug (#20316)
Browse files Browse the repository at this point in the history
### Description
- Adds a patch that fixes a shape inference bug that caused a segfault:
onnx/onnx#6080
- Fix documentation describing why QLinearMatMul tests are currently
being skipped.



### Motivation and Context
The [PR for integrating with ONNX
1.16.0](#19745) disabled
various python quantization tests due to a shape inference bug. This PR
applies the ONNX fix as a patch. We still can't enable the tests because
some of our CIs pip install onnx-1.16.0, which doesn't include the fix.
  • Loading branch information
adrianlizarraga authored Apr 17, 2024
1 parent bb19722 commit 0a19025
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 39 deletions.
94 changes: 65 additions & 29 deletions cmake/patches/onnx/onnx.patch
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4dd56b6e..018da488 100644
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 6d7ca846..69aa622f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -397,6 +397,7 @@ if (MSVC)
@@ -499,6 +499,7 @@ if (MSVC)
endif()
else()
# On non-Windows, hide all symbols we don't need
+ set(EXTRA_FLAGS "-Wno-unused-parameter")
set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)")
set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden)
set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1)
@@ -548,20 +549,9 @@ endif()
@@ -653,20 +654,9 @@ endif()
if(MSVC)
target_compile_options(onnx_proto
PRIVATE /MP
Expand All @@ -31,14 +31,72 @@ index 4dd56b6e..018da488 100644
${EXTRA_FLAGS})
if(ONNX_USE_PROTOBUF_SHARED_LIBS)
target_compile_options(onnx_proto
diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h
index b847798e..a6c31904 100644
--- a/onnx/common/file_utils.h
+++ b/onnx/common/file_utils.h
@@ -6,7 +6,6 @@

#pragma once

-#include <filesystem>
#include <fstream>
#include <string>

@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE {

template <typename T>
void LoadProtoFromPath(const std::string proto_path, T& proto) {
- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path);
- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary);
+ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary);
if (!proto_stream.good()) {
fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. ");
}
diff --git a/onnx/defs/quantization/defs.cc b/onnx/defs/quantization/defs.cc
index 70b4a4db..98c11545 100644
--- a/onnx/defs/quantization/defs.cc
+++ b/onnx/defs/quantization/defs.cc
@@ -200,6 +200,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.SetDoc(DequantizeLinear_ver21_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
+ if (!hasInputShape(ctx, 0)) {
+ return;
+ }
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
diff --git a/onnx/defs/quantization/old.cc b/onnx/defs/quantization/old.cc
index 3f2d6384..d2f7cfd8 100644
--- a/onnx/defs/quantization/old.cc
+++ b/onnx/defs/quantization/old.cc
@@ -130,6 +130,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.SetDoc(DequantizeLinear_ver19_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
+ if (!hasInputShape(ctx, 0)) {
+ return;
+ }
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
@@ -181,7 +184,6 @@ ONNX_OPERATOR_SET_SCHEMA(
if (!hasInputShape(ctx, 0)) {
return;
}
-
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h
index 0aab3e26..0f859267 100644
index 0aab3e26..398ac2d6 100644
--- a/onnx/onnx_pb.h
+++ b/onnx/onnx_pb.h
@@ -47,10 +47,28 @@
#define ONNX_API ONNX_IMPORT
#endif

+#if defined(__GNUC__)
+#pragma GCC diagnostic push
+
Expand All @@ -58,34 +116,12 @@ index 0aab3e26..0f859267 100644
#else
#include "onnx/onnx.pb.h"
#endif

+#if defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
+
#endif // ! ONNX_ONNX_PB_H
diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h
index b847798e..a6c31904 100644
--- a/onnx/common/file_utils.h
+++ b/onnx/common/file_utils.h
@@ -6,7 +6,6 @@

#pragma once

-#include <filesystem>
#include <fstream>
#include <string>

@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE {

template <typename T>
void LoadProtoFromPath(const std::string proto_path, T& proto) {
- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path);
- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary);
+ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary);
if (!proto_stream.good()) {
fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. ");
}
diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc
index fab1faf2..8723dcd4 100644
--- a/onnx/shape_inference/implementation.cc
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/python/quantization/test_op_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def test_quantize_gemm(self):
weight_type=QuantType.QUInt8,
)

@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
def test_quantize_qop_gemm_s8s8(self):
np.random.seed(1)
model_fp32_path = "gemm_fp32.onnx"
Expand Down Expand Up @@ -366,7 +366,7 @@ def test_quantize_qdq_gemm_e4m3fn_same(self):
calibrate_method=CalibrationMethod.Distribution,
)

@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
def test_quantize_qop_gemm_e4m3fn_same(self):
np.random.seed(1)
model_fp32_path = "gemm_fp32.onnx"
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_quantize_qdq_gemm_e4m3fn_p3(self):
calibrate_method=CalibrationMethod.Distribution,
)

@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
def test_quantize_qop_gemm_e4m3fn_p3(self):
np.random.seed(1)
model_fp32_path = "gemm_fp32.onnx"
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/test/python/quantization/test_op_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def quantize_matmul_u8u8(self, tt, opset, ir_version):
def test_quantize_matmul_u8u8(self):
self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT, 18, 8)

@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_u8u8_f16(self):
self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT16, 21, 9)
Expand Down Expand Up @@ -393,22 +393,22 @@ def test_quantize_matmul_s8s8_percentile(self):
def test_quantize_matmul_s8s8_distribution(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Distribution)

@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_s8s8_f16(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9)

@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_s8s8_f16_entropy(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Entropy)

@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_s8s8_f16_percentile(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Percentile)

@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_s8s8_f16_distribution(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Distribution)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/python/quantization/test_op_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_quantize_gemm(self):
weight_type=QuantType.QUInt8,
)

@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
def test_quantize_qop_relu_s8s8(self):
np.random.seed(1)
model_fp32_path = "relu_fp32.onnx"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def test_get_input_output_names():


# Fails in ONNX 1.16.0 due to potential shape inference bug for custom ops.
# Potential ONNX fix: https://github.com/onnx/onnx/pull/6049
# Potential ONNX fix: https://github.com/onnx/onnx/pull/6080
# Error log: LookupError: The provided name onnx::linear.output::171 is not a graph value info or a graph output.
@pytest.mark.skipif(
pv.Version(onnx.__version__) == pv.Version("1.16.0"), reason="Shape inference bug for custom ops in ONNX 1.16.0"
Expand Down

0 comments on commit 0a19025

Please sign in to comment.