diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp index b7008690b6ab..a1b579e8ba3e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp @@ -304,6 +304,10 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern { return false; } + if (!llvm::all_of(producer.getIndexingMapsArray(), + [](AffineMap map) { return map.isIdentity(); })) + return false; + std::optional castOp = getDefiningNonI1ExtendingCastOp(operand.get()); if (!castOp) { diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir index a1cd2d63216e..c84f128ed15e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir @@ -566,6 +566,33 @@ util.func public @matmul_extsi(%arg0 : tensor<10x20xi32>, // CHECK: util.return %[[RESULT]] // ----- +// Regression test. extsi is transposed, dont't fuse into matmul. +util.func public @matmul_extsi_transposed(%arg0 : tensor<10x20xi32>, + %arg1 : tensor<40x20xi16>) -> tensor<10x40xi32> { + %0 = tensor.empty() : tensor<20x40xi32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg1 : tensor<40x20xi16>) outs(%0 : tensor<20x40xi32>) { + ^bb0(%b0 : i16, %b1 : i32): + %e = arith.extsi %b0 : i16 to i32 + linalg.yield %e : i32 + } -> tensor<20x40xi32> + %2 = tensor.empty() : tensor<10x40xi32> + %3 = arith.constant 0 : i32 + %4 = linalg.fill ins(%3 : i32) outs(%2 : tensor<10x40xi32>) -> tensor<10x40xi32> + %5 = linalg.matmul ins(%arg0, %1 : tensor<10x20xi32>, tensor<20x40xi32>) + outs(%4 : tensor<10x40xi32>) -> tensor<10x40xi32> + util.return %5 : tensor<10x40xi32> +} +// CHECK-LABEL: util.func public @matmul_extsi_transposed +// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<40x20xi16> +// CHECK: %[[GEN:.+]] = linalg.generic +// CHECK: %[[RESULT:.+]] = linalg.matmul ins(%[[ARG0]], %[[GEN]] +// CHECK: util.return %[[RESULT]] +// ----- + util.func public @matmul_extsi_a(%arg0 : tensor<10x20xi16>, %arg1 : tensor<20x40xi32>) -> tensor<10x40xi32> { %0 = tensor.empty() : tensor<10x20xi32>