Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIRinterface> #87145

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

lipracer
Copy link
Member

@lipracer lipracer commented Mar 30, 2024

Fixes #86647

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 30, 2024
@llvmbot
Copy link

llvmbot commented Mar 30, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: long.chen (lipracer)

Changes

… interface>

Fixes #86647


Full diff: https://github.com/llvm/llvm-project/pull/87145.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/OpDefinition.h (+28)
  • (modified) mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (+2-1)
  • (modified) mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp (+28)
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bd68c27445744e..6da00b4c549a3d 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -22,6 +22,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/ODSSupport.h"
 #include "mlir/IR/Operation.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
 #include <optional>
@@ -2110,6 +2111,33 @@ struct DenseMapInfo<T,
   }
   static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
 };
+
+template <typename To, typename From>
+struct CastInfo<
+    To, From,
+    std::enable_if_t<
+        std::is_base_of_v<mlir::OpInterface<To, typename To::InterfaceTraits>,
+                          To> &&
+            std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
+                                                typename std::remove_const_t<
+                                                    From>::InterfaceTraits>,
+                              std::remove_const_t<From>>,
+        void>> : NullableValueCastFailed<To>,
+                 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+  static bool isPossible(From &val) {
+    if constexpr (std::is_same_v<To, From>)
+      return true;
+    else
+      return isa<To>(
+          const_cast<std::remove_const_t<From> &>(val).getOperation());
+  }
+
+  static To doCast(From &val) {
+    return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+  }
+};
+
 } // namespace llvm
 
 #endif
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2a7406f42f34b5..c6409e9ec30ec9 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
   // Emit the main interface class declaration.
   os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
                       "public:\n"
-                      "  using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
+                      "  using ::mlir::{3}<{1}, detail::{2}>::{3};\n"
+                      "  using InterfaceTraits = detail::{2};\n",
                       interfaceName, interfaceName, interfaceTraitsName,
                       interfaceBaseType);
 
diff --git a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
index 2fc8a43f7c04c6..686bac8c8aaac0 100644
--- a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Parser/Parser.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 #include <gtest/gtest.h>
 
@@ -103,3 +104,30 @@ TEST_F(ValueShapeRangeTest, SettingShapes) {
   EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
   EXPECT_FALSE(range.getShape(2));
 }
+
+TEST(OperationInterfaceTest, CastOpToInterface) {
+  DialectRegistry registry;
+  MLIRContext ctx;
+
+  const char *ir = R"MLIR(
+      func.func @map(%arg : tensor<1xi64>) {
+        %0 = arith.constant dense<[10]> : tensor<1xi64>
+        %1 = arith.addi %arg, %0 : tensor<1xi64>
+        return
+      }
+    )MLIR";
+
+  registry.insert<func::FuncDialect, arith::ArithDialect>();
+  ctx.appendDialectRegistry(registry);
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
+
+  ::mlir::OpAsmOpInterface interface = llvm::cast<::mlir::OpAsmOpInterface>(op);
+
+  std::string funcName;
+  llvm::TypeSwitch<::mlir::OpAsmOpInterface, void>(interface)
+      .Case<::mlir::VectorUnrollOpInterface, arith::ConstantOp>(
+          [&](auto op) { funcName = __PRETTY_FUNCTION__; });
+
+  EXPECT_TRUE(funcName.find("ConstantOp") != std::string::npos);
+}

@lipracer lipracer marked this pull request as draft March 30, 2024 07:19
@lipracer lipracer marked this pull request as ready for review March 30, 2024 07:21
@lipracer lipracer marked this pull request as draft March 30, 2024 07:52
@joker-eph joker-eph changed the title [mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIR… [mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIRinterface> Mar 30, 2024
Copy link

github-actions bot commented Apr 1, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@lipracer
Copy link
Member Author

lipracer commented Apr 1, 2024

After I specialized the CastInfo class, it can support the conversion from OpInterface to OpInterface. However, whether it is gcc-9 or clang-13, it will also give me a warning message returning reference to local temporary object. However, when I actually run tests, It will not execute it at this location.I think one of them may be that the compiler has not yet found a specialized implementation.

@lipracer lipracer marked this pull request as ready for review April 1, 2024 13:42
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning does not give me confidence about the fix at this point.

@lipracer
Copy link
Member Author

lipracer commented Apr 2, 2024

It is warning message on my local machine(clang-13 or gcc-9.4.0):
83471b7dc64933e2d7748120690ffdc8

@joker-eph
Copy link
Collaborator

I asked in the issue because you posted there first I believe, but we can discuss here instead if you'd like:

However, when I actually run tests, It will not execute it at this location.

How do you know it does not execute this code?

@lipracer
Copy link
Member Author

lipracer commented Apr 2, 2024

Sorry, I placed it here in the hope that the reviewer can understand the background. I compiled the debug mode and then executed it step by step, knowing that it would not be executed here.

@joker-eph
Copy link
Collaborator

OK I see:

    if (!Self::isPossible(f))
      return castFailed();

Here the if (!Self::isPossible(f)) will always succeed in your case?

I am not sure how to restructure this all (and I'm too busy preparing for euroLLVM at the moment) but we can't have that warning in-tree I think.

@lipracer
Copy link
Member Author

lipracer commented Apr 2, 2024

we can't have that warning in-tree I think.

I couldn't agree more.I am also trying to eliminate the warning.

@lipracer
Copy link
Member Author

iree-llvm-sandbox CI is green. Is it necessary to merge this feature? If so, I will organize the code and update this PR. If it is not needed for the time being, I will close this issue.

@ingomueller-net
Copy link
Contributor

The CI is green because I found a work-around: I can use a TypeSwitch<Operation *>, which works as intended (see the main message of #86647).

@lipracer
Copy link
Member Author

iree-org/iree-llvm-sandbox#851 (review) I have removed workaround.

@joker-eph
Copy link
Collaborator

@River707 could you look at this maybe?

mlir/include/mlir/IR/OpDefinition.h Outdated Show resolved Hide resolved
mlir/include/mlir/TableGen/Class.h Outdated Show resolved Hide resolved
mlir/include/mlir/IR/OpDefinition.h Outdated Show resolved Hide resolved
@lipracer lipracer force-pushed the opinterface-cast branch 2 times, most recently from 1b21dd2 to 1391d96 Compare November 5, 2024 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Undefined behavior in CastInfo::castFailed with From=<MLIR interface>
5 participants