From 72f2e1aa62dbfc222ed46d321caea822f57a179a Mon Sep 17 00:00:00 2001 From: vporpo Date: Mon, 8 Jul 2024 20:10:28 -0700 Subject: [PATCH] [SandboxIR] Add BasicBlock and adds functionality to Function and Context (#97637) We can now create SandboxIR from LLVM IR using the Context::create* functions. --- llvm/include/llvm/SandboxIR/SandboxIR.h | 157 +++++++++++++ .../llvm/SandboxIR/SandboxIRValues.def | 1 + llvm/lib/SandboxIR/SandboxIR.cpp | 219 +++++++++++++++++- llvm/unittests/SandboxIR/SandboxIRTest.cpp | 135 +++++++++++ 4 files changed, 508 insertions(+), 4 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index ab6273a7ace665..8416e082aec3db 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -62,12 +62,15 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/Support/raw_ostream.h" +#include namespace llvm { namespace sandboxir { +class Function; class Context; +class Instruction; /// A SandboxIR Value has users. This is the base class. class Value { @@ -106,6 +109,8 @@ class Value { /// NOTE: Some SBInstructions, like Packs, may include more than one value. llvm::Value *Val = nullptr; + friend class Context; // For getting `Val`. + /// All values point to the context. Context &Ctx; // This is used by eraseFromParent(). @@ -205,6 +210,48 @@ class Constant : public sandboxir::User { #endif }; +/// The BasicBlock::iterator. +class BBIterator { +public: + using difference_type = std::ptrdiff_t; + using value_type = Instruction; + using pointer = value_type *; + using reference = value_type &; + using iterator_category = std::bidirectional_iterator_tag; + +private: + llvm::BasicBlock *BB; + llvm::BasicBlock::iterator It; + Context *Ctx; + pointer getInstr(llvm::BasicBlock::iterator It) const; + +public: + BBIterator() : BB(nullptr), Ctx(nullptr) {} + BBIterator(llvm::BasicBlock *BB, llvm::BasicBlock::iterator It, Context *Ctx) + : BB(BB), It(It), Ctx(Ctx) {} + reference operator*() const { return *getInstr(It); } + BBIterator &operator++(); + BBIterator operator++(int) { + auto Copy = *this; + ++*this; + return Copy; + } + BBIterator &operator--(); + BBIterator operator--(int) { + auto Copy = *this; + --*this; + return Copy; + } + bool operator==(const BBIterator &Other) const { + assert(Ctx == Other.Ctx && "BBIterators in different context!"); + return It == Other.It; + } + bool operator!=(const BBIterator &Other) const { return !(*this == Other); } + /// \Returns the SBInstruction that corresponds to this iterator, or null if + /// the instruction is not found in the IR-to-SandboxIR tables. + pointer get() const { return getInstr(It); } +}; + /// A sandboxir::User with operands and opcode. class Instruction : public sandboxir::User { public: @@ -231,6 +278,8 @@ class Instruction : public sandboxir::User { return OS; } #endif + /// This is used by BasicBlock::iterator. + virtual unsigned getNumOfIRInstrs() const = 0; /// For isa/dyn_cast. static bool classof(const sandboxir::Value *From); @@ -256,6 +305,7 @@ class OpaqueInst : public sandboxir::Instruction { static bool classof(const sandboxir::Value *From) { return From->getSubclassID() == ClassID::Opaque; } + unsigned getNumOfIRInstrs() const final { return 1u; } #ifndef NDEBUG void verify() const final { // Nothing to do @@ -270,6 +320,54 @@ class OpaqueInst : public sandboxir::Instruction { #endif }; +class BasicBlock : public Value { + /// Builds a graph that contains all values in \p BB in their original form + /// i.e., no vectorization is taking place here. + void buildBasicBlockFromLLVMIR(llvm::BasicBlock *LLVMBB); + friend class Context; // For `buildBasicBlockFromIR` + +public: + BasicBlock(llvm::BasicBlock *BB, Context &SBCtx) + : Value(ClassID::Block, BB, SBCtx) { + buildBasicBlockFromLLVMIR(BB); + } + ~BasicBlock() = default; + /// For isa/dyn_cast. + static bool classof(const Value *From) { + return From->getSubclassID() == Value::ClassID::Block; + } + Function *getParent() const; + using iterator = BBIterator; + iterator begin() const; + iterator end() const { + auto *BB = cast(Val); + return iterator(BB, BB->end(), &Ctx); + } + std::reverse_iterator rbegin() const { + return std::make_reverse_iterator(end()); + } + std::reverse_iterator rend() const { + return std::make_reverse_iterator(begin()); + } + Context &getContext() const { return Ctx; } + Instruction *getTerminator() const; + bool empty() const { return begin() == end(); } + Instruction &front() const; + Instruction &back() const; + +#ifndef NDEBUG + void verify() const final { + assert(isa(Val) && "Expected BasicBlock!"); + } + friend raw_ostream &operator<<(raw_ostream &OS, const BasicBlock &SBBB) { + SBBB.dump(OS); + return OS; + } + void dump(raw_ostream &OS) const final; + LLVM_DUMP_METHOD void dump() const final; +#endif +}; + class Context { protected: LLVMContext &LLVMCtx; @@ -278,12 +376,53 @@ class Context { DenseMap> LLVMValueToValueMap; + /// Take ownership of VPtr and store it in `LLVMValueToValueMap`. + Value *registerValue(std::unique_ptr &&VPtr); + + Value *getOrCreateValueInternal(llvm::Value *V, llvm::User *U = nullptr); + + Argument *getOrCreateArgument(llvm::Argument *LLVMArg) { + auto Pair = LLVMValueToValueMap.insert({LLVMArg, nullptr}); + auto It = Pair.first; + if (Pair.second) { + It->second = std::make_unique(LLVMArg, *this); + return cast(It->second.get()); + } + return cast(It->second.get()); + } + + Value *getOrCreateValue(llvm::Value *LLVMV) { + return getOrCreateValueInternal(LLVMV, 0); + } + + BasicBlock *createBasicBlock(llvm::BasicBlock *BB); + + friend class BasicBlock; // For getOrCreateValue(). + public: Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx) {} + sandboxir::Value *getValue(llvm::Value *V) const; + const sandboxir::Value *getValue(const llvm::Value *V) const { + return getValue(const_cast(V)); + } + + Function *createFunction(llvm::Function *F); + + /// \Returns the number of values registered with Context. + size_t getNumValues() const { return LLVMValueToValueMap.size(); } }; class Function : public sandboxir::Value { + /// Helper for mapped_iterator. + struct LLVMBBToBB { + Context &Ctx; + LLVMBBToBB(Context &Ctx) : Ctx(Ctx) {} + BasicBlock &operator()(llvm::BasicBlock &LLVMBB) const { + return *cast(Ctx.getValue(&LLVMBB)); + } + }; + public: Function(llvm::Function *F, sandboxir::Context &Ctx) : sandboxir::Value(ClassID::Function, F, Ctx) {} @@ -292,6 +431,24 @@ class Function : public sandboxir::Value { return From->getSubclassID() == ClassID::Function; } + Argument *getArg(unsigned Idx) const { + llvm::Argument *Arg = cast(Val)->getArg(Idx); + return cast(Ctx.getValue(Arg)); + } + + size_t arg_size() const { return cast(Val)->arg_size(); } + bool arg_empty() const { return cast(Val)->arg_empty(); } + + using iterator = mapped_iterator; + iterator begin() const { + LLVMBBToBB BBGetter(Ctx); + return iterator(cast(Val)->begin(), BBGetter); + } + iterator end() const { + LLVMBBToBB BBGetter(Ctx); + return iterator(cast(Val)->end(), BBGetter); + } + #ifndef NDEBUG void verify() const final { assert(isa(Val) && "Expected Function!"); diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def index 474b151ae03a41..b090ade3ea0cae 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def +++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def @@ -17,6 +17,7 @@ DEF_VALUE(Argument, Argument) #define DEF_USER(ID, CLASS) #endif DEF_USER(User, User) +DEF_VALUE(Block, BasicBlock) DEF_USER(Constant, Constant) #ifndef DEF_INSTR diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index ea2f15754d3409..bd615f0ee76543 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/SandboxIR/SandboxIR.h" +#include "llvm/IR/Constants.h" #include "llvm/Support/Debug.h" #include @@ -15,7 +16,7 @@ using namespace llvm::sandboxir; Value::Value(ClassID SubclassID, llvm::Value *Val, Context &Ctx) : SubclassID(SubclassID), Val(Val), Ctx(Ctx) { #ifndef NDEBUG - UID = 0; // FIXME: Once SBContext is available. + UID = Ctx.getNumValues(); #endif } @@ -47,8 +48,7 @@ void Value::dumpCommonPrefix(raw_ostream &OS) const { } void Value::dumpCommonSuffix(raw_ostream &OS) const { - OS << " ; " << getName() << " (" << getSubclassIDStr(SubclassID) << ") " - << this; + OS << " ; " << getName() << " (" << getSubclassIDStr(SubclassID) << ")"; } void Value::printAsOperandCommon(raw_ostream &OS) const { @@ -93,6 +93,33 @@ void User::dumpCommonHeader(raw_ostream &OS) const { } #endif // NDEBUG +BBIterator &BBIterator::operator++() { + auto ItE = BB->end(); + assert(It != ItE && "Already at end!"); + ++It; + if (It == ItE) + return *this; + Instruction &NextI = *cast(Ctx->getValue(&*It)); + unsigned Num = NextI.getNumOfIRInstrs(); + assert(Num > 0 && "Bad getNumOfIRInstrs()"); + It = std::next(It, Num - 1); + return *this; +} + +BBIterator &BBIterator::operator--() { + assert(It != BB->begin() && "Already at begin!"); + if (It == BB->end()) { + --It; + return *this; + } + Instruction &CurrI = **this; + unsigned Num = CurrI.getNumOfIRInstrs(); + assert(Num > 0 && "Bad getNumOfIRInstrs()"); + assert(std::prev(It, Num - 1) != BB->begin() && "Already at begin!"); + It = std::prev(It, Num); + return *this; +} + const char *Instruction::getOpcodeName(Opcode Opc) { switch (Opc) { #define DEF_VALUE(ID, CLASS) @@ -148,7 +175,7 @@ void Constant::dump() const { void Function::dumpNameAndArgs(raw_ostream &OS) const { auto *F = cast(Val); - OS << *getType() << " @" << F->getName() << "("; + OS << *F->getReturnType() << " @" << F->getName() << "("; auto NumArgs = F->arg_size(); for (auto [Idx, Arg] : enumerate(F->args())) { auto *SBArg = cast_or_null(Ctx.getValue(&Arg)); @@ -164,6 +191,17 @@ void Function::dumpNameAndArgs(raw_ostream &OS) const { void Function::dump(raw_ostream &OS) const { dumpNameAndArgs(OS); OS << " {\n"; + auto *LLVMF = cast(Val); + interleave( + *LLVMF, + [this, &OS](const llvm::BasicBlock &LLVMBB) { + auto *BB = cast_or_null(Ctx.getValue(&LLVMBB)); + if (BB == nullptr) + OS << "NULL"; + else + OS << *BB; + }, + [&OS] { OS << "\n"; }); OS << "}\n"; } void Function::dump() const { @@ -172,9 +210,182 @@ void Function::dump() const { } #endif // NDEBUG +BasicBlock::iterator::pointer +BasicBlock::iterator::getInstr(llvm::BasicBlock::iterator It) const { + return cast_or_null(Ctx->getValue(&*It)); +} + +Value *Context::registerValue(std::unique_ptr &&VPtr) { + assert(VPtr->getSubclassID() != Value::ClassID::User && + "Can't register a user!"); + Value *V = VPtr.get(); + llvm::Value *Key = V->Val; + LLVMValueToValueMap[Key] = std::move(VPtr); + return V; +} + +Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) { + auto Pair = LLVMValueToValueMap.insert({LLVMV, nullptr}); + auto It = Pair.first; + if (!Pair.second) + return It->second.get(); + + if (auto *C = dyn_cast(LLVMV)) { + for (llvm::Value *COp : C->operands()) + getOrCreateValueInternal(COp, C); + It->second = std::make_unique(C, *this); + return It->second.get(); + } + if (auto *Arg = dyn_cast(LLVMV)) { + It->second = std::make_unique(Arg, *this); + return It->second.get(); + } + if (auto *BB = dyn_cast(LLVMV)) { + assert(isa(U) && + "This won't create a SBBB, don't call this function directly!"); + if (auto *SBBB = getValue(BB)) + return SBBB; + return nullptr; + } + assert(isa(LLVMV) && "Expected Instruction"); + It->second = + std::make_unique(cast(LLVMV), *this); + return It->second.get(); +} + +BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) { + assert(getValue(LLVMBB) == nullptr && "Already exists!"); + auto NewBBPtr = std::make_unique(LLVMBB, *this); + auto *BB = cast(registerValue(std::move(NewBBPtr))); + // Create SandboxIR for BB's body. + BB->buildBasicBlockFromLLVMIR(LLVMBB); + return BB; +} + Value *Context::getValue(llvm::Value *V) const { auto It = LLVMValueToValueMap.find(V); if (It != LLVMValueToValueMap.end()) return It->second.get(); return nullptr; } + +Function *Context::createFunction(llvm::Function *F) { + assert(getValue(F) == nullptr && "Already exists!"); + auto NewFPtr = std::make_unique(F, *this); + // Create arguments. + for (auto &Arg : F->args()) + getOrCreateArgument(&Arg); + // Create BBs. + for (auto &BB : *F) + createBasicBlock(&BB); + auto *SBF = cast(registerValue(std::move(NewFPtr))); + return SBF; +} + +Function *BasicBlock::getParent() const { + auto *BB = cast(Val); + auto *F = BB->getParent(); + if (F == nullptr) + // Detached + return nullptr; + return cast_or_null(Ctx.getValue(F)); +} + +void BasicBlock::buildBasicBlockFromLLVMIR(llvm::BasicBlock *LLVMBB) { + for (llvm::Instruction &IRef : reverse(*LLVMBB)) { + llvm::Instruction *I = &IRef; + Ctx.getOrCreateValue(I); + for (auto [OpIdx, Op] : enumerate(I->operands())) { + // Skip instruction's label operands + if (isa(Op)) + continue; + // Skip metadata + if (isa(Op)) + continue; + // Skip asm + if (isa(Op)) + continue; + Ctx.getOrCreateValue(Op); + } + } +#if !defined(NDEBUG) && defined(SBVEC_EXPENSIVE_CHECKS) + verify(); +#endif +} + +BasicBlock::iterator BasicBlock::begin() const { + llvm::BasicBlock *BB = cast(Val); + llvm::BasicBlock::iterator It = BB->begin(); + if (!BB->empty()) { + auto *V = Ctx.getValue(&*BB->begin()); + assert(V != nullptr && "No SandboxIR for BB->begin()!"); + auto *I = cast(V); + unsigned Num = I->getNumOfIRInstrs(); + assert(Num >= 1u && "Bad getNumOfIRInstrs()"); + It = std::next(It, Num - 1); + } + return iterator(BB, It, &Ctx); +} + +Instruction *BasicBlock::getTerminator() const { + auto *TerminatorV = + Ctx.getValue(cast(Val)->getTerminator()); + return cast_or_null(TerminatorV); +} + +Instruction &BasicBlock::front() const { + auto *BB = cast(Val); + assert(!BB->empty() && "Empty block!"); + auto *SBI = cast(getContext().getValue(&*BB->begin())); + assert(SBI != nullptr && "Expected Instr!"); + return *SBI; +} + +Instruction &BasicBlock::back() const { + auto *BB = cast(Val); + assert(!BB->empty() && "Empty block!"); + auto *SBI = cast(getContext().getValue(&*BB->rbegin())); + assert(SBI != nullptr && "Expected Instr!"); + return *SBI; +} + +#ifndef NDEBUG +void BasicBlock::dump(raw_ostream &OS) const { + llvm::BasicBlock *BB = cast(Val); + const auto &Name = BB->getName(); + OS << Name; + if (!Name.empty()) + OS << ":\n"; + // If there are Instructions in the BB that are not mapped to SandboxIR, then + // use a crash-proof dump. + if (any_of(*BB, [this](llvm::Instruction &I) { + return Ctx.getValue(&I) == nullptr; + })) { + OS << "\n"; + DenseSet Visited; + for (llvm::Instruction &IRef : *BB) { + Value *SBV = Ctx.getValue(&IRef); + if (SBV == nullptr) + OS << IRef << " *** No SandboxIR ***\n"; + else { + auto *SBI = dyn_cast(SBV); + if (SBI == nullptr) { + OS << IRef << " *** Not a SBInstruction!!! ***\n"; + } else { + if (Visited.insert(SBI).second) + OS << *SBI << "\n"; + } + } + } + } else { + for (auto &SBI : *this) { + SBI.dump(OS); + OS << "\n"; + } + } +} +void BasicBlock::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 0b0409aa15c4f0..e523ae90966d79 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -27,6 +27,12 @@ struct SandboxIRTest : public testing::Test { if (!M) Err.print("SandboxIRTest", errs()); } + BasicBlock *getBasicBlockByName(Function &F, StringRef Name) { + for (BasicBlock &BB : F) + if (BB.getName() == Name) + return &BB; + llvm_unreachable("Expected to find basic block!"); + } }; TEST_F(SandboxIRTest, UserInstantiation) { @@ -89,4 +95,133 @@ define void @foo(i32 %v1) { EXPECT_FALSE(isa(Arg0)); EXPECT_TRUE(isa(Const0)); EXPECT_TRUE(isa(OpaqueI)); + +#ifndef NDEBUG + // The dump() functions should be very forgiving and should not crash even if + // sandboxir has not been built properly. + F.dump(); + Arg0.dump(); + Const0.dump(); + OpaqueI.dump(); +#endif +} + +TEST_F(SandboxIRTest, Function) { + parseIR(C, R"IR( +define void @foo(i32 %arg0, i32 %arg1) { +bb0: + br label %bb1 +bb1: + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + llvm::Argument *LLVMArg0 = LLVMF->getArg(0); + llvm::Argument *LLVMArg1 = LLVMF->getArg(1); + + sandboxir::Context Ctx(C); + sandboxir::Function *F = Ctx.createFunction(LLVMF); + + // Check F arguments + EXPECT_EQ(F->arg_size(), 2u); + EXPECT_FALSE(F->arg_empty()); + EXPECT_EQ(F->getArg(0), Ctx.getValue(LLVMArg0)); + EXPECT_EQ(F->getArg(1), Ctx.getValue(LLVMArg1)); + + // Check F.begin(), F.end(), Function::iterator + llvm::BasicBlock *LLVMBB = &*LLVMF->begin(); + for (sandboxir::BasicBlock &BB : *F) { + EXPECT_EQ(&BB, Ctx.getValue(LLVMBB)); + LLVMBB = LLVMBB->getNextNode(); + } + +#ifndef NDEBUG + { + // Check F.dumpNameAndArgs() + std::string Buff; + raw_string_ostream BS(Buff); + F->dumpNameAndArgs(BS); + EXPECT_EQ(Buff, "void @foo(i32 %arg0, i32 %arg1)"); + } + { + // Check F.dump() + std::string Buff; + raw_string_ostream BS(Buff); + BS << "\n"; + F->dump(BS); + EXPECT_EQ(Buff, R"IR( +void @foo(i32 %arg0, i32 %arg1) { +bb0: + br label %bb1 ; SB3. (Opaque) + +bb1: + ret void ; SB5. (Opaque) +} +)IR"); + } +#endif // NDEBUG +} + +TEST_F(SandboxIRTest, BasicBlock) { + parseIR(C, R"IR( +define void @foo(i32 %v1) { +bb0: + br label %bb1 +bb1: + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + llvm::BasicBlock *LLVMBB0 = getBasicBlockByName(*LLVMF, "bb0"); + llvm::BasicBlock *LLVMBB1 = getBasicBlockByName(*LLVMF, "bb1"); + + sandboxir::Context Ctx(C); + sandboxir::Function *F = Ctx.createFunction(LLVMF); + auto &BB0 = cast(*Ctx.getValue(LLVMBB0)); + auto &BB1 = cast(*Ctx.getValue(LLVMBB1)); + + // Check BB::classof() + EXPECT_TRUE(isa(BB0)); + EXPECT_FALSE(isa(BB0)); + EXPECT_FALSE(isa(BB0)); + EXPECT_FALSE(isa(BB0)); + EXPECT_FALSE(isa(BB0)); + + // Check BB.getParent() + EXPECT_EQ(BB0.getParent(), F); + EXPECT_EQ(BB1.getParent(), F); + + // Check BBIterator, BB.begin(), BB.end(). + llvm::Instruction *LLVMI = &*LLVMBB0->begin(); + for (sandboxir::Instruction &I : BB0) { + EXPECT_EQ(&I, Ctx.getValue(LLVMI)); + LLVMI = LLVMI->getNextNode(); + } + LLVMI = &*LLVMBB1->begin(); + for (sandboxir::Instruction &I : BB1) { + EXPECT_EQ(&I, Ctx.getValue(LLVMI)); + LLVMI = LLVMI->getNextNode(); + } + + // Check BB.getTerminator() + EXPECT_EQ(BB0.getTerminator(), Ctx.getValue(LLVMBB0->getTerminator())); + EXPECT_EQ(BB1.getTerminator(), Ctx.getValue(LLVMBB1->getTerminator())); + + // Check BB.rbegin(), BB.rend() + EXPECT_EQ(&*BB0.rbegin(), BB0.getTerminator()); + EXPECT_EQ(&*std::prev(BB0.rend()), &*BB0.begin()); + +#ifndef NDEBUG + { + // Check BB.dump() + std::string Buff; + raw_string_ostream BS(Buff); + BS << "\n"; + BB0.dump(BS); + EXPECT_EQ(Buff, R"IR( +bb0: + br label %bb1 ; SB2. (Opaque) +)IR"); + } +#endif // NDEBUG }