Skip to content

Commit

Permalink
[mlir][tosa] Fix for incorrect cannonicalization of tosa.pad (#98356)
Browse files Browse the repository at this point in the history
The current fold method for tosa.pad can produce invalid IR by replacing
the padded value with the tosa.pad is a noop. When the type of the input
value does not match the type of the tosa.pad, the canonicalizer detects
the change in types and asserts.

This change addresses the issue by avoiding folding when the input and
result types do not match.
  • Loading branch information
sabauma authored Jul 11, 2024
1 parent b64c1de commit 2fb53f3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {

OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
// If the pad is all zeros we can fold this operation away.
if (adaptor.getPadding()) {
if (adaptor.getPadding() && getInput1().getType() == getType()) {
auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
return getInput1();
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,20 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {

// -----

// CHECK-LABEL: @pad_noop_type_mismatch_nofold
func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32> {
// CHECK: %[[PAD:.+]] = tosa.pad
// CHECK: return %[[PAD]]

%c0_i32 = arith.constant 0 : i32
%shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>

%0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}

// -----

// CHECK-LABEL: @pad_determine_val_i32
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
Expand Down

0 comments on commit 2fb53f3

Please sign in to comment.