Skip to content

Commit

Permalink
[SandboxIR] Implement BranchInst (#100063)
Browse files Browse the repository at this point in the history
This patch implements sandboxir::BranchInst which mirrors
llvm::BranchInst.

BranchInst::swapSuccessors() relies on User::swapOperandsInternal() so
this patch also adds Use::swap() and the corresponding tracking code and
test.
  • Loading branch information
vporpo authored Jul 23, 2024
1 parent 95ea37c commit 3993da2
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 2 deletions.
107 changes: 107 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class Context;
class Function;
class Instruction;
class SelectInst;
class BranchInst;
class LoadInst;
class ReturnInst;
class StoreInst;
Expand Down Expand Up @@ -179,6 +180,7 @@ class Value {
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class SelectInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
Expand Down Expand Up @@ -343,6 +345,14 @@ class User : public Value {
virtual unsigned getUseOperandNo(const Use &Use) const = 0;
friend unsigned Use::getOperandNo() const; // For getUseOperandNo()

void swapOperandsInternal(unsigned OpIdxA, unsigned OpIdxB) {
assert(OpIdxA < getNumOperands() && "OpIdxA out of bounds!");
assert(OpIdxB < getNumOperands() && "OpIdxB out of bounds!");
auto UseA = getOperandUse(OpIdxA);
auto UseB = getOperandUse(OpIdxB);
UseA.swap(UseB);
}

#ifndef NDEBUG
void verifyUserOfLLVMUse(const llvm::Use &Use) const;
#endif // NDEBUG
Expand Down Expand Up @@ -504,6 +514,7 @@ class Instruction : public sandboxir::User {
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
friend class SelectInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
Expand Down Expand Up @@ -617,6 +628,100 @@ class SelectInst : public Instruction {
#endif
};

class BranchInst : public Instruction {
/// Use Context::createBranchInst(). Don't call the constructor directly.
BranchInst(llvm::BranchInst *BI, Context &Ctx)
: Instruction(ClassID::Br, Opcode::Br, BI, Ctx) {}
friend Context; // for BranchInst()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
static BranchInst *create(BasicBlock *IfTrue, Instruction *InsertBefore,
Context &Ctx);
static BranchInst *create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd,
Context &Ctx);
static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse,
Value *Cond, Instruction *InsertBefore,
Context &Ctx);
static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse,
Value *Cond, BasicBlock *InsertAtEnd, Context &Ctx);
/// For isa/dyn_cast.
static bool classof(const Value *From);
bool isUnconditional() const {
return cast<llvm::BranchInst>(Val)->isUnconditional();
}
bool isConditional() const {
return cast<llvm::BranchInst>(Val)->isConditional();
}
Value *getCondition() const;
void setCondition(Value *V) { setOperand(0, V); }
unsigned getNumSuccessors() const { return 1 + isConditional(); }
BasicBlock *getSuccessor(unsigned SuccIdx) const;
void setSuccessor(unsigned Idx, BasicBlock *NewSucc);
void swapSuccessors() { swapOperandsInternal(1, 2); }

private:
struct LLVMBBToSBBB {
Context &Ctx;
LLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {}
BasicBlock *operator()(llvm::BasicBlock *BB) const;
};

struct ConstLLVMBBToSBBB {
Context &Ctx;
ConstLLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {}
const BasicBlock *operator()(const llvm::BasicBlock *BB) const;
};

public:
using sb_succ_op_iterator =
mapped_iterator<llvm::BranchInst::succ_op_iterator, LLVMBBToSBBB>;
iterator_range<sb_succ_op_iterator> successors() {
iterator_range<llvm::BranchInst::succ_op_iterator> LLVMRange =
cast<llvm::BranchInst>(Val)->successors();
LLVMBBToSBBB BBMap(Ctx);
sb_succ_op_iterator MappedBegin = map_iterator(LLVMRange.begin(), BBMap);
sb_succ_op_iterator MappedEnd = map_iterator(LLVMRange.end(), BBMap);
return make_range(MappedBegin, MappedEnd);
}

using const_sb_succ_op_iterator =
mapped_iterator<llvm::BranchInst::const_succ_op_iterator,
ConstLLVMBBToSBBB>;
iterator_range<const_sb_succ_op_iterator> successors() const {
iterator_range<llvm::BranchInst::const_succ_op_iterator> ConstLLVMRange =
static_cast<const llvm::BranchInst *>(cast<llvm::BranchInst>(Val))
->successors();
ConstLLVMBBToSBBB ConstBBMap(Ctx);
const_sb_succ_op_iterator ConstMappedBegin =
map_iterator(ConstLLVMRange.begin(), ConstBBMap);
const_sb_succ_op_iterator ConstMappedEnd =
map_iterator(ConstLLVMRange.end(), ConstBBMap);
return make_range(ConstMappedBegin, ConstMappedEnd);
}

#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::BranchInst>(Val) && "Expected BranchInst!");
}
friend raw_ostream &operator<<(raw_ostream &OS, const BranchInst &BI) {
BI.dump(OS);
return OS;
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

class LoadInst final : public Instruction {
/// Use LoadInst::create() instead of calling the constructor.
LoadInst(llvm::LoadInst *LI, Context &Ctx)
Expand Down Expand Up @@ -870,6 +975,8 @@ class Context {

SelectInst *createSelectInst(llvm::SelectInst *SI);
friend SelectInst; // For createSelectInst()
BranchInst *createBranchInst(llvm::BranchInst *I);
friend BranchInst; // For createBranchInst()
LoadInst *createLoadInst(llvm::LoadInst *LI);
friend LoadInst; // For createLoadInst()
StoreInst *createStoreInst(llvm::StoreInst *SI);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ DEF_USER(Constant, Constant)
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
Expand Down
21 changes: 21 additions & 0 deletions llvm/include/llvm/SandboxIR/Tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ class UseSet : public IRChangeBase {
#endif
};

/// Tracks swapping a Use with another Use.
class UseSwap : public IRChangeBase {
Use ThisUse;
Use OtherUse;

public:
UseSwap(const Use &ThisUse, const Use &OtherUse, Tracker &Tracker)
: IRChangeBase(Tracker), ThisUse(ThisUse), OtherUse(OtherUse) {
assert(ThisUse.getUser() == OtherUse.getUser() && "Expected same user!");
}
void revert() final { ThisUse.swap(OtherUse); }
void accept() final {}
#ifndef NDEBUG
void dump(raw_ostream &OS) const final {
dumpCommon(OS);
OS << "UseSwap";
}
LLVM_DUMP_METHOD void dump() const final;
#endif
};

class EraseFromParent : public IRChangeBase {
/// Contains all the data we need to restore an "erased" (i.e., detached)
/// instruction: the instruction itself and its operands in order.
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/Use.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Use {
void set(Value *V);
class User *getUser() const { return Usr; }
unsigned getOperandNo() const;
void swap(Use &OtherUse);
Context *getContext() const { return Ctx; }
bool operator==(const Use &Other) const {
assert(Ctx == Other.Ctx && "Contexts differ!");
Expand Down
96 changes: 96 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ void Use::set(Value *V) { LLVMUse->set(V->Val); }

unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }

void Use::swap(Use &OtherUse) {
auto &Tracker = Ctx->getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<UseSwap>(*this, OtherUse, Tracker));
LLVMUse->swap(*OtherUse.LLVMUse);
}

#ifndef NDEBUG
void Use::dump(raw_ostream &OS) const {
Value *Def = nullptr;
Expand Down Expand Up @@ -500,6 +507,85 @@ void SelectInst::dump() const {
}
#endif // NDEBUG

BranchInst *BranchInst::create(BasicBlock *IfTrue, Instruction *InsertBefore,
Context &Ctx) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::Instruction>(InsertBefore->Val));
llvm::BranchInst *NewBr =
Builder.CreateBr(cast<llvm::BasicBlock>(IfTrue->Val));
return Ctx.createBranchInst(NewBr);
}

BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd,
Context &Ctx) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
llvm::BranchInst *NewBr =
Builder.CreateBr(cast<llvm::BasicBlock>(IfTrue->Val));
return Ctx.createBranchInst(NewBr);
}

BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse,
Value *Cond, Instruction *InsertBefore,
Context &Ctx) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::Instruction>(InsertBefore->Val));
llvm::BranchInst *NewBr =
Builder.CreateCondBr(Cond->Val, cast<llvm::BasicBlock>(IfTrue->Val),
cast<llvm::BasicBlock>(IfFalse->Val));
return Ctx.createBranchInst(NewBr);
}

BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse,
Value *Cond, BasicBlock *InsertAtEnd,
Context &Ctx) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
llvm::BranchInst *NewBr =
Builder.CreateCondBr(Cond->Val, cast<llvm::BasicBlock>(IfTrue->Val),
cast<llvm::BasicBlock>(IfFalse->Val));
return Ctx.createBranchInst(NewBr);
}

bool BranchInst::classof(const Value *From) {
return From->getSubclassID() == ClassID::Br;
}

Value *BranchInst::getCondition() const {
assert(isConditional() && "Cannot get condition of an uncond branch!");
return Ctx.getValue(cast<llvm::BranchInst>(Val)->getCondition());
}

BasicBlock *BranchInst::getSuccessor(unsigned SuccIdx) const {
assert(SuccIdx < getNumSuccessors() &&
"Successor # out of range for Branch!");
return cast_or_null<BasicBlock>(
Ctx.getValue(cast<llvm::BranchInst>(Val)->getSuccessor(SuccIdx)));
}

void BranchInst::setSuccessor(unsigned Idx, BasicBlock *NewSucc) {
assert((Idx == 0 || Idx == 1) && "Out of bounds!");
setOperand(2u - Idx, NewSucc);
}

BasicBlock *BranchInst::LLVMBBToSBBB::operator()(llvm::BasicBlock *BB) const {
return cast<BasicBlock>(Ctx.getValue(BB));
}
const BasicBlock *
BranchInst::ConstLLVMBBToSBBB::operator()(const llvm::BasicBlock *BB) const {
return cast<BasicBlock>(Ctx.getValue(BB));
}
#ifndef NDEBUG
void BranchInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}
void BranchInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
Expand Down Expand Up @@ -758,6 +844,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
return It->second.get();
}
case llvm::Instruction::Br: {
auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
return It->second.get();
}
case llvm::Instruction::Load: {
auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
Expand Down Expand Up @@ -796,6 +887,11 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
return cast<SelectInst>(registerValue(std::move(NewPtr)));
}

BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
return cast<BranchInst>(registerValue(std::move(NewPtr)));
}

LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
return cast<LoadInst>(registerValue(std::move(NewPtr)));
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/SandboxIR/Tracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ void UseSet::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void UseSwap::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

Tracker::~Tracker() {
Expand Down
Loading

0 comments on commit 3993da2

Please sign in to comment.