Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorRT 在转换包含 Trilu 的 onnx 模型时可能会获得错误的结果 #84

Open
xiatwhu opened this issue Aug 22, 2023 · 1 comment

Comments

@xiatwhu
Copy link

xiatwhu commented Aug 22, 2023

  • Environment
    • TensorRT 9.0.0.2
    • Versions of CUDA(12.1), CUBLAS(12.1.3.1)
    • Container used (registry.cn-hangzhou.aliyuncs.com/trt-hackathon/trt-hackathon:final_v1)
    • NVIDIA driver version (535.86.05)
  • Reproduction Steps
    该 bug 是在初赛时转换 clip 模型时发现的,复现的完整代码如下:
import torch
import os
import tensorrt as trt
from cuda import cudart
import numpy as np

class CLIP(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        mask = torch.empty(x.shape[0], x.shape[1], x.shape[2], device=x.device)
        mask.fill_(float('-inf'))
        mask.triu_(1)
        return mask, mask + x
    
model = CLIP()
device = torch.device('cuda')
x = torch.ones(1, 2, 2, device=device)

torch.onnx.export(model, (x), 'test.onnx',
                  export_params=True,
                  opset_version=18,
                  do_constant_folding=True,
                  keep_initializers_as_inputs=True,
                  input_names=['x'],
                  output_names=['mask', 'y'])

os.system("trtexec --onnx=test.onnx --saveEngine=test.plan")

trt_logger = trt.Logger(trt.Logger.INFO)
with open('test.plan', 'rb') as f, trt.Runtime(trt_logger) as runtime:
    trt_engine = runtime.deserialize_cuda_engine(f.read())
    trt_ctx = trt_engine.create_execution_context()

    mask = torch.empty(1, 2, 2, device=device)
    y = torch.empty(1, 2, 2, device=device)

    trt_ctx.set_tensor_address('x', x.data_ptr())
    trt_ctx.set_tensor_address('mask', mask.data_ptr())
    trt_ctx.set_tensor_address('y', y.data_ptr())

    stream = cudart.cudaStreamCreateWithPriority(cudart.cudaStreamNonBlocking, 0)[1]
    trt_ctx.execute_async_v3(stream)
    cudart.cudaStreamSynchronize(stream)

    print('trt mask:', mask)
    print('trt y:', y)


mask, y = model(x)
print('torch mask: ', mask)
print('torch y: ', y)
  • Expected Behavior
    预期 TensorRT 获得与 pytorch 一致的输出
trt mask: tensor([[[0., -inf],
         [0., 0.]]], device='cuda:0')
trt y: tensor([[[1., -inf],
         [1., 1.]]], device='cuda:0')
torch mask:  tensor([[[0., -inf],
         [0., 0.]]], device='cuda:0')
torch y:  tensor([[[1., -inf],
         [1., 1.]]], device='cuda:0')
  • Actual Behavior
    实际上 Trilu 输出的结果中本该为 0 的地方实际上为 nan, 进而导致后续计算结果也为 nan
trt mask: tensor([[[nan, -inf],
         [nan, nan]]], device='cuda:0')
trt y: tensor([[[nan, -inf],
         [nan, nan]]], device='cuda:0')
torch mask:  tensor([[[0., -inf],
         [0., 0.]]], device='cuda:0')
torch y:  tensor([[[1., -inf],
         [1., 1.]]], device='cuda:0')
  • Additional Notes
    引起这一错误的原因在于从 onnx Trilu 转为 TensorRT Layer 时使用了一个工具函数 createZeroTensor‎,这个函数目的是为了创建一个与输入 Tensor 维度一致数值全为 0 的 Tensor。该函数实现时使用 constant 0 与输入做点乘的方案,这个方案在输入为 -inf 时并不会输出 0,而是输出 nan。

Trilu 转换代码

auto* rows = iota(ctx, iotadims, 0);
auto* cols = iota(ctx, iotadims, 1);
auto* zero = createZeroTensor(ctx, data);

createZeroTensor 实现‎

nvinfer1::ITensor* createZeroTensor(IImporterContext* ctx, nvinfer1::ITensor* data)
{
    nvinfer1::ITensor* zero
        = addConstant(ctx, std::vector<float>{0.f}, ::ONNX_NAMESPACE::TensorProto::FLOAT, {0, {1}})->getOutput(0);
    zero = castHelper(ctx, zero, data->getType());
    broadcastTensors(ctx, zero, data);
    zero = ctx->network()->addElementWise(*data, *zero, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
    return zero;
}
@wlp9805
Copy link

wlp9805 commented Feb 5, 2024

同样遇到这个问题,导出onnx后通过替换一个constant解决了

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants