Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于使用TensorRT量化扩散模型中的unet,FP16与INT8正常,但是FP32异常的问题 #105

Open
Refrain-lhc opened this issue May 30, 2024 · 0 comments

Comments

@Refrain-lhc
Copy link

Refrain-lhc commented May 30, 2024

问题描述

主要的问题是,基于扩散模型的unet我转化出来的onnx文件大约155MB,但是用tensorRT量化FP32得到的trt模型竟然达到了264MB,但是FP16只有81MB,INT8是45MB,后两个是正常大小,FP32竟然多了100多MB。
用到的unet来自https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-

具体的操作流程如下:

  1. 先把torch转化成onnx文件
    x = torch.randn(1, 6, 256, 256)
    time = torch.randn(1, )
    torch.onnx.export(model, (x, time), 'unet_change.onnx')

  2. 再把onnx文件读取,转化成FP32的trt模型
    def build_engine(onnx_file, trt_file, selection, calibration_data=None):
    TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
    builder = trt.Builder(TRT_LOGGER)

    flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(flag)

    parser = trt.OnnxParser(network, TRT_LOGGER)
    with open(onnx_file, 'rb') as model:
    if not parser.parse(model.read()):
    for error in range(parser.num_errors):
    print(parser.get_error(error))
    return None

    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, (1 << 20))

    if selection == 'fp16':
    config.set_flag(trt.BuilderFlag.FP16)
    elif selection == 'int8':
    config.set_flag(trt.BuilderFlag.INT8)
    config.int8_calibrator = MyCalibrator_for_more(calibration_data)

    engine = builder.build_serialized_network(network, config)
    with open(trt_file, 'wb') as f:
    f.write(engine)
    return engine

model = 'unet'
onnx_file_path = model + '.onnx'
engine_fp32 = build_engine(onnx_file_path, model + "_fp32.trt", 'fp32')

其他的尝试

在此之前我简单测试过ResNet18和ResNet50,他们用上面这段代码的效果就非常好。然后我又加深了unet的ResConv块,发现onnx模型大小为280MB多的时候,FP32转出来竟然多了200多MB。
求大佬解惑,已经卡了一周多的时间了QAQ

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant