diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 8687be075ea678..866ab0d2228f79 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -859,7 +859,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { OpFoldResult PadOp::fold(FoldAdaptor adaptor) { // If the pad is all zeros we can fold this operation away. - if (adaptor.getPadding()) { + if (adaptor.getPadding() && getInput1().getType() == getType()) { auto densePad = llvm::cast(adaptor.getPadding()); if (densePad.isSplat() && densePad.getSplatValue().isZero()) { return getInput1(); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index accc792c8f2aca..3bcf58015831ba 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -217,6 +217,20 @@ func.func @pad_noop(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: @pad_noop_type_mismatch_nofold +func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor { + // CHECK: %[[PAD:.+]] = tosa.pad + // CHECK: return %[[PAD]] + + %c0_i32 = arith.constant 0 : i32 + %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32> + + %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @pad_determine_val_i32 func.func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor}