Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SandboxIR] Add BasicBlock and adds functionality to Function and Context #97637

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/raw_ostream.h"
#include <iterator>

namespace llvm {

namespace sandboxir {

class Function;
class Context;
class Instruction;

/// A SandboxIR Value has users. This is the base class.
class Value {
Expand Down Expand Up @@ -106,6 +109,8 @@ class Value {
/// NOTE: Some SBInstructions, like Packs, may include more than one value.
llvm::Value *Val = nullptr;

friend class Context; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
// This is used by eraseFromParent().
Expand Down Expand Up @@ -205,6 +210,48 @@ class Constant : public sandboxir::User {
#endif
};

/// The BasicBlock::iterator.
class BBIterator {
public:
using difference_type = std::ptrdiff_t;
using value_type = Instruction;
using pointer = value_type *;
using reference = value_type &;
using iterator_category = std::bidirectional_iterator_tag;

private:
llvm::BasicBlock *BB;
llvm::BasicBlock::iterator It;
Context *Ctx;
pointer getInstr(llvm::BasicBlock::iterator It) const;

public:
BBIterator() : BB(nullptr), Ctx(nullptr) {}
BBIterator(llvm::BasicBlock *BB, llvm::BasicBlock::iterator It, Context *Ctx)
: BB(BB), It(It), Ctx(Ctx) {}
reference operator*() const { return *getInstr(It); }
BBIterator &operator++();
BBIterator operator++(int) {
auto Copy = *this;
++*this;
return Copy;
}
BBIterator &operator--();
BBIterator operator--(int) {
auto Copy = *this;
--*this;
return Copy;
}
bool operator==(const BBIterator &Other) const {
assert(Ctx == Other.Ctx && "BBIterators in different context!");
return It == Other.It;
}
bool operator!=(const BBIterator &Other) const { return !(*this == Other); }
/// \Returns the SBInstruction that corresponds to this iterator, or null if
/// the instruction is not found in the IR-to-SandboxIR tables.
pointer get() const { return getInstr(It); }
};

/// A sandboxir::User with operands and opcode.
class Instruction : public sandboxir::User {
public:
Expand All @@ -231,6 +278,8 @@ class Instruction : public sandboxir::User {
return OS;
}
#endif
/// This is used by BasicBlock::iterator.
virtual unsigned getNumOfIRInstrs() const = 0;
vporpo marked this conversation as resolved.
Show resolved Hide resolved
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From);

Expand All @@ -256,6 +305,7 @@ class OpaqueInst : public sandboxir::Instruction {
static bool classof(const sandboxir::Value *From) {
return From->getSubclassID() == ClassID::Opaque;
}
unsigned getNumOfIRInstrs() const final { return 1u; }
#ifndef NDEBUG
void verify() const final {
// Nothing to do
Expand All @@ -270,6 +320,54 @@ class OpaqueInst : public sandboxir::Instruction {
#endif
};

class BasicBlock : public Value {
/// Builds a graph that contains all values in \p BB in their original form
/// i.e., no vectorization is taking place here.
void buildBasicBlockFromLLVMIR(llvm::BasicBlock *LLVMBB);
friend class Context; // For `buildBasicBlockFromIR`

public:
BasicBlock(llvm::BasicBlock *BB, Context &SBCtx)
: Value(ClassID::Block, BB, SBCtx) {
buildBasicBlockFromLLVMIR(BB);
}
~BasicBlock() = default;
/// For isa/dyn_cast.
static bool classof(const Value *From) {
return From->getSubclassID() == Value::ClassID::Block;
}
Function *getParent() const;
using iterator = BBIterator;
iterator begin() const;
iterator end() const {
auto *BB = cast<llvm::BasicBlock>(Val);
return iterator(BB, BB->end(), &Ctx);
}
std::reverse_iterator<iterator> rbegin() const {
return std::make_reverse_iterator(end());
}
std::reverse_iterator<iterator> rend() const {
return std::make_reverse_iterator(begin());
}
Context &getContext() const { return Ctx; }
Instruction *getTerminator() const;
bool empty() const { return begin() == end(); }
Instruction &front() const;
Instruction &back() const;

#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::BasicBlock>(Val) && "Expected BasicBlock!");
}
friend raw_ostream &operator<<(raw_ostream &OS, const BasicBlock &SBBB) {
SBBB.dump(OS);
return OS;
}
void dump(raw_ostream &OS) const final;
LLVM_DUMP_METHOD void dump() const final;
#endif
};

class Context {
protected:
LLVMContext &LLVMCtx;
Expand All @@ -278,12 +376,53 @@ class Context {
DenseMap<llvm::Value *, std::unique_ptr<sandboxir::Value>>
LLVMValueToValueMap;

/// Take ownership of VPtr and store it in `LLVMValueToValueMap`.
Value *registerValue(std::unique_ptr<Value> &&VPtr);

Value *getOrCreateValueInternal(llvm::Value *V, llvm::User *U = nullptr);

Argument *getOrCreateArgument(llvm::Argument *LLVMArg) {
auto Pair = LLVMValueToValueMap.insert({LLVMArg, nullptr});
auto It = Pair.first;
if (Pair.second) {
It->second = std::make_unique<Argument>(LLVMArg, *this);
return cast<Argument>(It->second.get());
}
return cast<Argument>(It->second.get());
}

Value *getOrCreateValue(llvm::Value *LLVMV) {
return getOrCreateValueInternal(LLVMV, 0);
}

BasicBlock *createBasicBlock(llvm::BasicBlock *BB);

friend class BasicBlock; // For getOrCreateValue().

public:
Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx) {}

sandboxir::Value *getValue(llvm::Value *V) const;
const sandboxir::Value *getValue(const llvm::Value *V) const {
return getValue(const_cast<llvm::Value *>(V));
}

Function *createFunction(llvm::Function *F);

/// \Returns the number of values registered with Context.
size_t getNumValues() const { return LLVMValueToValueMap.size(); }
};

class Function : public sandboxir::Value {
/// Helper for mapped_iterator.
struct LLVMBBToBB {
Context &Ctx;
LLVMBBToBB(Context &Ctx) : Ctx(Ctx) {}
BasicBlock &operator()(llvm::BasicBlock &LLVMBB) const {
return *cast<BasicBlock>(Ctx.getValue(&LLVMBB));
}
};

public:
Function(llvm::Function *F, sandboxir::Context &Ctx)
: sandboxir::Value(ClassID::Function, F, Ctx) {}
Expand All @@ -292,6 +431,24 @@ class Function : public sandboxir::Value {
return From->getSubclassID() == ClassID::Function;
}

Argument *getArg(unsigned Idx) const {
llvm::Argument *Arg = cast<llvm::Function>(Val)->getArg(Idx);
return cast<Argument>(Ctx.getValue(Arg));
}

size_t arg_size() const { return cast<llvm::Function>(Val)->arg_size(); }
bool arg_empty() const { return cast<llvm::Function>(Val)->arg_empty(); }

using iterator = mapped_iterator<llvm::Function::iterator, LLVMBBToBB>;
iterator begin() const {
LLVMBBToBB BBGetter(Ctx);
return iterator(cast<llvm::Function>(Val)->begin(), BBGetter);
}
iterator end() const {
LLVMBBToBB BBGetter(Ctx);
return iterator(cast<llvm::Function>(Val)->end(), BBGetter);
}

#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::Function>(Val) && "Expected Function!");
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ DEF_VALUE(Argument, Argument)
#define DEF_USER(ID, CLASS)
#endif
DEF_USER(User, User)
DEF_VALUE(Block, BasicBlock)
DEF_USER(Constant, Constant)

#ifndef DEF_INSTR
Expand Down
Loading
Loading