Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorrt_llm 中 layer norm 插件的 USE_DIFF_OF_SQUARES 实现可能导致结果出现 nan #88

Open
xiatwhu opened this issue Sep 9, 2023 · 0 comments

Comments

@xiatwhu
Copy link

xiatwhu commented Sep 9, 2023

  • Environment
    • TensorRT 9.0.0.2
    • Versions of CUDA(12.1), CUBLAS(12.1.3.1)
    • Container used (registry.cn-hangzhou.aliyuncs.com/trt-hackathon/trt-hackathon:final_v1)
    • NVIDIA driver version (535.86.05)
  • Reproduction Steps
    USE_DIFF_OF_SQUARES 采用了 Var[x] = E[x²] - E[x]² 的公式来计算方差,但由于 float 能表示的精度有限,在某些情况下会导致计算出来的 Var[x] 为负数,此时通过 rsqrtf 计算得到的标准差的倒数为 nan,从而导致经过 layer norm 之后输出为 nan。

使用 torch 进行简单的验证

import torch

x = torch.full((1, 10240), 256., device='cuda:0', dtype=torch.float32)
x[0, -1] = 255

x_square_mean = (x * x).mean()
x_mean = x.mean()
x_var = x_square_mean - x_mean * x_mean
x_inv_std = torch.rsqrt(x_var + 1e-5)

print(x_var, x_inv_std)
  • Expected Behavior
    输出的方差大于 0,标准差的倒数大于 0

  • Actual Behavior
    输出的方差小于 0,标准差的倒数为 nan

tensor(-0.0039, device='cuda:0') tensor(nan, device='cuda:0')
  • Additional Notes
    编写 C++ 代码来测试 tensorrt_llm 中的 kernel,确认经过 layer norm 计算后结果为 nan
#include "cuda_runtime.h"
#include "tensorrt_llm/kernels/layernormKernels.h"

#include <vector>
#include <iostream>

float* malloc_device(std::vector<float>& host) {
    float* dev = nullptr;
    cudaMalloc(&dev, host.size() * sizeof(float));
    cudaMemcpy(dev, host.data(), host.size() * sizeof(float), cudaMemcpyDefault);
    return dev;
}

int main() {
    const int size = 10240;
    
    auto host_input = std::vector<float>(size, 256.f);
    host_input[size - 1] = 255.f;
    auto host_gamma = std::vector<float>(size, 1.f);
    auto host_beta = std::vector<float>(size, 0.f);
    auto host_output = std::vector<float>(size, 0.f);

    float* dev_input = malloc_device(host_input);
    float* dev_gamma = malloc_device(host_gamma);
    float* dev_beta = malloc_device(host_beta);
    float* dev_output = malloc_device(host_output);

    tensorrt_llm::kernels::invokeGeneralLayerNorm(
            dev_output, dev_input, dev_gamma, dev_beta, 1e-5f, 1, size);

    cudaMemcpy(host_output.data(), dev_output, size * sizeof(float), cudaMemcpyDefault);

    std::cout << host_output[0] << std::endl;

    return 0;
}

修复方法:

  • 方案一:只需要添加一个 variance = max(variance, 0.f) 的保护即可修复该 bug,在代码 tensorrt_llm/kernels/layernormKernels.cu 中 118 行
    if (threadIdx.x == 0)
    {
        mean = mean / hidden_dim;
        s_mean = mean;
        if (USE_DIFF_OF_SQUARES)
        {
            variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
            // 此处添加一行 variance = max(variance, 0.f)
            s_variance = rsqrtf(variance + eps);
        }
    }
    __syncthreads();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant