-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BACKEND] Fix the combineSelectAndIf when the user of select in ifOp. (…
…#5031) The CombineTensorSelectAndIf pass currently doesn’t work correctly **when the user of select is inside the scf.if block**. For example: ```mlir %select = arith.select %cond, %trueVal, %falseVal : i32 %if = scf.if %cond -> (i32) { %sub = arith.subi %select, %val1 : i32 scf.yield %sub : i32 } else { %mul = arith.muli %select, %val2 : i32 scf.yield %mul : i32 } use %select ``` In this case, dom.dominates(ifOp, user) will return true, but directly using replaceAllUsesWith would lead to incorrect replacement behavior. ```mlir // without this pr (the user in ifOp use the result of ifOp) %if:2 = scf.if %cond -> (i32, i32) { %sub = arith.subi %if#1, %val1 : i32 scf.yield %sub, %trueVal : i32, i32 } else { %mul = arith.muli %if#1, %val2 : i32 scf.yield %mul, %falseVal : i32, i32 } use %if#1 ``` To address this, we need to adjust the user’s operand based on the specific region it is in. ```mlir // with this pr (the user in ifOp be canonicaled first) %if:2 = scf.if %cond -> (i32, i32) { %sub = arith.subi %trueVal, %val1 : i32 scf.yield %sub, %trueVal : i32, i32 } else { %mul = arith.muli %falseVal, %val2 : i32 scf.yield %mul, %falseVal : i32, i32 } use %if#1 ```
- Loading branch information
1 parent
0b443ce
commit 73df068
Showing
2 changed files
with
107 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,77 @@ | ||
// RUN: triton-opt %s -split-input-file -tritongpu-combine-tensor-select-and-if | FileCheck %s | ||
|
||
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> | ||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { | ||
// CHECK-LABEL: @select_if_combine | ||
tt.func public @select_if_combine(%arg0: tensor<64xf32, #blocked>, %dst_ptr: tensor<64x!tt.ptr<f32>, #blocked>, %cnd: i1) attributes {noinline = false} { | ||
// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> | ||
%cst = arith.constant dense<0.000000e+00> : tensor<64xf32, #blocked> | ||
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> | ||
%cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32, #blocked> | ||
// CHECK-NOT: arith.select | ||
%sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32, #blocked> | ||
// CHECK: %[[IF_RES:.*]] = scf.if | ||
scf.if %cnd { | ||
tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr<f32>, #blocked> | ||
// CHECK: scf.yield %[[CST0]] | ||
} | ||
// CHECK: else | ||
// CHECK: scf.yield %[[CST1]] | ||
// CHECK: tt.store %{{.*}}, %[[IF_RES]] | ||
tt.store %dst_ptr, %sel : tensor<64x!tt.ptr<f32>, #blocked> | ||
tt.return | ||
tt.func public @select_if_combine(%arg0: tensor<64xf32>, %dst_ptr: tensor<64x!tt.ptr<f32>>, %cnd: i1) { | ||
// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> | ||
%cst = arith.constant dense<0.000000e+00> : tensor<64xf32> | ||
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> | ||
%cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32> | ||
// CHECK-NOT: arith.select | ||
%sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32> | ||
// CHECK: %[[R:.+]] = scf.if %{{.*}} | ||
// CHECK: tt.store %{{.*}}, %{{.*}} | ||
// CHECK: scf.yield %[[CST0]] | ||
// CHECK: } else { | ||
// CHECK: scf.yield %[[CST1]] | ||
// CHECK: } | ||
scf.if %cnd { | ||
tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr<f32>> | ||
} | ||
// CHECK: tt.store %{{.*}}, %[[R]] | ||
tt.store %dst_ptr, %sel : tensor<64x!tt.ptr<f32>> | ||
tt.return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @if_multiple_sel | ||
tt.func @if_multiple_sel(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32){ | ||
// CHECK-NOT: select | ||
// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { | ||
// CHECK: scf.yield {{.*}} : i32, i32, f32 | ||
// CHECK: } else { | ||
// CHECK: scf.yield {{.*}} : i32, i32, f32 | ||
// CHECK: } | ||
// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 | ||
// CHECK-NOT: arith.select | ||
%0 = arith.select %arg0, %arg1, %arg2 : i32 | ||
%1 = arith.select %arg0, %arg3, %arg4 : f32 | ||
// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { | ||
// CHECK: scf.yield {{.*}} : i32, i32, f32 | ||
// CHECK: } else { | ||
// CHECK: scf.yield {{.*}} : i32, i32, f32 | ||
// CHECK: } | ||
%2 = scf.if %arg0 -> (i32) { | ||
%3 = arith.subi %arg1, %arg2 : i32 | ||
scf.yield %3 : i32 | ||
} else { | ||
scf.yield %arg1 : i32 | ||
} | ||
// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 | ||
tt.return %0, %1, %2 : i32, f32, i32 | ||
} | ||
|
||
// ----- | ||
// CHECK-LABEL: tt.func @users_in_if( | ||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i1 | ||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i32 | ||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: i32 | ||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: f32 | ||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: f32 | ||
tt.func @users_in_if(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32, i32) { | ||
// CHECK: %[[CST:.*]] = arith.constant 8 : i32 | ||
%c8_i32 = arith.constant 8 : i32 | ||
// CHECK-NOT: arith.select | ||
%0 = arith.select %arg0, %arg1, %arg2 : i32 | ||
%1 = arith.select %arg0, %arg3, %arg4 : f32 | ||
// CHECK: %[[R:.+]]:4 = scf.if %[[ARG0]] -> (i32, i32, i32, f32) { | ||
// CHECK: %[[MULI:.*]] = arith.muli %[[ARG1]], %[[ARG2]] : i32 | ||
// CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[CST]] : i32 | ||
// CHECK: scf.yield %[[MULI]], %[[ADDI]], %[[ARG1]], %[[ARG3]] : i32, i32, i32, f32 | ||
// CHECK: } else { | ||
// CHECK: %[[ADDI:.*]] = arith.subi %[[ARG2]], %[[CST]] : i32 | ||
// CHECK: scf.yield %[[ARG1]], %[[ADDI]], %[[ARG2]], %[[ARG4]] : i32, i32, i32, f32 | ||
// CHECK: } | ||
%2:2 = scf.if %arg0 -> (i32, i32) { | ||
%3 = arith.muli %0, %arg2 : i32 | ||
%4 = arith.addi %0, %c8_i32 : i32 | ||
scf.yield %3, %4 : i32, i32 | ||
} else { | ||
%3 = arith.subi %0, %c8_i32 : i32 | ||
scf.yield %arg1, %3 : i32, i32 | ||
} | ||
// CHECK: tt.return %[[R]]#2, %[[R]]#3, %[[R]]#0, %[[R]]#1 : i32, f32, i32, i32 | ||
tt.return %0, %1, %2#0, %2#1 : i32, f32, i32, i32 | ||
} |