-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIRinterface> #87145
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: long.chen (lipracer) Changes… interface> Fixes #86647 Full diff: https://github.com/llvm/llvm-project/pull/87145.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bd68c27445744e..6da00b4c549a3d 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -22,6 +22,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/Operation.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include <optional>
@@ -2110,6 +2111,33 @@ struct DenseMapInfo<T,
}
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};
+
+template <typename To, typename From>
+struct CastInfo<
+ To, From,
+ std::enable_if_t<
+ std::is_base_of_v<mlir::OpInterface<To, typename To::InterfaceTraits>,
+ To> &&
+ std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
+ typename std::remove_const_t<
+ From>::InterfaceTraits>,
+ std::remove_const_t<From>>,
+ void>> : NullableValueCastFailed<To>,
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+ static bool isPossible(From &val) {
+ if constexpr (std::is_same_v<To, From>)
+ return true;
+ else
+ return isa<To>(
+ const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+
+ static To doCast(From &val) {
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+};
+
} // namespace llvm
#endif
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2a7406f42f34b5..c6409e9ec30ec9 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
// Emit the main interface class declaration.
os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
"public:\n"
- " using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
+ " using ::mlir::{3}<{1}, detail::{2}>::{3};\n"
+ " using InterfaceTraits = detail::{2};\n",
interfaceName, interfaceName, interfaceTraitsName,
interfaceBaseType);
diff --git a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
index 2fc8a43f7c04c6..686bac8c8aaac0 100644
--- a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Parser/Parser.h"
+#include "llvm/ADT/TypeSwitch.h"
#include <gtest/gtest.h>
@@ -103,3 +104,30 @@ TEST_F(ValueShapeRangeTest, SettingShapes) {
EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
EXPECT_FALSE(range.getShape(2));
}
+
+TEST(OperationInterfaceTest, CastOpToInterface) {
+ DialectRegistry registry;
+ MLIRContext ctx;
+
+ const char *ir = R"MLIR(
+ func.func @map(%arg : tensor<1xi64>) {
+ %0 = arith.constant dense<[10]> : tensor<1xi64>
+ %1 = arith.addi %arg, %0 : tensor<1xi64>
+ return
+ }
+ )MLIR";
+
+ registry.insert<func::FuncDialect, arith::ArithDialect>();
+ ctx.appendDialectRegistry(registry);
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
+
+ ::mlir::OpAsmOpInterface interface = llvm::cast<::mlir::OpAsmOpInterface>(op);
+
+ std::string funcName;
+ llvm::TypeSwitch<::mlir::OpAsmOpInterface, void>(interface)
+ .Case<::mlir::VectorUnrollOpInterface, arith::ConstantOp>(
+ [&](auto op) { funcName = __PRETTY_FUNCTION__; });
+
+ EXPECT_TRUE(funcName.find("ConstantOp") != std::string::npos);
+}
|
3cd0fc9
to
dd0c480
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
dd0c480
to
3ea3d4b
Compare
After I specialized the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning does not give me confidence about the fix at this point.
I asked in the issue because you posted there first I believe, but we can discuss here instead if you'd like:
How do you know it does not execute this code? |
Sorry, I placed it here in the hope that the reviewer can understand the background. I compiled the debug mode and then executed it step by step, knowing that it would not be executed here. |
OK I see:
Here the I am not sure how to restructure this all (and I'm too busy preparing for euroLLVM at the moment) but we can't have that warning in-tree I think. |
I couldn't agree more.I am also trying to eliminate the warning. |
8057676
to
e00ec57
Compare
iree-llvm-sandbox CI is green. Is it necessary to merge this feature? If so, I will organize the code and update this PR. If it is not needed for the time being, I will close this issue. |
The CI is green because I found a work-around: I can use a |
iree-org/iree-llvm-sandbox#851 (review) I have removed workaround. |
e00ec57
to
3be1ad2
Compare
@River707 could you look at this maybe? |
1b21dd2
to
1391d96
Compare
1391d96
to
11f7d95
Compare
Fixes #86647