Skip to content

Commit

Permalink
upgrade onnx submodule, support qdq
Browse files Browse the repository at this point in the history
Signed-off-by: daquexian <[email protected]>
  • Loading branch information
daquexian committed Apr 4, 2021
1 parent cc8fd26 commit ec8ff8e
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 10 deletions.
42 changes: 33 additions & 9 deletions onnxoptimizer/passes/fuse_add_bias_into_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,21 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
}
static Node *makeSqueezeOrUnsqueeze(Graph &graph, std::vector<int64_t> &axes,
Value *input, Node *target_node,
BuiltinSymbol k) {
BuiltinSymbol k, bool is_input_qdq) {
assert(k == kSqueeze || k == kUnsqueeze);
Node *squeeze = graph.create(k, 1);
int opset_version = getOpsetVersion(graph);
Node *dequant_node = nullptr;
Node *quant_node = nullptr;
if (is_input_qdq) {
dequant_node = input->node();
quant_node = dequant_node->input(0)->node();
target_node = quant_node;
input = target_node->input(0);
dequant_node->output()->clearMetadata();
quant_node->output()->clearMetadata();
}
squeeze->addInput(input);
int opset_version = getOpsetVersion(graph);
int version_threshold = 13;
if (opset_version < version_threshold && opset_version != 0) {
squeeze->is_(kaxes, std::move(axes));
Expand All @@ -54,7 +64,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
Value *tv = graph.addInitializerAndInput(t);
squeeze->addInput(tv);
}
if (is_input_qdq) {
quant_node->replaceInput(0, squeeze->output());
}
squeeze->insertBefore(target_node);
if (is_input_qdq) {
return dequant_node;
}
return squeeze;
}
bool runTransform(Node *n, Graph &graph,
Expand Down Expand Up @@ -115,13 +131,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
if (bias_shape.size() > 1) {
std::vector<int64_t> axes(bias_shape.size() - 1);
std::iota(axes.begin(), axes.end(), 0);
Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input,
orig_conv->node(), kSqueeze);
Node *squeeze = makeSqueezeOrUnsqueeze(
graph, axes, conv_3rd_input, orig_conv->node(), kSqueeze, false);
conv_3rd_input = squeeze->output();
} else if (bias_shape.size() == 0) {
std::vector<int64_t> axes = {0};
Node *unsqueeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input,
orig_conv->node(), kUnsqueeze);
Node *unsqueeze = makeSqueezeOrUnsqueeze(
graph, axes, conv_3rd_input, orig_conv->node(), kUnsqueeze, false);
conv_3rd_input = unsqueeze->output();
}
if (M > 1) {
Expand Down Expand Up @@ -149,17 +165,25 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
bias_shape[1 + bias_shape.size() - static_cast<unsigned>(rank)]
.dim == M) {
ONNX_ASSERT(bias_shape.size() > 1);
const bool is_input_qdq =
orig_bias->node()->kind() == Symbol("DequantizeLinear") &&
orig_bias->node()->input(0)->node()->kind() ==
Symbol("QuantizeLinear");
if (orig_bias->node()->kind() != kParam &&
orig_conv->node()->isBefore(orig_bias->node())) {
if (is_input_qdq) {
orig_bias->node()->input(0)->node()->moveBefore(orig_conv->node());
}
orig_bias->node()->moveBefore(orig_conv->node());
}
std::vector<int64_t> axes(bias_shape.size());
std::iota(axes.begin(), axes.end(), static_cast<int64_t>(0));
axes.erase(axes.begin() +
(1 + bias_shape.size() - static_cast<unsigned>(rank)));
Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, orig_bias,
orig_conv->node(), kSqueeze);
orig_conv->node()->addInput(squeeze->output());

Node *new_bias = makeSqueezeOrUnsqueeze(
graph, axes, orig_bias, orig_conv->node(), kSqueeze, is_input_qdq);
orig_conv->node()->addInput(new_bias->output());
} else {
return false;
}
Expand Down
36 changes: 36 additions & 0 deletions onnxoptimizer/test/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _optimized(self, graph_or_model, opts, fixed_point=False, compare_result=Tru
graph_or_model, producer_name='onnx-test', opset_imports=opset_imports, **kwargs)
checker.check_model(orig_model)
optimized_model = onnxoptimizer.optimize(orig_model, opts, fixed_point)
print(str(optimized_model))
checker.check_model(optimized_model)
if compare_result and len(optimized_model.graph.node) > 0:
if has_ort:
Expand Down Expand Up @@ -1150,6 +1151,41 @@ def test_fuse_add_bias_into_conv_with_non_constant_bias(self):
assert optimized_model.graph.node[2].op_type == 'Conv'
assert optimized_model.graph.output[0].name == 'C'

# type: () -> None
def test_fuse_add_bias_into_conv_with_quanted_bias(self):
nodes = [helper.make_node("Conv", ["X", "Y"], ["Z"]),
helper.make_node("QuantizeLinear", ["A", "scale", "zero_point"], ["B"], axis=0),
helper.make_node("DequantizeLinear", ["B", "scale", "zero_point"], ["C"], axis=0),
helper.make_node("Add", ["Z", "C"], ["D"])]
graph = helper.make_graph(
nodes,
"test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 5, 3, 3)),
helper.make_tensor_value_info(
"Y", TensorProto.FLOAT, (16, 5, 3, 3)),
helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 1, 1))],
[helper.make_tensor_value_info(
"D", TensorProto.FLOAT, (1, 16, 1, 1))],
[helper.make_tensor("scale", TensorProto.FLOAT,
dims=(16,),
vals=np.random.rand(16).astype(np.float32).tobytes(),
raw=True),
helper.make_tensor("zero_point", TensorProto.INT8,
dims=(16,),
vals=np.zeros([16]).astype(np.int8).tobytes(),
raw=True)],
value_info=[helper.make_tensor_value_info(
"C", TensorProto.FLOAT, (16, 1, 1))]
)
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"], opset_imports=[helper.make_opsetid("", 13)])

assert len(list(optimized_model.graph.node)) == 4
assert optimized_model.graph.node[0].op_type == 'Squeeze'
assert optimized_model.graph.node[1].op_type == 'QuantizeLinear'
assert optimized_model.graph.node[2].op_type == 'DequantizeLinear'
assert optimized_model.graph.node[3].op_type == 'Conv'
assert optimized_model.graph.output[0].name == 'D'

def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
add = helper.make_node("Add", ["Z", "B"], ["A"])
Expand Down
2 changes: 1 addition & 1 deletion third_party/onnx
Submodule onnx updated 302 files

0 comments on commit ec8ff8e

Please sign in to comment.