Skip to content

Commit

Permalink
[mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIR…
Browse files Browse the repository at this point in the history
… interface>

Fixes #86647
  • Loading branch information
lipracer committed Mar 29, 2024
1 parent e6f63a9 commit 3cd0fc9
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
28 changes: 28 additions & 0 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/PointerLikeTypeTraits.h"

#include <optional>
Expand Down Expand Up @@ -2110,6 +2111,33 @@ struct DenseMapInfo<T,
}
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};

template <typename To, typename From>
struct CastInfo<
To, From,
std::enable_if_t<
std::is_base_of_v<mlir::OpInterface<To, typename To::InterfaceTraits>,
To> &&
std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
typename std::remove_const_t<
From>::InterfaceTraits>,
std::remove_const_t<From>>,
void>> : NullableValueCastFailed<To>,
DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {

static bool isPossible(From &val) {
if constexpr (std::is_same_v<To, From>)
return true;
else
return isa<To>(
const_cast<std::remove_const_t<From> &>(val).getOperation());
}

static To doCast(From &val) {
return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
}
};

} // namespace llvm

#endif
3 changes: 2 additions & 1 deletion mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
// Emit the main interface class declaration.
os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
"public:\n"
" using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
" using ::mlir::{3}<{1}, detail::{2}>::{3};\n"
" using InterfaceTraits = detail::{2};\n",
interfaceName, interfaceName, interfaceTraitsName,
interfaceBaseType);

Expand Down
28 changes: 28 additions & 0 deletions mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/TypeSwitch.h"

#include <gtest/gtest.h>

Expand Down Expand Up @@ -103,3 +104,30 @@ TEST_F(ValueShapeRangeTest, SettingShapes) {
EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
EXPECT_FALSE(range.getShape(2));
}

TEST(OperationInterfaceTest, CastOpToInterface) {
DialectRegistry registry;
MLIRContext ctx;

const char *ir = R"MLIR(
func.func @map(%arg : tensor<1xi64>) {
%0 = arith.constant dense<[10]> : tensor<1xi64>
%1 = arith.addi %arg, %0 : tensor<1xi64>
return
}
)MLIR";

registry.insert<func::FuncDialect, arith::ArithDialect>();
ctx.appendDialectRegistry(registry);
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();

::mlir::OpAsmOpInterface interface = llvm::cast<::mlir::OpAsmOpInterface>(op);

std::string funcName;
llvm::TypeSwitch<::mlir::OpAsmOpInterface, void>(interface)
.Case<::mlir::VectorUnrollOpInterface, arith::ConstantOp>(
[&](auto op) { funcName = __PRETTY_FUNCTION__; });

EXPECT_TRUE(funcName.find("ConstantOp") != std::string::npos);
}

0 comments on commit 3cd0fc9

Please sign in to comment.