Skip to content

Commit

Permalink
add CastInfo to support cast Interface to Op
Browse files Browse the repository at this point in the history
  • Loading branch information
lipracer committed Jul 12, 2024
1 parent a58e3d2 commit 8057676
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 8 deletions.
10 changes: 10 additions & 0 deletions config.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
cmake -G Ninja llvm -B build \
-DCMAKE_C_COMPILER=clang \
-DCMAKE_CXX_COMPILER=clang++ \
-DLLVM_ENABLE_LLD=OFF \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_BUILD_EXAMPLES=ON \
-DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_INCLUDE_INTEGRATION_TESTS=ON
46 changes: 44 additions & 2 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -2157,7 +2157,7 @@ struct CastInfo<
void>> : NullableValueCastFailed<To>,
DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {

static bool isPossible(From &val) {
static inline bool isPossible(From &val) {
if constexpr (std::is_same_v<To, From>)
return true;
else
Expand All @@ -2166,7 +2166,49 @@ struct CastInfo<
const_cast<std::remove_const_t<From> &>(val).getOperation());
}

static To doCast(From &val) {
static inline To doCast(From &val) {
return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
}
};

template <typename OpT, typename = void>
struct is_concrete_op_type : public std::false_type {};

template <typename OpT, template <typename T> typename... Traits>
constexpr auto concrete_op_base_type_impl(std::tuple<Traits<OpT>...>) {
return mlir::Op<OpT, Traits...>(nullptr);
}

template <typename OpT>
using concrete_op_base_type =
decltype(concrete_op_base_type_impl<OpT>(typename OpT::traits()));

template <typename OpT>
struct is_concrete_op_type<
OpT, std::enable_if_t<std::is_base_of_v<concrete_op_base_type<OpT>, OpT>>>
: public std::true_type {};

template <typename To, typename From>
struct CastInfo<
To, From,
std::enable_if_t<
is_concrete_op_type<To>() &&
std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
typename std::remove_const_t<
From>::InterfaceTraits>,
std::remove_const_t<From>>>>
: NullableValueCastFailed<To>,
DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {

static inline bool isPossible(From &val) {
if constexpr (std::is_same_v<To, From>)
return true;
else
return isa<To>(
const_cast<std::remove_const_t<From> &>(val).getOperation());
}

static inline To doCast(From &val) {
return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
}
};
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/TableGen/Class.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,8 @@ class ParentClass {
/// Write the parent class declaration.
void writeTo(raw_indented_ostream &os) const;

friend class OpClass;

private:
/// The fully resolved C++ name of the parent class.
std::string name;
Expand Down
9 changes: 9 additions & 0 deletions mlir/tools/mlir-tblgen/OpClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration,
}

void OpClass::finalize() {
std::string traitList;
llvm::raw_string_ostream os(traitList);
iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1,
std::end(parent.templateParams));
llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) {
os << trait << "<" << getClassName().str() << ">";
});
declare<UsingDeclaration>("traits", "std::tuple<" + traitList + ">");
Class::finalize();

declare<VisibilityDeclaration>(Visibility::Public);
declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
}
20 changes: 14 additions & 6 deletions mlir/unittests/IR/InterfaceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/TypeSwitch.h"

Expand Down Expand Up @@ -88,7 +89,7 @@ TEST(InterfaceTest, TestImplicitConversion) {
EXPECT_EQ(typeA, typeB);
}

TEST(OperationInterfaceTest, CastOpToInterface) {
TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) {
DialectRegistry registry;
MLIRContext ctx;

Expand All @@ -105,13 +106,20 @@ TEST(OperationInterfaceTest, CastOpToInterface) {
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();

static_assert(std::is_base_of_v<llvm::concrete_op_base_type<arith::AddIOp>,
arith::AddIOp>,
"");
static_assert(llvm::is_concrete_op_type<arith::AddIOp>(), "");
static_assert(!llvm::is_concrete_op_type<OpAsmOpInterface>(), "");

OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);

bool constantOp =
llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
.Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
return std::is_same_v<decltype(op), arith::ConstantOp>;
});
bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
.Case<arith::AddIOp, arith::ConstantOp>([&](auto op) {
bool is_same =
std::is_same_v<decltype(op), arith::ConstantOp>;
return is_same;
});

EXPECT_TRUE(constantOp);

Expand Down

0 comments on commit 8057676

Please sign in to comment.