Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Corrupt stack detected at runtime #279

Closed
WangJialei-A opened this issue Aug 23, 2024 · 2 comments · Fixed by #287
Closed

Corrupt stack detected at runtime #279

WangJialei-A opened this issue Aug 23, 2024 · 2 comments · Fixed by #287
Assignees
Labels
bug Something isn't working CPU

Comments

@WangJialei-A
Copy link
Contributor

Please try to reproduce the issue on PR #161
This issue may share the same root cause with #278, but not the same phenomena

reproduce command

python -m benchgc --verbose 1 --driver mlir --case /path/to/mlir
python: /home/jovyan/graph-compiler/lib/gc/ExecutionEngine/CPURuntime/MemoryPool.cpp:213: void (anonymous namespace)::FILOMemoryPool::dealloc(void *): Assertion `chunk->canary == MemoryChunk::magicCheckNum && "Corrupt stack detected"' failed.

Please save the following module as an mlir file

module {
  func.func @entry(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %arg2: tensor<1x32x4096xbf16>, %arg3: tensor<1xf32>, %arg4: tensor<4096xbf16>, %arg5: tensor<4096x11008xbf16>, %arg6: tensor<4096x11008xbf16>, %arg7: tensor<11008x4096xbf16>, %arg8: tensor<1xf32>, %arg9: tensor<4096xbf16>) -> tensor<1x32x4096xbf16> attributes {llvm.emit_c_interface} {
    %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16>
    %cst = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<32x4096xbf16>
    %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16>
    %2 = linalg.matmul ins(%collapsed, %arg1 : tensor<32x4096xbf16>, tensor<4096x4096xbf16>) outs(%1 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16>
    %expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16>
    %3 = tensor.empty() : tensor<1x32x4096xbf16>
    %4 = linalg.add ins(%arg2, %expanded : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%3 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16>
    %5 = tensor.empty() : tensor<1x32x4096xf32>
    %6 = linalg.copy ins(%4 : tensor<1x32x4096xbf16>) outs(%5 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32>
    %cst_0 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32>
    %7 = tensor.empty() : tensor<1x32x4096xf32>
    %8 = linalg.powf ins(%6, %cst_0 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%7 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %9 = tensor.empty() : tensor<1x32xf32>
    %10 = linalg.fill ins(%cst_1 : f32) outs(%9 : tensor<1x32xf32>) -> tensor<1x32xf32>
    %reduced = linalg.reduce ins(%8 : tensor<1x32x4096xf32>) outs(%10 : tensor<1x32xf32>) dimensions = [2] 
      (%in: f32, %init: f32) {
        %26 = arith.addf %in, %init : f32
        linalg.yield %26 : f32
      }
    %cst_2 = arith.constant dense<4.096000e+03> : tensor<1x32xf32>
    %11 = tensor.empty() : tensor<1x32xf32>
    %12 = linalg.div ins(%reduced, %cst_2 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%11 : tensor<1x32xf32>) -> tensor<1x32xf32>
    %expanded_3 = tensor.expand_shape %12 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32>
    %13 = tensor.empty() : tensor<1x32x1xf32>
    %broadcasted = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%13 : tensor<1x32x1xf32>) dimensions = [0, 1] 
    %14 = tensor.empty() : tensor<1x32x1xf32>
    %15 = linalg.add ins(%expanded_3, %broadcasted : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%14 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32>
    %cst_4 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32>
    %16 = tensor.empty() : tensor<1x32x1xf32>
    %17 = linalg.powf ins(%15, %cst_4 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%16 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32>
    %collapsed_5 = tensor.collapse_shape %17 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32>
    %18 = tensor.empty() : tensor<1x32x4096xf32>
    %broadcasted_6 = linalg.broadcast ins(%collapsed_5 : tensor<1x32xf32>) outs(%18 : tensor<1x32x4096xf32>) dimensions = [2] 
    %19 = tensor.empty() : tensor<1x32x4096xf32>
    %20 = linalg.mul ins(%6, %broadcasted_6 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%19 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32>
    %21 = tensor.empty() : tensor<1x32x4096xbf16>
    %22 = linalg.copy ins(%20 : tensor<1x32x4096xf32>) outs(%21 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16>
    %23 = tensor.empty() : tensor<1x32x4096xbf16>
    %broadcasted_7 = linalg.broadcast ins(%arg4 : tensor<4096xbf16>) outs(%23 : tensor<1x32x4096xbf16>) dimensions = [0, 1] 
    %24 = tensor.empty() : tensor<1x32x4096xbf16>
    %25 = linalg.mul ins(%broadcasted_7, %22 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%24 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16>
    return %25 : tensor<1x32x4096xbf16>
  }
}
@WangJialei-A WangJialei-A added bug Something isn't working CPU labels Aug 23, 2024
@crazydemo
Copy link
Contributor

Checked the generated IR. The dealloc ops inside the forloop generated by createBufferDeallocationPass does not obey the FILO rule, which is conflict with the design of gc memory pool. createConvertMemRefToCPURuntime can be disabled, until the bufferDeallocation is solved.

@yifeizh2
Copy link
Contributor

Checked the generated IR. The dealloc ops inside the forloop generated by createBufferDeallocationPass does not obey the FILO rule, which is conflict with the design of gc memory pool. createConvertMemRefToCPURuntime can be disabled, until the bufferDeallocation is solved.

Thanks. Issue #278 will not occur after disabling createConvertMemRefToCPURuntime

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CPU
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants