Skip to content

Commit

Permalink
Support fuse bn into ConvTranspose.
Browse files Browse the repository at this point in the history
Signed-off-by: wenyuchi.wyc <[email protected]>
  • Loading branch information
wenyuchi.wyc committed Mar 6, 2023
1 parent 807cff7 commit 5d4d388
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 7 deletions.
16 changes: 9 additions & 7 deletions onnxoptimizer/passes/fuse_bn_into_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
return "fuse_bn_into_conv";
}

bool modify_conv(Node* conv, Node* bn, Graph& graph) {
bool modify_conv(Node* conv, Node* bn, Graph& graph, const bool is_conv) {
const auto& bn_inputs = bn->inputs();
const auto& conv_inputs = conv->inputs();

Expand Down Expand Up @@ -123,10 +123,9 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
Node* unsqueeze = graph.create(kUnsqueeze, 1);
unsqueeze->insertAfter(scale);
unsqueeze->addInput(scale->output());
std::vector<int64_t> insert_dims;
for (int i = 1; i < conv_W.sizes().size(); ++i) {
insert_dims.push_back(i);
}
std::vector<int64_t> insert_dims(conv_W.sizes().size());
std::iota(insert_dims.begin(), insert_dims.end(), 0);
insert_dims.erase(insert_dims.begin() + (is_conv ? 0 : 1));
if (getOpsetVersion(graph) > 11) {
Tensor shape_s_t;
shape_s_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64;
Expand Down Expand Up @@ -181,7 +180,8 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
}

bool patternMatchPredicate(Node* n) override {
return CheckKind(n, kBatchNormalization, 0, kConv) &&
return (CheckKind(n, kBatchNormalization, 0, kConv) ||
CheckKind(n, kBatchNormalization, 0, kConvTranspose)) &&
GetValueFromAttrWithDefault(n, "training_mode", (int64_t)0) == 0 &&
n->input(0)->uses().size() == 1 && n->outputs().size() == 1 &&
IsConstantTensor(n, 1) && IsConstantTensor(n, 2) &&
Expand All @@ -190,10 +190,12 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
}
bool runTransform(Node* n, Graph& graph,
NodeDestroyType& destroy_current) override {
const bool is_conv = CheckKind(n, kBatchNormalization, 0, kConv);

Node* bn = n;
Node* conv = PrevNode(n, 0);
auto origInput = bn->inputs()[0];
if (!modify_conv(conv, bn, graph)) {
if (!modify_conv(conv, bn, graph, is_conv)) {
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
Expand Down
40 changes: 40 additions & 0 deletions onnxoptimizer/test/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3063,6 +3063,46 @@ def test_fuse_bn_into_conv_simple(self): # type: () -> None
)
optimized_model = self._optimized(graph, ["fuse_bn_into_conv"]) # noqa

def test_fuse_bn_into_conv_transpose_simple(self): # type: () -> None
for (tensor_type, np_type) in [(TensorProto.FLOAT, np.float32)]:
conv = helper.make_node("ConvTranspose", ["X", "W", "B"], ["Y"])
bn = helper.make_node(
"BatchNormalization", ["Y", "scale", "b", "mean", "var"], ["Z"]
)

W = np.random.randn(64, 64, 2, 2).astype(np_type) + 2
B = np.random.randn(64,).astype(np_type) + 2
scale = np.random.randn(64,).astype(np_type) + 2
b = np.random.randn(64,).astype(np_type) + 2
mean = np.random.randn(64,).astype(np_type) + 2
var = np.abs(np.random.randn(64,).astype(np_type)) + 2

initializers = [
helper.make_tensor(
name, tensor_type, npa.shape, npa.tobytes(), raw=True
)
for name, npa in [
("W", W),
("B", B),
("scale", scale),
("b", b),
("mean", mean),
("var", var),
]
]
graph = helper.make_graph(
[conv, bn],
"test",
[helper.make_tensor_value_info("X", tensor_type, (1, 64, 160, 160))],
[helper.make_tensor_value_info("Z", tensor_type, (1, 64, 320, 320))],
initializer=initializers,
value_info=[
helper.make_tensor_value_info("Y", tensor_type, (1, 64, 320, 320))
],
)

optimized_model = self._optimized(graph, ["fuse_bn_into_conv"])

def _internal_test_deadend_elimination(self, fixed): # type: (bool) -> None
softmax = helper.make_node("Softmax", ["X"], ["Y"], axis=2)
log = helper.make_node("Log", ["Y"], ["Z"])
Expand Down

0 comments on commit 5d4d388

Please sign in to comment.