From 3993da23daa0ae75e9e80def76854534903e3761 Mon Sep 17 00:00:00 2001 From: vporpo Date: Tue, 23 Jul 2024 11:50:33 -0700 Subject: [PATCH] [SandboxIR] Implement BranchInst (#100063) This patch implements sandboxir::BranchInst which mirrors llvm::BranchInst. BranchInst::swapSuccessors() relies on User::swapOperandsInternal() so this patch also adds Use::swap() and the corresponding tracking code and test. --- llvm/include/llvm/SandboxIR/SandboxIR.h | 107 ++++++++++++++++++ .../llvm/SandboxIR/SandboxIRValues.def | 1 + llvm/include/llvm/SandboxIR/Tracker.h | 21 ++++ llvm/include/llvm/SandboxIR/Use.h | 1 + llvm/lib/SandboxIR/SandboxIR.cpp | 96 ++++++++++++++++ llvm/lib/SandboxIR/Tracker.cpp | 5 + llvm/unittests/SandboxIR/SandboxIRTest.cpp | 105 ++++++++++++++++- llvm/unittests/SandboxIR/TrackerTest.cpp | 43 +++++++ 8 files changed, 377 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index 0c67206d307efe..6c04c92e3e70e7 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -76,6 +76,7 @@ class Context; class Function; class Instruction; class SelectInst; +class BranchInst; class LoadInst; class ReturnInst; class StoreInst; @@ -179,6 +180,7 @@ class Value { friend class User; // For getting `Val`. friend class Use; // For getting `Val`. friend class SelectInst; // For getting `Val`. + friend class BranchInst; // For getting `Val`. friend class LoadInst; // For getting `Val`. friend class StoreInst; // For getting `Val`. friend class ReturnInst; // For getting `Val`. @@ -343,6 +345,14 @@ class User : public Value { virtual unsigned getUseOperandNo(const Use &Use) const = 0; friend unsigned Use::getOperandNo() const; // For getUseOperandNo() + void swapOperandsInternal(unsigned OpIdxA, unsigned OpIdxB) { + assert(OpIdxA < getNumOperands() && "OpIdxA out of bounds!"); + assert(OpIdxB < getNumOperands() && "OpIdxB out of bounds!"); + auto UseA = getOperandUse(OpIdxA); + auto UseB = getOperandUse(OpIdxB); + UseA.swap(UseB); + } + #ifndef NDEBUG void verifyUserOfLLVMUse(const llvm::Use &Use) const; #endif // NDEBUG @@ -504,6 +514,7 @@ class Instruction : public sandboxir::User { /// returns its topmost LLVM IR instruction. llvm::Instruction *getTopmostLLVMInstruction() const; friend class SelectInst; // For getTopmostLLVMInstruction(). + friend class BranchInst; // For getTopmostLLVMInstruction(). friend class LoadInst; // For getTopmostLLVMInstruction(). friend class StoreInst; // For getTopmostLLVMInstruction(). friend class ReturnInst; // For getTopmostLLVMInstruction(). @@ -617,6 +628,100 @@ class SelectInst : public Instruction { #endif }; +class BranchInst : public Instruction { + /// Use Context::createBranchInst(). Don't call the constructor directly. + BranchInst(llvm::BranchInst *BI, Context &Ctx) + : Instruction(ClassID::Br, Opcode::Br, BI, Ctx) {} + friend Context; // for BranchInst() + Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final { + return getOperandUseDefault(OpIdx, Verify); + } + SmallVector getLLVMInstrs() const final { + return {cast(Val)}; + } + +public: + unsigned getUseOperandNo(const Use &Use) const final { + return getUseOperandNoDefault(Use); + } + unsigned getNumOfIRInstrs() const final { return 1u; } + static BranchInst *create(BasicBlock *IfTrue, Instruction *InsertBefore, + Context &Ctx); + static BranchInst *create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd, + Context &Ctx); + static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse, + Value *Cond, Instruction *InsertBefore, + Context &Ctx); + static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse, + Value *Cond, BasicBlock *InsertAtEnd, Context &Ctx); + /// For isa/dyn_cast. + static bool classof(const Value *From); + bool isUnconditional() const { + return cast(Val)->isUnconditional(); + } + bool isConditional() const { + return cast(Val)->isConditional(); + } + Value *getCondition() const; + void setCondition(Value *V) { setOperand(0, V); } + unsigned getNumSuccessors() const { return 1 + isConditional(); } + BasicBlock *getSuccessor(unsigned SuccIdx) const; + void setSuccessor(unsigned Idx, BasicBlock *NewSucc); + void swapSuccessors() { swapOperandsInternal(1, 2); } + +private: + struct LLVMBBToSBBB { + Context &Ctx; + LLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {} + BasicBlock *operator()(llvm::BasicBlock *BB) const; + }; + + struct ConstLLVMBBToSBBB { + Context &Ctx; + ConstLLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {} + const BasicBlock *operator()(const llvm::BasicBlock *BB) const; + }; + +public: + using sb_succ_op_iterator = + mapped_iterator; + iterator_range successors() { + iterator_range LLVMRange = + cast(Val)->successors(); + LLVMBBToSBBB BBMap(Ctx); + sb_succ_op_iterator MappedBegin = map_iterator(LLVMRange.begin(), BBMap); + sb_succ_op_iterator MappedEnd = map_iterator(LLVMRange.end(), BBMap); + return make_range(MappedBegin, MappedEnd); + } + + using const_sb_succ_op_iterator = + mapped_iterator; + iterator_range successors() const { + iterator_range ConstLLVMRange = + static_cast(cast(Val)) + ->successors(); + ConstLLVMBBToSBBB ConstBBMap(Ctx); + const_sb_succ_op_iterator ConstMappedBegin = + map_iterator(ConstLLVMRange.begin(), ConstBBMap); + const_sb_succ_op_iterator ConstMappedEnd = + map_iterator(ConstLLVMRange.end(), ConstBBMap); + return make_range(ConstMappedBegin, ConstMappedEnd); + } + +#ifndef NDEBUG + void verify() const final { + assert(isa(Val) && "Expected BranchInst!"); + } + friend raw_ostream &operator<<(raw_ostream &OS, const BranchInst &BI) { + BI.dump(OS); + return OS; + } + void dump(raw_ostream &OS) const override; + LLVM_DUMP_METHOD void dump() const override; +#endif +}; + class LoadInst final : public Instruction { /// Use LoadInst::create() instead of calling the constructor. LoadInst(llvm::LoadInst *LI, Context &Ctx) @@ -870,6 +975,8 @@ class Context { SelectInst *createSelectInst(llvm::SelectInst *SI); friend SelectInst; // For createSelectInst() + BranchInst *createBranchInst(llvm::BranchInst *I); + friend BranchInst; // For createBranchInst() LoadInst *createLoadInst(llvm::LoadInst *LI); friend LoadInst; // For createLoadInst() StoreInst *createStoreInst(llvm::StoreInst *SI); diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def index efa91557555879..f3d616774b3fd9 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def +++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def @@ -26,6 +26,7 @@ DEF_USER(Constant, Constant) // ClassID, Opcode(s), Class DEF_INSTR(Opaque, OP(Opaque), OpaqueInst) DEF_INSTR(Select, OP(Select), SelectInst) +DEF_INSTR(Br, OP(Br), BranchInst) DEF_INSTR(Load, OP(Load), LoadInst) DEF_INSTR(Store, OP(Store), StoreInst) DEF_INSTR(Ret, OP(Ret), ReturnInst) diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h index b88eb3d2a52808..3daec3fd5c63cf 100644 --- a/llvm/include/llvm/SandboxIR/Tracker.h +++ b/llvm/include/llvm/SandboxIR/Tracker.h @@ -101,6 +101,27 @@ class UseSet : public IRChangeBase { #endif }; +/// Tracks swapping a Use with another Use. +class UseSwap : public IRChangeBase { + Use ThisUse; + Use OtherUse; + +public: + UseSwap(const Use &ThisUse, const Use &OtherUse, Tracker &Tracker) + : IRChangeBase(Tracker), ThisUse(ThisUse), OtherUse(OtherUse) { + assert(ThisUse.getUser() == OtherUse.getUser() && "Expected same user!"); + } + void revert() final { ThisUse.swap(OtherUse); } + void accept() final {} +#ifndef NDEBUG + void dump(raw_ostream &OS) const final { + dumpCommon(OS); + OS << "UseSwap"; + } + LLVM_DUMP_METHOD void dump() const final; +#endif +}; + class EraseFromParent : public IRChangeBase { /// Contains all the data we need to restore an "erased" (i.e., detached) /// instruction: the instruction itself and its operands in order. diff --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h index d77b4568d0fab0..03cbfe6cb04463 100644 --- a/llvm/include/llvm/SandboxIR/Use.h +++ b/llvm/include/llvm/SandboxIR/Use.h @@ -47,6 +47,7 @@ class Use { void set(Value *V); class User *getUser() const { return Usr; } unsigned getOperandNo() const; + void swap(Use &OtherUse); Context *getContext() const { return Ctx; } bool operator==(const Use &Other) const { assert(Ctx == Other.Ctx && "Contexts differ!"); diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index 51c9af8a6e1fec..ceadb34f53eafb 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -20,6 +20,13 @@ void Use::set(Value *V) { LLVMUse->set(V->Val); } unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); } +void Use::swap(Use &OtherUse) { + auto &Tracker = Ctx->getTracker(); + if (Tracker.isTracking()) + Tracker.track(std::make_unique(*this, OtherUse, Tracker)); + LLVMUse->swap(*OtherUse.LLVMUse); +} + #ifndef NDEBUG void Use::dump(raw_ostream &OS) const { Value *Def = nullptr; @@ -500,6 +507,85 @@ void SelectInst::dump() const { } #endif // NDEBUG +BranchInst *BranchInst::create(BasicBlock *IfTrue, Instruction *InsertBefore, + Context &Ctx) { + auto &Builder = Ctx.getLLVMIRBuilder(); + Builder.SetInsertPoint(cast(InsertBefore->Val)); + llvm::BranchInst *NewBr = + Builder.CreateBr(cast(IfTrue->Val)); + return Ctx.createBranchInst(NewBr); +} + +BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd, + Context &Ctx) { + auto &Builder = Ctx.getLLVMIRBuilder(); + Builder.SetInsertPoint(cast(InsertAtEnd->Val)); + llvm::BranchInst *NewBr = + Builder.CreateBr(cast(IfTrue->Val)); + return Ctx.createBranchInst(NewBr); +} + +BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse, + Value *Cond, Instruction *InsertBefore, + Context &Ctx) { + auto &Builder = Ctx.getLLVMIRBuilder(); + Builder.SetInsertPoint(cast(InsertBefore->Val)); + llvm::BranchInst *NewBr = + Builder.CreateCondBr(Cond->Val, cast(IfTrue->Val), + cast(IfFalse->Val)); + return Ctx.createBranchInst(NewBr); +} + +BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse, + Value *Cond, BasicBlock *InsertAtEnd, + Context &Ctx) { + auto &Builder = Ctx.getLLVMIRBuilder(); + Builder.SetInsertPoint(cast(InsertAtEnd->Val)); + llvm::BranchInst *NewBr = + Builder.CreateCondBr(Cond->Val, cast(IfTrue->Val), + cast(IfFalse->Val)); + return Ctx.createBranchInst(NewBr); +} + +bool BranchInst::classof(const Value *From) { + return From->getSubclassID() == ClassID::Br; +} + +Value *BranchInst::getCondition() const { + assert(isConditional() && "Cannot get condition of an uncond branch!"); + return Ctx.getValue(cast(Val)->getCondition()); +} + +BasicBlock *BranchInst::getSuccessor(unsigned SuccIdx) const { + assert(SuccIdx < getNumSuccessors() && + "Successor # out of range for Branch!"); + return cast_or_null( + Ctx.getValue(cast(Val)->getSuccessor(SuccIdx))); +} + +void BranchInst::setSuccessor(unsigned Idx, BasicBlock *NewSucc) { + assert((Idx == 0 || Idx == 1) && "Out of bounds!"); + setOperand(2u - Idx, NewSucc); +} + +BasicBlock *BranchInst::LLVMBBToSBBB::operator()(llvm::BasicBlock *BB) const { + return cast(Ctx.getValue(BB)); +} +const BasicBlock * +BranchInst::ConstLLVMBBToSBBB::operator()(const llvm::BasicBlock *BB) const { + return cast(Ctx.getValue(BB)); +} +#ifndef NDEBUG +void BranchInst::dump(raw_ostream &OS) const { + dumpCommonPrefix(OS); + dumpCommonSuffix(OS); +} +void BranchInst::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG + LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align, Instruction *InsertBefore, Context &Ctx, const Twine &Name) { @@ -758,6 +844,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) { It->second = std::unique_ptr(new SelectInst(LLVMSel, *this)); return It->second.get(); } + case llvm::Instruction::Br: { + auto *LLVMBr = cast(LLVMV); + It->second = std::unique_ptr(new BranchInst(LLVMBr, *this)); + return It->second.get(); + } case llvm::Instruction::Load: { auto *LLVMLd = cast(LLVMV); It->second = std::unique_ptr(new LoadInst(LLVMLd, *this)); @@ -796,6 +887,11 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) { return cast(registerValue(std::move(NewPtr))); } +BranchInst *Context::createBranchInst(llvm::BranchInst *BI) { + auto NewPtr = std::unique_ptr(new BranchInst(BI, *this)); + return cast(registerValue(std::move(NewPtr))); +} + LoadInst *Context::createLoadInst(llvm::LoadInst *LI) { auto NewPtr = std::unique_ptr(new LoadInst(LI, *this)); return cast(registerValue(std::move(NewPtr))); diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp index 626c9c27d05e57..c74177608aff20 100644 --- a/llvm/lib/SandboxIR/Tracker.cpp +++ b/llvm/lib/SandboxIR/Tracker.cpp @@ -35,6 +35,11 @@ void UseSet::dump() const { dump(dbgs()); dbgs() << "\n"; } + +void UseSwap::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} #endif // NDEBUG Tracker::~Tracker() { diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index ba90b4f811f8e1..783f606c703802 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -398,7 +398,7 @@ define void @foo(i32 %arg0, i32 %arg1) { EXPECT_EQ(Buff, R"IR( void @foo(i32 %arg0, i32 %arg1) { bb0: - br label %bb1 ; SB3. (Opaque) + br label %bb1 ; SB3. (Br) bb1: ret void ; SB5. (Ret) @@ -466,7 +466,7 @@ define void @foo(i32 %v1) { BB0.dump(BS); EXPECT_EQ(Buff, R"IR( bb0: - br label %bb1 ; SB2. (Opaque) + br label %bb1 ; SB2. (Br) )IR"); } #endif // NDEBUG @@ -629,6 +629,107 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) { } } +TEST_F(SandboxIRTest, BranchInst) { + parseIR(C, R"IR( +define void @foo(i1 %cond0, i1 %cond2) { + bb0: + br i1 %cond0, label %bb1, label %bb2 + bb1: + ret void + bb2: + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + sandboxir::Function *F = Ctx.createFunction(LLVMF); + auto *Cond0 = F->getArg(0); + auto *Cond1 = F->getArg(1); + auto *BB0 = cast( + Ctx.getValue(getBasicBlockByName(*LLVMF, "bb0"))); + auto *BB1 = cast( + Ctx.getValue(getBasicBlockByName(*LLVMF, "bb1"))); + auto *Ret1 = BB1->getTerminator(); + auto *BB2 = cast( + Ctx.getValue(getBasicBlockByName(*LLVMF, "bb2"))); + auto *Ret2 = BB2->getTerminator(); + auto It = BB0->begin(); + auto *Br0 = cast(&*It++); + // Check isUnconditional(). + EXPECT_FALSE(Br0->isUnconditional()); + // Check isConditional(). + EXPECT_TRUE(Br0->isConditional()); + // Check getCondition(). + EXPECT_EQ(Br0->getCondition(), Cond0); + // Check setCondition(). + Br0->setCondition(Cond1); + EXPECT_EQ(Br0->getCondition(), Cond1); + // Check getNumSuccessors(). + EXPECT_EQ(Br0->getNumSuccessors(), 2u); + // Check getSuccessor(). + EXPECT_EQ(Br0->getSuccessor(0), BB1); + EXPECT_EQ(Br0->getSuccessor(1), BB2); + // Check swapSuccessors(). + Br0->swapSuccessors(); + EXPECT_EQ(Br0->getSuccessor(0), BB2); + EXPECT_EQ(Br0->getSuccessor(1), BB1); + // Check successors(). + EXPECT_EQ(range_size(Br0->successors()), 2u); + unsigned SuccIdx = 0; + SmallVector ExpectedSuccs({BB1, BB2}); + for (sandboxir::BasicBlock *Succ : Br0->successors()) + EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]); + + { + // Check unconditional BranchInst::create() InsertBefore. + auto *Br = sandboxir::BranchInst::create(BB1, /*InsertBefore=*/Ret1, Ctx); + EXPECT_FALSE(Br->isConditional()); + EXPECT_TRUE(Br->isUnconditional()); + EXPECT_DEATH(Br->getCondition(), ".*condition.*"); + unsigned SuccIdx = 0; + SmallVector ExpectedSuccs({BB1}); + for (sandboxir::BasicBlock *Succ : Br->successors()) + EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]); + EXPECT_EQ(Br->getNextNode(), Ret1); + } + { + // Check unconditional BranchInst::create() InsertAtEnd. + auto *Br = sandboxir::BranchInst::create(BB1, /*InsertAtEnd=*/BB1, Ctx); + EXPECT_FALSE(Br->isConditional()); + EXPECT_TRUE(Br->isUnconditional()); + EXPECT_DEATH(Br->getCondition(), ".*condition.*"); + unsigned SuccIdx = 0; + SmallVector ExpectedSuccs({BB1}); + for (sandboxir::BasicBlock *Succ : Br->successors()) + EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]); + EXPECT_EQ(Br->getPrevNode(), Ret1); + } + { + // Check conditional BranchInst::create() InsertBefore. + auto *Br = sandboxir::BranchInst::create(BB1, BB2, Cond0, + /*InsertBefore=*/Ret1, Ctx); + EXPECT_TRUE(Br->isConditional()); + EXPECT_EQ(Br->getCondition(), Cond0); + unsigned SuccIdx = 0; + SmallVector ExpectedSuccs({BB2, BB1}); + for (sandboxir::BasicBlock *Succ : Br->successors()) + EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]); + EXPECT_EQ(Br->getNextNode(), Ret1); + } + { + // Check conditional BranchInst::create() InsertAtEnd. + auto *Br = sandboxir::BranchInst::create(BB1, BB2, Cond0, + /*InsertAtEnd=*/BB2, Ctx); + EXPECT_TRUE(Br->isConditional()); + EXPECT_EQ(Br->getCondition(), Cond0); + unsigned SuccIdx = 0; + SmallVector ExpectedSuccs({BB2, BB1}); + for (sandboxir::BasicBlock *Succ : Br->successors()) + EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]); + EXPECT_EQ(Br->getPrevNode(), Ret2); + } +} + TEST_F(SandboxIRTest, LoadInst) { parseIR(C, R"IR( define void @foo(ptr %arg0, ptr %arg1) { diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp index 354cd187adb107..dd9dcd543236ee 100644 --- a/llvm/unittests/SandboxIR/TrackerTest.cpp +++ b/llvm/unittests/SandboxIR/TrackerTest.cpp @@ -69,6 +69,49 @@ define void @foo(ptr %ptr) { EXPECT_EQ(Ld->getOperand(0), Gep0); } +TEST_F(TrackerTest, SwapOperands) { + parseIR(C, R"IR( +define void @foo(i1 %cond) { + bb0: + br i1 %cond, label %bb1, label %bb2 + bb1: + ret void + bb2: + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + Ctx.createFunction(&LLVMF); + auto *BB0 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb0"))); + auto *BB1 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb1"))); + auto *BB2 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb2"))); + auto &Tracker = Ctx.getTracker(); + Tracker.save(); + auto It = BB0->begin(); + auto *Br = cast(&*It++); + + unsigned SuccIdx = 0; + SmallVector ExpectedSuccs({BB2, BB1}); + for (auto *Succ : Br->successors()) + EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]); + + // This calls User::swapOperandsInternal() internally. + Br->swapSuccessors(); + + SuccIdx = 0; + for (auto *Succ : reverse(Br->successors())) + EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]); + + Ctx.getTracker().revert(); + SuccIdx = 0; + for (auto *Succ : Br->successors()) + EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]); +} + TEST_F(TrackerTest, RUWIf_RAUW_RUOW) { parseIR(C, R"IR( define void @foo(ptr %ptr) {