Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fuse bn into ConvTranspose. #106

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

YuchiWen
Copy link

No description provided.

@daquexian
Copy link
Member

Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024

@YuchiWen YuchiWen force-pushed the fuse_bn_into_conv_transpose branch from 6e5ac70 to 5d4d388 Compare March 6, 2023 10:52
@YuchiWen
Copy link
Author

YuchiWen commented Mar 6, 2023

Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024

@daquexian Done, please review.

@YuchiWen YuchiWen force-pushed the fuse_bn_into_conv_transpose branch from 5d4d388 to ff1229e Compare March 6, 2023 11:31
@huangzhicong3
Copy link

Hello, i try to used this commit to fuse the bn layer and convtranspose layer in my model and find some bugs:
The error message is:
passes/fuse_bn_into_conv.h:71: modify_conv: Assertion conv_W.sizes().size() > 2 && conv_W.sizes()[0] == C failed.

From the doc of onnx website (https://onnx.ai/onnx/operators/onnx__ConvTranspose.html), the shape of weight array of convtranspose is (Cin, Cout, K, K), which is different to normal Conv layer (Cout, Cin, K, K).

@huangzhicong3
Copy link

huangzhicong3 commented May 30, 2024

Hi, i would like to share my codes for fusing convtranspose and bn. It has been tested on my own model. I hope it will help others who have the same issue.

import numpy as np
import onnx
import sclblonnx as so

model = onnx.load('../onnx/models/backbone_clean.onnx')

all_initializer = model.graph.initializer
all_node = model.graph.node
ConvTranspose_list = []
BatchNormalization_list = []
for i, node in enumerate(all_node):
    # search convtranspose and batchnormalization
    if node.op_type == "ConvTranspose":
        # print(i, node.name, node.op_type,  node.input, node.output)
        ConvTranspose_list.append(node)
    if node.op_type == "BatchNormalization":
        # print(i, node.name, node.op_type,  node.input, node.output)
        BatchNormalization_list.append(node)

valid_ConvTranspose_list = []
for node in ConvTranspose_list:
    output = node.output
    for bn_node in BatchNormalization_list:
        bn_inputs = bn_node.input
        if output[0] in bn_inputs:
            valid_ConvTranspose_list.append({"conv": node, "bn": bn_node})
            continue

# print(valid_ConvTranspose_list)
param_dict = {}
for node in valid_ConvTranspose_list:
    conv = node["conv"]
    bn = node["bn"]
    # find params
    param_name = list(conv.input) + list(bn.input)
    for i, initializer in enumerate(all_initializer):
        if initializer.name in param_name:
            param_dict[initializer.name] = onnx.numpy_helper.to_array(initializer)
# print(param_dict)
for node in valid_ConvTranspose_list:
    conv = node["conv"]
    bn = node["bn"]

    bn_eps = bn.attribute[0].f
    bn_mom = bn.attribute[1].f

    bn_w = param_dict[bn.input[1]]  # [Cout, ]
    bn_b = param_dict[bn.input[2]]  # [Cout, ]
    bn_mean = param_dict[bn.input[3]]  # [Cout, ]
    bn_var = param_dict[bn.input[4]]  # [Cout, ]

    conv_w = param_dict[conv.input[1]]  # [Cin, Cout, H, W]
    if len(conv.input) > 2:
        conv_b = param_dict[conv.input[2]]
    else:
        conv_b = np.zeros_like(bn_b)  # [Cout, ]
    conv_w_tran = conv_w.transpose(1, 0, 2, 3)

    Cout = conv_w_tran.shape[0]
    conv_w_reshape = conv_w_tran.reshape([Cout, -1])
    w_bn = np.diag(bn_w / (np.sqrt(bn_eps + bn_var)))
    new_conv_w = np.matmul(w_bn, conv_w_reshape).reshape(conv_w_tran.shape).transpose(1, 0, 2, 3)
    bn_b_tmp = bn_b - (np.multiply(bn_w, bn_mean) / (np.sqrt(bn_eps + bn_var)))
    new_conv_b = np.matmul(bn_w, conv_b) + bn_b_tmp

    new_node = onnx.helper.make_node(
        name=conv.name+'_bn',
        op_type="ConvTranspose",
        inputs=[conv.input[0], conv.name+'_bn.weights', conv.name+'_bn.bias'],
        outputs=[bn.output[0]],
        dilations=conv.attribute[0].ints,
        group=conv.attribute[1].i,
        kernel_shape=conv.attribute[2].ints,
        pads=conv.attribute[3].ints,
        strides=conv.attribute[4].ints
    )
    initializer_w = onnx.helper.make_tensor(
        name=conv.name+'_bn.weights',
        data_type=onnx.helper.TensorProto.DataType.FLOAT,
        dims=new_conv_w.shape,
        vals=new_conv_w.tobytes(),
        raw=True
    )
    initializer_b = onnx.helper.make_tensor(
        name=conv.name+'_bn.bias',
        data_type=onnx.helper.TensorProto.DataType.FLOAT,
        dims=new_conv_b.shape,
        vals=new_conv_b.tobytes(),
        raw=True
    )

    model.graph.initializer.append(initializer_w)
    model.graph.initializer.append(initializer_b)

    # insert node
    for i, node in enumerate(all_node):
        if conv.name == node.name:
            model.graph.node.insert(i, new_node)
            break
    # clean node
    model.graph.node.remove(conv)
    model.graph.node.remove(bn)

onnx.checker.check_model(model)
onnx.save(model, '../onnx/models/backbone_fuse.onnx')

graph = so.graph_from_file('../onnx/models/backbone_fuse.onnx')
graph = so.clean(graph)
so.check(graph)
so.graph_to_file(graph, '../onnx/models/backbone_fuse.onnx')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants