diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index 38c2586f9d73c2..1fd9283313895c 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -226,6 +226,7 @@ class Value { friend class CallBrInst; // For getting `Val`. friend class GetElementPtrInst; // For getting `Val`. friend class CastInst; // For getting `Val`. + friend class PHINode; // For getting `Val`. /// All values point to the context. Context &Ctx; @@ -618,6 +619,7 @@ class Instruction : public sandboxir::User { friend class CallBrInst; // For getTopmostLLVMInstruction(). friend class GetElementPtrInst; // For getTopmostLLVMInstruction(). friend class CastInst; // For getTopmostLLVMInstruction(). + friend class PHINode; // For getTopmostLLVMInstruction(). /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program /// order. @@ -1515,6 +1517,100 @@ class IntToPtrInst final : public CastInst { #endif // NDEBUG }; +class PHINode final : public Instruction { + /// Use Context::createPHINode(). Don't call the constructor directly. + PHINode(llvm::PHINode *PHI, Context &Ctx) + : Instruction(ClassID::PHI, Opcode::PHI, PHI, Ctx) {} + friend Context; // for PHINode() + Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final { + return getOperandUseDefault(OpIdx, Verify); + } + SmallVector getLLVMInstrs() const final { + return {cast(Val)}; + } + /// Helper for mapped_iterator. + struct LLVMBBToBB { + Context &Ctx; + LLVMBBToBB(Context &Ctx) : Ctx(Ctx) {} + BasicBlock *operator()(llvm::BasicBlock *LLVMBB) const; + }; + +public: + unsigned getUseOperandNo(const Use &Use) const final { + return getUseOperandNoDefault(Use); + } + unsigned getNumOfIRInstrs() const final { return 1u; } + static PHINode *create(Type *Ty, unsigned NumReservedValues, + Instruction *InsertBefore, Context &Ctx, + const Twine &Name = ""); + /// For isa/dyn_cast. + static bool classof(const Value *From); + + using const_block_iterator = + mapped_iterator; + + const_block_iterator block_begin() const { + LLVMBBToBB BBGetter(Ctx); + return const_block_iterator(cast(Val)->block_begin(), + BBGetter); + } + const_block_iterator block_end() const { + LLVMBBToBB BBGetter(Ctx); + return const_block_iterator(cast(Val)->block_end(), + BBGetter); + } + iterator_range blocks() const { + return make_range(block_begin(), block_end()); + } + + op_range incoming_values() { return operands(); } + + const_op_range incoming_values() const { return operands(); } + + unsigned getNumIncomingValues() const { + return cast(Val)->getNumIncomingValues(); + } + Value *getIncomingValue(unsigned Idx) const; + void setIncomingValue(unsigned Idx, Value *V); + static unsigned getOperandNumForIncomingValue(unsigned Idx) { + return llvm::PHINode::getOperandNumForIncomingValue(Idx); + } + static unsigned getIncomingValueNumForOperand(unsigned Idx) { + return llvm::PHINode::getIncomingValueNumForOperand(Idx); + } + BasicBlock *getIncomingBlock(unsigned Idx) const; + BasicBlock *getIncomingBlock(const Use &U) const; + + void setIncomingBlock(unsigned Idx, BasicBlock *BB); + + void addIncoming(Value *V, BasicBlock *BB); + + Value *removeIncomingValue(unsigned Idx); + Value *removeIncomingValue(BasicBlock *BB); + + int getBasicBlockIndex(const BasicBlock *BB) const; + Value *getIncomingValueForBlock(const BasicBlock *BB) const; + + Value *hasConstantValue() const; + + bool hasConstantOrUndefValue() const { + return cast(Val)->hasConstantOrUndefValue(); + } + bool isComplete() const { return cast(Val)->isComplete(); } + // TODO: Implement the below functions: + // void replaceIncomingBlockWith (const BasicBlock *Old, BasicBlock *New); + // void copyIncomingBlocks(iterator_range BBRange, + // uint32_t ToIdx = 0) + // void removeIncomingValueIf(function_ref< bool(unsigned)> Predicate, + // bool DeletePHIIfEmpty=true) +#ifndef NDEBUG + void verify() const final { + assert(isa(Val) && "Expected PHINode!"); + } + void dump(raw_ostream &OS) const override; + LLVM_DUMP_METHOD void dump() const override; +#endif +}; class PtrToIntInst final : public CastInst { public: static Value *create(Value *Src, Type *DestTy, BBIterator WhereIt, @@ -1700,6 +1796,8 @@ class Context { friend GetElementPtrInst; // For createGetElementPtrInst() CastInst *createCastInst(llvm::CastInst *I); friend CastInst; // For createCastInst() + PHINode *createPHINode(llvm::PHINode *I); + friend PHINode; // For createPHINode() public: Context(LLVMContext &LLVMCtx) diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def index 243ce6b2c60a9f..4cb601128a507e 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def +++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def @@ -58,6 +58,8 @@ DEF_INSTR(Cast, OPCODES(\ OP(BitCast) \ OP(AddrSpaceCast) \ ), CastInst) +DEF_INSTR(PHI, OP(PHI), PHINode) + // clang-format on #ifdef DEF_VALUE #undef DEF_VALUE diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h index 64068461b94905..238e4e9dacd342 100644 --- a/llvm/include/llvm/SandboxIR/Tracker.h +++ b/llvm/include/llvm/SandboxIR/Tracker.h @@ -102,6 +102,64 @@ class UseSet : public IRChangeBase { #endif }; +class PHISetIncoming : public IRChangeBase { + PHINode &PHI; + unsigned Idx; + PointerUnion OrigValueOrBB; + +public: + enum class What { + Value, + Block, + }; + PHISetIncoming(PHINode &PHI, unsigned Idx, What What, Tracker &Tracker); + void revert() final; + void accept() final {} +#ifndef NDEBUG + void dump(raw_ostream &OS) const final { + dumpCommon(OS); + OS << "PHISetIncoming"; + } + LLVM_DUMP_METHOD void dump() const final; +#endif +}; + +class PHIRemoveIncoming : public IRChangeBase { + PHINode &PHI; + unsigned RemovedIdx; + Value *RemovedV; + BasicBlock *RemovedBB; + +public: + PHIRemoveIncoming(PHINode &PHI, unsigned RemovedIdx, Tracker &Tracker); + void revert() final; + void accept() final {} +#ifndef NDEBUG + void dump(raw_ostream &OS) const final { + dumpCommon(OS); + OS << "PHISetIncoming"; + } + LLVM_DUMP_METHOD void dump() const final; +#endif +}; + +class PHIAddIncoming : public IRChangeBase { + PHINode &PHI; + unsigned Idx; + +public: + PHIAddIncoming(PHINode &PHI, Tracker &Tracker); + void revert() final; + void accept() final {} +#ifndef NDEBUG + void dump(raw_ostream &OS) const final { + dumpCommon(OS); + OS << "PHISetIncoming"; + } + LLVM_DUMP_METHOD void dump() const final; +#endif +}; + /// Tracks swapping a Use with another Use. class UseSwap : public IRChangeBase { Use ThisUse; diff --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h index ef728ea3878516..35d01daf39f6e1 100644 --- a/llvm/include/llvm/SandboxIR/Use.h +++ b/llvm/include/llvm/SandboxIR/Use.h @@ -22,6 +22,7 @@ class Context; class Value; class User; class CallBase; +class PHINode; /// Represents a Def-use/Use-def edge in SandboxIR. /// NOTE: Unlike llvm::Use, this is not an integral part of the use-def chains. @@ -43,6 +44,7 @@ class Use { friend class UserUseIterator; // For accessing members friend class CallBase; // For LLVMUse friend class CallBrInst; // For constructor + friend class PHINode; // For LLVMUse public: operator Value *() const { return get(); } diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index 1ea22c3a8b48e5..4f12985bb0e636 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -1062,6 +1062,95 @@ void GetElementPtrInst::dump() const { } #endif // NDEBUG +BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const { + return cast(Ctx.getValue(LLVMBB)); +} + +PHINode *PHINode::create(Type *Ty, unsigned NumReservedValues, + Instruction *InsertBefore, Context &Ctx, + const Twine &Name) { + llvm::PHINode *NewPHI = llvm::PHINode::Create( + Ty, NumReservedValues, Name, InsertBefore->getTopmostLLVMInstruction()); + return Ctx.createPHINode(NewPHI); +} + +bool PHINode::classof(const Value *From) { + return From->getSubclassID() == ClassID::PHI; +} + +Value *PHINode::getIncomingValue(unsigned Idx) const { + return Ctx.getValue(cast(Val)->getIncomingValue(Idx)); +} +void PHINode::setIncomingValue(unsigned Idx, Value *V) { + auto &Tracker = Ctx.getTracker(); + if (Tracker.isTracking()) + Tracker.track(std::make_unique( + *this, Idx, PHISetIncoming::What::Value, Tracker)); + + cast(Val)->setIncomingValue(Idx, V->Val); +} +BasicBlock *PHINode::getIncomingBlock(unsigned Idx) const { + return cast( + Ctx.getValue(cast(Val)->getIncomingBlock(Idx))); +} +BasicBlock *PHINode::getIncomingBlock(const Use &U) const { + llvm::Use *LLVMUse = U.LLVMUse; + llvm::BasicBlock *BB = cast(Val)->getIncomingBlock(*LLVMUse); + return cast(Ctx.getValue(BB)); +} +void PHINode::setIncomingBlock(unsigned Idx, BasicBlock *BB) { + auto &Tracker = Ctx.getTracker(); + if (Tracker.isTracking()) + Tracker.track(std::make_unique( + *this, Idx, PHISetIncoming::What::Block, Tracker)); + cast(Val)->setIncomingBlock(Idx, + cast(BB->Val)); +} +void PHINode::addIncoming(Value *V, BasicBlock *BB) { + auto &Tracker = Ctx.getTracker(); + if (Tracker.isTracking()) + Tracker.track(std::make_unique(*this, Tracker)); + + cast(Val)->addIncoming(V->Val, + cast(BB->Val)); +} +Value *PHINode::removeIncomingValue(unsigned Idx) { + auto &Tracker = Ctx.getTracker(); + if (Tracker.isTracking()) + Tracker.track(std::make_unique(*this, Idx, Tracker)); + + llvm::Value *LLVMV = + cast(Val)->removeIncomingValue(Idx, + /*DeletePHIIfEmpty=*/false); + return Ctx.getValue(LLVMV); +} +Value *PHINode::removeIncomingValue(BasicBlock *BB) { + auto &Tracker = Ctx.getTracker(); + if (Tracker.isTracking()) + Tracker.track(std::make_unique( + *this, getBasicBlockIndex(BB), Tracker)); + + auto *LLVMBB = cast(BB->Val); + llvm::Value *LLVMV = + cast(Val)->removeIncomingValue(LLVMBB, + /*DeletePHIIfEmpty=*/false); + return Ctx.getValue(LLVMV); +} +int PHINode::getBasicBlockIndex(const BasicBlock *BB) const { + auto *LLVMBB = cast(BB->Val); + return cast(Val)->getBasicBlockIndex(LLVMBB); +} +Value *PHINode::getIncomingValueForBlock(const BasicBlock *BB) const { + auto *LLVMBB = cast(BB->Val); + llvm::Value *LLVMV = + cast(Val)->getIncomingValueForBlock(LLVMBB); + return Ctx.getValue(LLVMV); +} +Value *PHINode::hasConstantValue() const { + llvm::Value *LLVMV = cast(Val)->hasConstantValue(); + return LLVMV != nullptr ? Ctx.getValue(LLVMV) : nullptr; +} + static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) { switch (Opc) { case Instruction::Opcode::ZExt: @@ -1272,6 +1361,16 @@ Value *PtrToIntInst::create(Value *Src, Type *DestTy, BasicBlock *InsertAtEnd, } #ifndef NDEBUG +void PHINode::dump(raw_ostream &OS) const { + dumpCommonPrefix(OS); + dumpCommonSuffix(OS); +} + +void PHINode::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} + void PtrToIntInst::dump(raw_ostream &OS) const { dumpCommonPrefix(OS); dumpCommonSuffix(OS); @@ -1537,6 +1636,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) { It->second = std::unique_ptr(new CastInst(LLVMCast, *this)); return It->second.get(); } + case llvm::Instruction::PHI: { + auto *LLVMPhi = cast(LLVMV); + It->second = std::unique_ptr(new PHINode(LLVMPhi, *this)); + return It->second.get(); + } default: break; } @@ -1606,6 +1710,10 @@ CastInst *Context::createCastInst(llvm::CastInst *I) { auto NewPtr = std::unique_ptr(new CastInst(I, *this)); return cast(registerValue(std::move(NewPtr))); } +PHINode *Context::createPHINode(llvm::PHINode *I) { + auto NewPtr = std::unique_ptr(new PHINode(I, *this)); + return cast(registerValue(std::move(NewPtr))); +} Value *Context::getValue(llvm::Value *V) const { auto It = LLVMValueToValueMap.find(V); diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp index eae55d7b3d962f..0310160e8bf35a 100644 --- a/llvm/lib/SandboxIR/Tracker.cpp +++ b/llvm/lib/SandboxIR/Tracker.cpp @@ -42,6 +42,81 @@ void UseSwap::dump() const { } #endif // NDEBUG +PHISetIncoming::PHISetIncoming(PHINode &PHI, unsigned Idx, What What, + Tracker &Tracker) + : IRChangeBase(Tracker), PHI(PHI), Idx(Idx) { + switch (What) { + case What::Value: + OrigValueOrBB = PHI.getIncomingValue(Idx); + break; + case What::Block: + OrigValueOrBB = PHI.getIncomingBlock(Idx); + break; + } +} + +void PHISetIncoming::revert() { + if (auto *V = OrigValueOrBB.dyn_cast()) + PHI.setIncomingValue(Idx, V); + else + PHI.setIncomingBlock(Idx, OrigValueOrBB.get()); +} + +#ifndef NDEBUG +void PHISetIncoming::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG + +PHIRemoveIncoming::PHIRemoveIncoming(PHINode &PHI, unsigned RemovedIdx, + Tracker &Tracker) + : IRChangeBase(Tracker), PHI(PHI), RemovedIdx(RemovedIdx) { + RemovedV = PHI.getIncomingValue(RemovedIdx); + RemovedBB = PHI.getIncomingBlock(RemovedIdx); +} + +void PHIRemoveIncoming::revert() { + // Special case: if the PHI is now empty, as we don't need to care about the + // order of the incoming values. + unsigned NumIncoming = PHI.getNumIncomingValues(); + if (NumIncoming == 0) { + PHI.addIncoming(RemovedV, RemovedBB); + return; + } + // Shift all incoming values by one starting from the end until `Idx`. + // Start by adding a copy of the last incoming values. + unsigned LastIdx = NumIncoming - 1; + PHI.addIncoming(PHI.getIncomingValue(LastIdx), PHI.getIncomingBlock(LastIdx)); + for (unsigned Idx = LastIdx; Idx > RemovedIdx; --Idx) { + auto *PrevV = PHI.getIncomingValue(Idx - 1); + auto *PrevBB = PHI.getIncomingBlock(Idx - 1); + PHI.setIncomingValue(Idx, PrevV); + PHI.setIncomingBlock(Idx, PrevBB); + } + PHI.setIncomingValue(RemovedIdx, RemovedV); + PHI.setIncomingBlock(RemovedIdx, RemovedBB); +} + +#ifndef NDEBUG +void PHIRemoveIncoming::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG + +PHIAddIncoming::PHIAddIncoming(PHINode &PHI, Tracker &Tracker) + : IRChangeBase(Tracker), PHI(PHI), Idx(PHI.getNumIncomingValues()) {} + +void PHIAddIncoming::revert() { PHI.removeIncomingValue(Idx); } + +#ifndef NDEBUG +void PHIAddIncoming::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG + Tracker::~Tracker() { assert(Changes.empty() && "You must accept or revert changes!"); } diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 9d4fba404a43cf..31feb56a5272f8 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -2112,3 +2112,123 @@ define void @foo(ptr %ptr) { EXPECT_EQ(NewI->getParent(), BB); } } + +TEST_F(SandboxIRTest, PHINode) { + parseIR(C, R"IR( +define void @foo(i32 %arg) { +bb1: + br label %bb2 + +bb2: + %phi = phi i32 [ %arg, %bb1 ], [ 0, %bb2 ] + br label %bb2 + +bb3: + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + auto *LLVMBB1 = getBasicBlockByName(LLVMF, "bb1"); + auto *LLVMBB2 = getBasicBlockByName(LLVMF, "bb2"); + auto *LLVMBB3 = getBasicBlockByName(LLVMF, "bb3"); + auto LLVMIt = LLVMBB2->begin(); + auto *LLVMPHI = cast(&*LLVMIt++); + sandboxir::Context Ctx(C); + sandboxir::Function *F = Ctx.createFunction(&LLVMF); + auto *Arg = F->getArg(0); + auto *BB1 = cast(Ctx.getValue(LLVMBB1)); + auto *BB2 = cast(Ctx.getValue(LLVMBB2)); + auto *BB3 = cast(Ctx.getValue(LLVMBB3)); + auto It = BB2->begin(); + // Check classof(). + auto *PHI = cast(&*It++); + auto *Br = cast(&*It++); + // Check blocks(). + EXPECT_EQ(range_size(PHI->blocks()), range_size(LLVMPHI->blocks())); + auto BlockIt = PHI->block_begin(); + for (llvm::BasicBlock *LLVMBB : LLVMPHI->blocks()) { + sandboxir::BasicBlock *BB = *BlockIt++; + EXPECT_EQ(BB, Ctx.getValue(LLVMBB)); + } + // Check incoming_values(). + EXPECT_EQ(range_size(PHI->incoming_values()), + range_size(LLVMPHI->incoming_values())); + auto IncIt = PHI->incoming_values().begin(); + for (llvm::Value *LLVMV : LLVMPHI->incoming_values()) { + sandboxir::Value *IncV = *IncIt++; + EXPECT_EQ(IncV, Ctx.getValue(LLVMV)); + } + // Check getNumIncomingValues(). + EXPECT_EQ(PHI->getNumIncomingValues(), LLVMPHI->getNumIncomingValues()); + // Check getIncomingValue(). + EXPECT_EQ(PHI->getIncomingValue(0), + Ctx.getValue(LLVMPHI->getIncomingValue(0))); + EXPECT_EQ(PHI->getIncomingValue(1), + Ctx.getValue(LLVMPHI->getIncomingValue(1))); + // Check setIncomingValue(). + auto *OrigV = PHI->getIncomingValue(0); + PHI->setIncomingValue(0, PHI); + EXPECT_EQ(PHI->getIncomingValue(0), PHI); + PHI->setIncomingValue(0, OrigV); + // Check getOperandNumForIncomingValue(). + EXPECT_EQ(sandboxir::PHINode::getOperandNumForIncomingValue(0), + llvm::PHINode::getOperandNumForIncomingValue(0)); + // Check getIncomingValueNumForOperand(). + EXPECT_EQ(sandboxir::PHINode::getIncomingValueNumForOperand(0), + llvm::PHINode::getIncomingValueNumForOperand(0)); + // Check getIncomingBlock(unsigned). + EXPECT_EQ(PHI->getIncomingBlock(0), + Ctx.getValue(LLVMPHI->getIncomingBlock(0))); + // Check getIncomingBlock(Use). + llvm::Use &LLVMUse = LLVMPHI->getOperandUse(0); + sandboxir::Use Use = PHI->getOperandUse(0); + EXPECT_EQ(PHI->getIncomingBlock(Use), + Ctx.getValue(LLVMPHI->getIncomingBlock(LLVMUse))); + // Check setIncomingBlock(). + sandboxir::BasicBlock *OrigBB = PHI->getIncomingBlock(0); + EXPECT_NE(OrigBB, BB2); + PHI->setIncomingBlock(0, BB2); + EXPECT_EQ(PHI->getIncomingBlock(0), BB2); + PHI->setIncomingBlock(0, OrigBB); + EXPECT_EQ(PHI->getIncomingBlock(0), OrigBB); + // Check addIncoming(). + unsigned OrigNumIncoming = PHI->getNumIncomingValues(); + PHI->addIncoming(Arg, BB3); + EXPECT_EQ(PHI->getNumIncomingValues(), LLVMPHI->getNumIncomingValues()); + EXPECT_EQ(PHI->getNumIncomingValues(), OrigNumIncoming + 1); + EXPECT_EQ(PHI->getIncomingValue(OrigNumIncoming), Arg); + EXPECT_EQ(PHI->getIncomingBlock(OrigNumIncoming), BB3); + // Check removeIncomingValue(unsigned). + PHI->removeIncomingValue(OrigNumIncoming); + EXPECT_EQ(PHI->getNumIncomingValues(), OrigNumIncoming); + // Check removeIncomingValue(BasicBlock *). + PHI->addIncoming(Arg, BB3); + PHI->removeIncomingValue(BB3); + EXPECT_EQ(PHI->getNumIncomingValues(), OrigNumIncoming); + // Check getBasicBlockIndex(). + EXPECT_EQ(PHI->getBasicBlockIndex(BB1), LLVMPHI->getBasicBlockIndex(LLVMBB1)); + // Check getIncomingValueForBlock(). + EXPECT_EQ(PHI->getIncomingValueForBlock(BB1), + Ctx.getValue(LLVMPHI->getIncomingValueForBlock(LLVMBB1))); + // Check hasConstantValue(). + llvm::Value *ConstV = LLVMPHI->hasConstantValue(); + EXPECT_EQ(PHI->hasConstantValue(), + ConstV != nullptr ? Ctx.getValue(ConstV) : nullptr); + // Check hasConstantOrUndefValue(). + EXPECT_EQ(PHI->hasConstantOrUndefValue(), LLVMPHI->hasConstantOrUndefValue()); + // Check isComplete(). + EXPECT_EQ(PHI->isComplete(), LLVMPHI->isComplete()); + + // Check create(). + auto *NewPHI = cast( + sandboxir::PHINode::create(PHI->getType(), 0, Br, Ctx, "NewPHI")); + EXPECT_EQ(NewPHI->getType(), PHI->getType()); + EXPECT_EQ(NewPHI->getNextNode(), Br); + EXPECT_EQ(NewPHI->getName(), "NewPHI"); + EXPECT_EQ(NewPHI->getNumIncomingValues(), 0u); + for (auto [Idx, V] : enumerate(PHI->incoming_values())) { + sandboxir::BasicBlock *IncBB = PHI->getIncomingBlock(Idx); + NewPHI->addIncoming(V, IncBB); + } + EXPECT_EQ(NewPHI->getNumIncomingValues(), PHI->getNumIncomingValues()); +} diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp index cd737d33dd1937..d016c7793a52c0 100644 --- a/llvm/unittests/SandboxIR/TrackerTest.cpp +++ b/llvm/unittests/SandboxIR/TrackerTest.cpp @@ -584,3 +584,127 @@ define void @foo(i8 %arg) { Ctx.revert(); EXPECT_EQ(CallBr->getIndirectDest(0), OrigIndirectDest); } + +TEST_F(TrackerTest, PHINodeSetters) { + parseIR(C, R"IR( +define void @foo(i8 %arg0, i8 %arg1, i8 %arg2) { +bb0: + br label %bb2 + +bb1: + %phi = phi i8 [ %arg0, %bb0 ], [ %arg1, %bb1 ] + br label %bb1 + +bb2: + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + unsigned ArgIdx = 0; + auto *Arg0 = F.getArg(ArgIdx++); + auto *Arg1 = F.getArg(ArgIdx++); + auto *Arg2 = F.getArg(ArgIdx++); + auto *BB0 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb0"))); + auto *BB1 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb1"))); + auto *BB2 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb2"))); + auto *PHI = cast(&*BB1->begin()); + + // Check setIncomingValue(). + Ctx.save(); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + PHI->setIncomingValue(0, Arg2); + EXPECT_EQ(PHI->getIncomingValue(0), Arg2); + Ctx.revert(); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + EXPECT_EQ(PHI->getNumIncomingValues(), 2u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + EXPECT_EQ(PHI->getIncomingBlock(1), BB1); + EXPECT_EQ(PHI->getIncomingValue(1), Arg1); + + // Check setIncomingBlock(). + Ctx.save(); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + PHI->setIncomingBlock(0, BB2); + EXPECT_EQ(PHI->getIncomingBlock(0), BB2); + Ctx.revert(); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getNumIncomingValues(), 2u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + EXPECT_EQ(PHI->getIncomingBlock(1), BB1); + EXPECT_EQ(PHI->getIncomingValue(1), Arg1); + + // Check addIncoming(). + Ctx.save(); + EXPECT_EQ(PHI->getNumIncomingValues(), 2u); + PHI->addIncoming(Arg1, BB2); + EXPECT_EQ(PHI->getNumIncomingValues(), 3u); + EXPECT_EQ(PHI->getIncomingBlock(2), BB2); + EXPECT_EQ(PHI->getIncomingValue(2), Arg1); + Ctx.revert(); + EXPECT_EQ(PHI->getNumIncomingValues(), 2u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + EXPECT_EQ(PHI->getIncomingBlock(1), BB1); + EXPECT_EQ(PHI->getIncomingValue(1), Arg1); + + // Check removeIncomingValue(1). + Ctx.save(); + PHI->removeIncomingValue(1); + EXPECT_EQ(PHI->getNumIncomingValues(), 1u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + Ctx.revert(); + EXPECT_EQ(PHI->getNumIncomingValues(), 2u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + EXPECT_EQ(PHI->getIncomingBlock(1), BB1); + EXPECT_EQ(PHI->getIncomingValue(1), Arg1); + + // Check removeIncomingValue(0). + Ctx.save(); + PHI->removeIncomingValue(0u); + EXPECT_EQ(PHI->getNumIncomingValues(), 1u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB1); + EXPECT_EQ(PHI->getIncomingValue(0), Arg1); + Ctx.revert(); + EXPECT_EQ(PHI->getNumIncomingValues(), 2u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + EXPECT_EQ(PHI->getIncomingBlock(1), BB1); + EXPECT_EQ(PHI->getIncomingValue(1), Arg1); + + // Check removeIncomingValue() remove all. + Ctx.save(); + PHI->removeIncomingValue(0u); + EXPECT_EQ(PHI->getNumIncomingValues(), 1u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB1); + EXPECT_EQ(PHI->getIncomingValue(0), Arg1); + PHI->removeIncomingValue(0u); + EXPECT_EQ(PHI->getNumIncomingValues(), 0u); + Ctx.revert(); + EXPECT_EQ(PHI->getNumIncomingValues(), 2u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + EXPECT_EQ(PHI->getIncomingBlock(1), BB1); + EXPECT_EQ(PHI->getIncomingValue(1), Arg1); + + // Check removeIncomingValue(BasicBlock *). + Ctx.save(); + PHI->removeIncomingValue(BB1); + EXPECT_EQ(PHI->getNumIncomingValues(), 1u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + Ctx.revert(); + EXPECT_EQ(PHI->getNumIncomingValues(), 2u); + EXPECT_EQ(PHI->getIncomingBlock(0), BB0); + EXPECT_EQ(PHI->getIncomingValue(0), Arg0); + EXPECT_EQ(PHI->getIncomingBlock(1), BB1); + EXPECT_EQ(PHI->getIncomingValue(1), Arg1); +}