You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(onnx_file, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
问题描述
主要的问题是,基于扩散模型的unet我转化出来的onnx文件大约155MB,但是用tensorRT量化FP32得到的trt模型竟然达到了264MB,但是FP16只有81MB,INT8是45MB,后两个是正常大小,FP32竟然多了100多MB。
用到的unet来自https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-
具体的操作流程如下:
先把torch转化成onnx文件
x = torch.randn(1, 6, 256, 256)
time = torch.randn(1, )
torch.onnx.export(model, (x, time), 'unet_change.onnx')
再把onnx文件读取,转化成FP32的trt模型
def build_engine(onnx_file, trt_file, selection, calibration_data=None):
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
builder = trt.Builder(TRT_LOGGER)
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(onnx_file, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, (1 << 20))
if selection == 'fp16':
config.set_flag(trt.BuilderFlag.FP16)
elif selection == 'int8':
config.set_flag(trt.BuilderFlag.INT8)
config.int8_calibrator = MyCalibrator_for_more(calibration_data)
engine = builder.build_serialized_network(network, config)
with open(trt_file, 'wb') as f:
f.write(engine)
return engine
model = 'unet'
onnx_file_path = model + '.onnx'
engine_fp32 = build_engine(onnx_file_path, model + "_fp32.trt", 'fp32')
其他的尝试
在此之前我简单测试过ResNet18和ResNet50,他们用上面这段代码的效果就非常好。然后我又加深了unet的ResConv块,发现onnx模型大小为280MB多的时候,FP32转出来竟然多了200多MB。
求大佬解惑,已经卡了一周多的时间了QAQ
The text was updated successfully, but these errors were encountered: