Skip to content

Commit

Permalink
[SandboxIR] Implement PHINodes (llvm#101111)
Browse files Browse the repository at this point in the history
This patch implements sandboxir::PHINode which mirrors llvm::PHINode.

Based almost entirely on work by vporpo.
  • Loading branch information
Sterling-Augustine authored Jul 31, 2024
1 parent 65d3c22 commit 3403b59
Show file tree
Hide file tree
Showing 8 changed files with 587 additions and 0 deletions.
98 changes: 98 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ class Value {
friend class CallBrInst; // For getting `Val`.
friend class GetElementPtrInst; // For getting `Val`.
friend class CastInst; // For getting `Val`.
friend class PHINode; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -618,6 +619,7 @@ class Instruction : public sandboxir::User {
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
friend class CastInst; // For getTopmostLLVMInstruction().
friend class PHINode; // For getTopmostLLVMInstruction().

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
Expand Down Expand Up @@ -1515,6 +1517,100 @@ class IntToPtrInst final : public CastInst {
#endif // NDEBUG
};

class PHINode final : public Instruction {
/// Use Context::createPHINode(). Don't call the constructor directly.
PHINode(llvm::PHINode *PHI, Context &Ctx)
: Instruction(ClassID::PHI, Opcode::PHI, PHI, Ctx) {}
friend Context; // for PHINode()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}
/// Helper for mapped_iterator.
struct LLVMBBToBB {
Context &Ctx;
LLVMBBToBB(Context &Ctx) : Ctx(Ctx) {}
BasicBlock *operator()(llvm::BasicBlock *LLVMBB) const;
};

public:
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
static PHINode *create(Type *Ty, unsigned NumReservedValues,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name = "");
/// For isa/dyn_cast.
static bool classof(const Value *From);

using const_block_iterator =
mapped_iterator<llvm::PHINode::const_block_iterator, LLVMBBToBB>;

const_block_iterator block_begin() const {
LLVMBBToBB BBGetter(Ctx);
return const_block_iterator(cast<llvm::PHINode>(Val)->block_begin(),
BBGetter);
}
const_block_iterator block_end() const {
LLVMBBToBB BBGetter(Ctx);
return const_block_iterator(cast<llvm::PHINode>(Val)->block_end(),
BBGetter);
}
iterator_range<const_block_iterator> blocks() const {
return make_range(block_begin(), block_end());
}

op_range incoming_values() { return operands(); }

const_op_range incoming_values() const { return operands(); }

unsigned getNumIncomingValues() const {
return cast<llvm::PHINode>(Val)->getNumIncomingValues();
}
Value *getIncomingValue(unsigned Idx) const;
void setIncomingValue(unsigned Idx, Value *V);
static unsigned getOperandNumForIncomingValue(unsigned Idx) {
return llvm::PHINode::getOperandNumForIncomingValue(Idx);
}
static unsigned getIncomingValueNumForOperand(unsigned Idx) {
return llvm::PHINode::getIncomingValueNumForOperand(Idx);
}
BasicBlock *getIncomingBlock(unsigned Idx) const;
BasicBlock *getIncomingBlock(const Use &U) const;

void setIncomingBlock(unsigned Idx, BasicBlock *BB);

void addIncoming(Value *V, BasicBlock *BB);

Value *removeIncomingValue(unsigned Idx);
Value *removeIncomingValue(BasicBlock *BB);

int getBasicBlockIndex(const BasicBlock *BB) const;
Value *getIncomingValueForBlock(const BasicBlock *BB) const;

Value *hasConstantValue() const;

bool hasConstantOrUndefValue() const {
return cast<llvm::PHINode>(Val)->hasConstantOrUndefValue();
}
bool isComplete() const { return cast<llvm::PHINode>(Val)->isComplete(); }
// TODO: Implement the below functions:
// void replaceIncomingBlockWith (const BasicBlock *Old, BasicBlock *New);
// void copyIncomingBlocks(iterator_range<const_block_iterator> BBRange,
// uint32_t ToIdx = 0)
// void removeIncomingValueIf(function_ref< bool(unsigned)> Predicate,
// bool DeletePHIIfEmpty=true)
#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::PHINode>(Val) && "Expected PHINode!");
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};
class PtrToIntInst final : public CastInst {
public:
static Value *create(Value *Src, Type *DestTy, BBIterator WhereIt,
Expand Down Expand Up @@ -1700,6 +1796,8 @@ class Context {
friend GetElementPtrInst; // For createGetElementPtrInst()
CastInst *createCastInst(llvm::CastInst *I);
friend CastInst; // For createCastInst()
PHINode *createPHINode(llvm::PHINode *I);
friend PHINode; // For createPHINode()

public:
Context(LLVMContext &LLVMCtx)
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ DEF_INSTR(Cast, OPCODES(\
OP(BitCast) \
OP(AddrSpaceCast) \
), CastInst)
DEF_INSTR(PHI, OP(PHI), PHINode)

// clang-format on
#ifdef DEF_VALUE
#undef DEF_VALUE
Expand Down
58 changes: 58 additions & 0 deletions llvm/include/llvm/SandboxIR/Tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,64 @@ class UseSet : public IRChangeBase {
#endif
};

class PHISetIncoming : public IRChangeBase {
PHINode &PHI;
unsigned Idx;
PointerUnion<Value *, BasicBlock *> OrigValueOrBB;

public:
enum class What {
Value,
Block,
};
PHISetIncoming(PHINode &PHI, unsigned Idx, What What, Tracker &Tracker);
void revert() final;
void accept() final {}
#ifndef NDEBUG
void dump(raw_ostream &OS) const final {
dumpCommon(OS);
OS << "PHISetIncoming";
}
LLVM_DUMP_METHOD void dump() const final;
#endif
};

class PHIRemoveIncoming : public IRChangeBase {
PHINode &PHI;
unsigned RemovedIdx;
Value *RemovedV;
BasicBlock *RemovedBB;

public:
PHIRemoveIncoming(PHINode &PHI, unsigned RemovedIdx, Tracker &Tracker);
void revert() final;
void accept() final {}
#ifndef NDEBUG
void dump(raw_ostream &OS) const final {
dumpCommon(OS);
OS << "PHISetIncoming";
}
LLVM_DUMP_METHOD void dump() const final;
#endif
};

class PHIAddIncoming : public IRChangeBase {
PHINode &PHI;
unsigned Idx;

public:
PHIAddIncoming(PHINode &PHI, Tracker &Tracker);
void revert() final;
void accept() final {}
#ifndef NDEBUG
void dump(raw_ostream &OS) const final {
dumpCommon(OS);
OS << "PHISetIncoming";
}
LLVM_DUMP_METHOD void dump() const final;
#endif
};

/// Tracks swapping a Use with another Use.
class UseSwap : public IRChangeBase {
Use ThisUse;
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/SandboxIR/Use.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Context;
class Value;
class User;
class CallBase;
class PHINode;

/// Represents a Def-use/Use-def edge in SandboxIR.
/// NOTE: Unlike llvm::Use, this is not an integral part of the use-def chains.
Expand All @@ -43,6 +44,7 @@ class Use {
friend class UserUseIterator; // For accessing members
friend class CallBase; // For LLVMUse
friend class CallBrInst; // For constructor
friend class PHINode; // For LLVMUse

public:
operator Value *() const { return get(); }
Expand Down
108 changes: 108 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,95 @@ void GetElementPtrInst::dump() const {
}
#endif // NDEBUG

BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const {
return cast<BasicBlock>(Ctx.getValue(LLVMBB));
}

PHINode *PHINode::create(Type *Ty, unsigned NumReservedValues,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
llvm::PHINode *NewPHI = llvm::PHINode::Create(
Ty, NumReservedValues, Name, InsertBefore->getTopmostLLVMInstruction());
return Ctx.createPHINode(NewPHI);
}

bool PHINode::classof(const Value *From) {
return From->getSubclassID() == ClassID::PHI;
}

Value *PHINode::getIncomingValue(unsigned Idx) const {
return Ctx.getValue(cast<llvm::PHINode>(Val)->getIncomingValue(Idx));
}
void PHINode::setIncomingValue(unsigned Idx, Value *V) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHISetIncoming>(
*this, Idx, PHISetIncoming::What::Value, Tracker));

cast<llvm::PHINode>(Val)->setIncomingValue(Idx, V->Val);
}
BasicBlock *PHINode::getIncomingBlock(unsigned Idx) const {
return cast<BasicBlock>(
Ctx.getValue(cast<llvm::PHINode>(Val)->getIncomingBlock(Idx)));
}
BasicBlock *PHINode::getIncomingBlock(const Use &U) const {
llvm::Use *LLVMUse = U.LLVMUse;
llvm::BasicBlock *BB = cast<llvm::PHINode>(Val)->getIncomingBlock(*LLVMUse);
return cast<BasicBlock>(Ctx.getValue(BB));
}
void PHINode::setIncomingBlock(unsigned Idx, BasicBlock *BB) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHISetIncoming>(
*this, Idx, PHISetIncoming::What::Block, Tracker));
cast<llvm::PHINode>(Val)->setIncomingBlock(Idx,
cast<llvm::BasicBlock>(BB->Val));
}
void PHINode::addIncoming(Value *V, BasicBlock *BB) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHIAddIncoming>(*this, Tracker));

cast<llvm::PHINode>(Val)->addIncoming(V->Val,
cast<llvm::BasicBlock>(BB->Val));
}
Value *PHINode::removeIncomingValue(unsigned Idx) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHIRemoveIncoming>(*this, Idx, Tracker));

llvm::Value *LLVMV =
cast<llvm::PHINode>(Val)->removeIncomingValue(Idx,
/*DeletePHIIfEmpty=*/false);
return Ctx.getValue(LLVMV);
}
Value *PHINode::removeIncomingValue(BasicBlock *BB) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<PHIRemoveIncoming>(
*this, getBasicBlockIndex(BB), Tracker));

auto *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
llvm::Value *LLVMV =
cast<llvm::PHINode>(Val)->removeIncomingValue(LLVMBB,
/*DeletePHIIfEmpty=*/false);
return Ctx.getValue(LLVMV);
}
int PHINode::getBasicBlockIndex(const BasicBlock *BB) const {
auto *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
return cast<llvm::PHINode>(Val)->getBasicBlockIndex(LLVMBB);
}
Value *PHINode::getIncomingValueForBlock(const BasicBlock *BB) const {
auto *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
llvm::Value *LLVMV =
cast<llvm::PHINode>(Val)->getIncomingValueForBlock(LLVMBB);
return Ctx.getValue(LLVMV);
}
Value *PHINode::hasConstantValue() const {
llvm::Value *LLVMV = cast<llvm::PHINode>(Val)->hasConstantValue();
return LLVMV != nullptr ? Ctx.getValue(LLVMV) : nullptr;
}

static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
switch (Opc) {
case Instruction::Opcode::ZExt:
Expand Down Expand Up @@ -1272,6 +1361,16 @@ Value *PtrToIntInst::create(Value *Src, Type *DestTy, BasicBlock *InsertAtEnd,
}

#ifndef NDEBUG
void PHINode::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void PHINode::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void PtrToIntInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
Expand Down Expand Up @@ -1537,6 +1636,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
return It->second.get();
}
case llvm::Instruction::PHI: {
auto *LLVMPhi = cast<llvm::PHINode>(LLVMV);
It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
return It->second.get();
}
default:
break;
}
Expand Down Expand Up @@ -1606,6 +1710,10 @@ CastInst *Context::createCastInst(llvm::CastInst *I) {
auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
return cast<CastInst>(registerValue(std::move(NewPtr)));
}
PHINode *Context::createPHINode(llvm::PHINode *I) {
auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
return cast<PHINode>(registerValue(std::move(NewPtr)));
}

Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
Expand Down
Loading

0 comments on commit 3403b59

Please sign in to comment.