diff --git a/config.sh b/config.sh new file mode 100644 index 00000000000000..55ab08224a32d1 --- /dev/null +++ b/config.sh @@ -0,0 +1,10 @@ +cmake -G Ninja llvm -B build \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DLLVM_ENABLE_LLD=OFF \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_BUILD_EXAMPLES=ON \ + -DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DMLIR_INCLUDE_INTEGRATION_TESTS=ON diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 5610daadfbecb5..52aac19289cf4e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -2157,7 +2157,7 @@ struct CastInfo< void>> : NullableValueCastFailed, DefaultDoCastIfPossible> { - static bool isPossible(From &val) { + static inline bool isPossible(From &val) { if constexpr (std::is_same_v) return true; else @@ -2166,7 +2166,49 @@ struct CastInfo< const_cast &>(val).getOperation()); } - static To doCast(From &val) { + static inline To doCast(From &val) { + return To(const_cast &>(val).getOperation()); + } +}; + +template +struct is_concrete_op_type : public std::false_type {}; + +template typename... Traits> +constexpr auto concrete_op_base_type_impl(std::tuple...>) { + return mlir::Op(nullptr); +} + +template +using concrete_op_base_type = + decltype(concrete_op_base_type_impl(typename OpT::traits())); + +template +struct is_concrete_op_type< + OpT, std::enable_if_t, OpT>>> + : public std::true_type {}; + +template +struct CastInfo< + To, From, + std::enable_if_t< + is_concrete_op_type() && + std::is_base_of_v, + typename std::remove_const_t< + From>::InterfaceTraits>, + std::remove_const_t>>> + : NullableValueCastFailed, + DefaultDoCastIfPossible> { + + static inline bool isPossible(From &val) { + if constexpr (std::is_same_v) + return true; + else + return isa( + const_cast &>(val).getOperation()); + } + + static inline To doCast(From &val) { return To(const_cast &>(val).getOperation()); } }; diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h index 92fec6a3b11d94..7616f56aa2e3de 100644 --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -520,6 +520,8 @@ class ParentClass { /// Write the parent class declaration. void writeTo(raw_indented_ostream &os) const; + friend class OpClass; + private: /// The fully resolved C++ name of the parent class. std::string name; diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp index 60fa1833ce625e..5426302dfed3e3 100644 --- a/mlir/tools/mlir-tblgen/OpClass.cpp +++ b/mlir/tools/mlir-tblgen/OpClass.cpp @@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration, } void OpClass::finalize() { + std::string traitList; + llvm::raw_string_ostream os(traitList); + iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1, + std::end(parent.templateParams)); + llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) { + os << trait << "<" << getClassName().str() << ">"; + }); + declare("traits", "std::tuple<" + traitList + ">"); Class::finalize(); + declare(Visibility::Public); declare(extraClassDeclaration, extraClassDefinition); } diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp index 741365b3efb5fc..c9ae6938e8b403 100644 --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -18,6 +18,7 @@ #include "../../test/lib/Dialect/Test/TestOps.h" #include "../../test/lib/Dialect/Test/TestTypes.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Parser/Parser.h" #include "llvm/ADT/TypeSwitch.h" @@ -88,7 +89,7 @@ TEST(InterfaceTest, TestImplicitConversion) { EXPECT_EQ(typeA, typeB); } -TEST(OperationInterfaceTest, CastOpToInterface) { +TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) { DialectRegistry registry; MLIRContext ctx; @@ -105,13 +106,20 @@ TEST(OperationInterfaceTest, CastOpToInterface) { OwningOpRef module = parseSourceString(ir, &ctx); Operation &op = cast(module->front()).getBody().front().front(); + static_assert(std::is_base_of_v, + arith::AddIOp>, + ""); + static_assert(llvm::is_concrete_op_type(), ""); + static_assert(!llvm::is_concrete_op_type(), ""); + OpAsmOpInterface interface = llvm::cast(op); - bool constantOp = - llvm::TypeSwitch(interface) - .Case([&](auto op) { - return std::is_same_v; - }); + bool constantOp = llvm::TypeSwitch(interface) + .Case([&](auto op) { + bool is_same = + std::is_same_v; + return is_same; + }); EXPECT_TRUE(constantOp);