Skip to content

Commit

Permalink
[DirectX] Lower @llvm.dx.handle.fromBinding to DXIL ops
Browse files Browse the repository at this point in the history
The `@llvm.dx.handle.fromBinding` intrinsic is lowered either to the
`CreateHandle` op or a pair of `CreateHandleFromBinding` and `AnnotateHandle`
ops, depending on the DXIL version. Regardless of the DXIL version we need to
emit metadata about the binding, but that's left to a separate change.

These DXIL ops all need to return the `%dx.types.Handle` type, but the llvm
intrinsic returns a target extension type. To facilitate changing the type of
the operation and all of its users, we introduce `%llvm.dx.cast.handle`, which
can cast between the two handle representations.

Pull Request: llvm#104251
  • Loading branch information
bogner committed Aug 14, 2024
1 parent c76bc28 commit c0d8b67
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 7 deletions.
8 changes: 8 additions & 0 deletions llvm/include/llvm/Analysis/DXILResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TargetExtType;
namespace dxil {

class ResourceInfo {
public:
struct ResourceBinding {
uint32_t RecordID;
uint32_t Space;
Expand Down Expand Up @@ -89,6 +90,7 @@ class ResourceInfo {
bool operator!=(const FeedbackInfo &RHS) const { return !(*this == RHS); }
};

private:
// Universal properties.
Value *Symbol;
StringRef Name;
Expand All @@ -115,6 +117,10 @@ class ResourceInfo {

MSInfo MultiSample;

// We need a default constructor if we want to insert this in a MapVector.
ResourceInfo() {}
friend class MapVector<CallInst *, ResourceInfo>;

public:
ResourceInfo(dxil::ResourceClass RC, dxil::ResourceKind Kind, Value *Symbol,
StringRef Name)
Expand Down Expand Up @@ -166,6 +172,8 @@ class ResourceInfo {
MultiSample.Count = Count;
}

dxil::ResourceClass getResourceClass() const { return RC; }

bool operator==(const ResourceInfo &RHS) const;

static ResourceInfo SRV(Value *Symbol, StringRef Name,
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def int_dx_handle_fromBinding
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
[IntrNoMem]>;

// Cast between target extension handle types and dxil-style opaque handles
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;

def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
Expand Down
24 changes: 24 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def FloatTy : DXILOpParamType;
def DoubleTy : DXILOpParamType;
def ResRetTy : DXILOpParamType;
def HandleTy : DXILOpParamType;
def ResBindTy : DXILOpParamType;
def ResPropsTy : DXILOpParamType;

class DXILOpClass;

Expand Down Expand Up @@ -673,6 +675,14 @@ def Dot4 : DXILOp<56, dot4> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def CreateHandle : DXILOp<57, createHandle> {
let Doc = "creates the handle to a resource";
// ResourceClass, RangeID, Index, NonUniform
let arguments = [Int8Ty, Int32Ty, Int32Ty, Int1Ty];
let result = HandleTy;
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def ThreadId : DXILOp<93, threadId> {
let Doc = "Reads the thread ID";
let LLVMIntrinsic = int_dx_thread_id;
Expand Down Expand Up @@ -712,3 +722,17 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> {
let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def AnnotateHandle : DXILOp<217, annotateHandle> {
let Doc = "annotate handle with resource properties";
let arguments = [HandleTy, ResPropsTy];
let result = HandleTy;
let stages = [Stages<DXIL1_6, [all_stages]>];
}

def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> {
let Doc = "create resource handle from binding";
let arguments = [ResBindTy, Int32Ty, Int1Ty];
let result = HandleTy;
let stages = [Stages<DXIL1_6, [all_stages]>];
}
44 changes: 44 additions & 0 deletions llvm/lib/Target/DirectX/DXILOpBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,23 @@ static StructType *getHandleType(LLVMContext &Ctx) {
Ctx);
}

static StructType *getResBindType(LLVMContext &Context) {
if (auto *ST = StructType::getTypeByName(Context, "dx.types.ResBind"))
return ST;
Type *Int32Ty = Type::getInt32Ty(Context);
Type *Int8Ty = Type::getInt8Ty(Context);
return StructType::create({Int32Ty, Int32Ty, Int32Ty, Int8Ty},
"dx.types.ResBind");
}

static StructType *getResPropsType(LLVMContext &Context) {
if (auto *ST =
StructType::getTypeByName(Context, "dx.types.ResourceProperties"))
return ST;
Type *Int32Ty = Type::getInt32Ty(Context);
return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
}

static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
Type *OverloadTy) {
switch (Kind) {
Expand Down Expand Up @@ -235,6 +252,10 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return getResRetType(OverloadTy, Ctx);
case OpParamType::HandleTy:
return getHandleType(Ctx);
case OpParamType::ResBindTy:
return getResBindType(Ctx);
case OpParamType::ResPropsTy:
return getResPropsType(Ctx);
}
llvm_unreachable("Invalid parameter kind");
return nullptr;
Expand Down Expand Up @@ -430,6 +451,29 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
return *Result;
}

StructType *DXILOpBuilder::getHandleType() {
return ::getHandleType(IRB.getContext());
}

Constant *DXILOpBuilder::getResBind(uint32_t LowerBound, uint32_t UpperBound,
uint32_t SpaceID, dxil::ResourceClass RC) {
Type *Int32Ty = IRB.getInt32Ty();
Type *Int8Ty = IRB.getInt8Ty();
return ConstantStruct::get(
getResBindType(IRB.getContext()),
{ConstantInt::get(Int32Ty, LowerBound),
ConstantInt::get(Int32Ty, UpperBound),
ConstantInt::get(Int32Ty, SpaceID),
ConstantInt::get(Int8Ty, llvm::to_underlying(RC))});
}

Constant *DXILOpBuilder::getResProps(uint32_t Word0, uint32_t Word1) {
Type *Int32Ty = IRB.getInt32Ty();
return ConstantStruct::get(
getResPropsType(IRB.getContext()),
{ConstantInt::get(Int32Ty, Word0), ConstantInt::get(Int32Ty, Word1)});
}

const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
return ::getOpCodeName(DXILOp);
}
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/DirectX/DXILOpBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
#include "DXILConstants.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/DXILABI.h"
#include "llvm/Support/Error.h"
#include "llvm/TargetParser/Triple.h"

namespace llvm {
class Module;
class IRBuilderBase;
class CallInst;
class Constant;
class Value;
class Type;
class FunctionType;
Expand All @@ -44,6 +46,15 @@ class DXILOpBuilder {
Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
Type *RetTy = nullptr);

/// Get the `%dx.types.Handle` type.
StructType *getHandleType();

/// Get a constant `%dx.types.ResBind` value.
Constant *getResBind(uint32_t LowerBound, uint32_t UpperBound,
uint32_t SpaceID, dxil::ResourceClass RC);
/// Get a constant `%dx.types.ResourceProperties` value.
Constant *getResProps(uint32_t Word0, uint32_t Word1);

/// Return the name of the given opcode.
static const char *getOpCodeName(dxil::OpCode DXILOp);

Expand Down
146 changes: 139 additions & 7 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "DXILOpBuilder.h"
#include "DirectX.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
Expand All @@ -20,6 +21,7 @@
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/ErrorHandling.h"

Expand Down Expand Up @@ -74,9 +76,11 @@ namespace {
class OpLowerer {
Module &M;
DXILOpBuilder OpBuilder;
DXILResourceMap &DRM;
SmallVector<CallInst *> CleanupCasts;

public:
OpLowerer(Module &M) : M(M), OpBuilder(M) {}
OpLowerer(Module &M, DXILResourceMap &DRM) : M(M), OpBuilder(M), DRM(DRM) {}

void replaceFunction(Function &F,
llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
Expand Down Expand Up @@ -119,6 +123,119 @@ class OpLowerer {
});
}

Value *createTmpHandleCast(Value *V, Type *Ty) {
Function *CastFn = Intrinsic::getDeclaration(&M, Intrinsic::dx_cast_handle,
{Ty, V->getType()});
CallInst *Cast = OpBuilder.getIRB().CreateCall(CastFn, {V});
CleanupCasts.push_back(Cast);
return Cast;
}

void cleanupHandleCasts() {
SmallVector<CallInst *> ToRemove;
SmallVector<Function *> CastFns;

for (CallInst *Cast : CleanupCasts) {
CastFns.push_back(Cast->getCalledFunction());
// All of the ops should be using `dx.types.Handle` at this point, so if
// we're not producing that we should be part of a pair. Track this so we
// can remove it at the end.
if (Cast->getType() != OpBuilder.getHandleType()) {
ToRemove.push_back(Cast);
continue;
}
// Otherwise, we're the second handle in a pair. Forward the arguments and
// remove the (second) cast.
CallInst *Def = cast<CallInst>(Cast->getOperand(0));
assert(Def->getIntrinsicID() == Intrinsic::dx_cast_handle &&
"Unbalanced pair of temporary handle casts");
Cast->replaceAllUsesWith(Def->getOperand(0));
Cast->eraseFromParent();
}
for (CallInst *Cast : ToRemove) {
assert(Cast->user_empty() && "Temporary handle cast still has users");
Cast->eraseFromParent();
}
llvm::sort(CastFns);
CastFns.erase(llvm::unique(CastFns), CastFns.end());
for (Function *F : CastFns)
F->eraseFromParent();

CleanupCasts.clear();
}

void lowerToCreateHandle(Function &F) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int8Ty = IRB.getInt8Ty();
Type *Int32Ty = IRB.getInt32Ty();

replaceFunction(F, [&](CallInst *CI) -> Error {
IRB.SetInsertPoint(CI);

dxil::ResourceInfo &RI = DRM[CI];
dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding();

std::array<Value *, 4> Args{
ConstantInt::get(Int8Ty, llvm::to_underlying(RI.getResourceClass())),
ConstantInt::get(Int32Ty, Binding.RecordID), CI->getArgOperand(3),
CI->getArgOperand(4)};
Expected<CallInst *> OpCall =
OpBuilder.tryCreateOp(OpCode::CreateHandle, Args);
if (Error E = OpCall.takeError())
return E;

Value *Cast = createTmpHandleCast(*OpCall, CI->getType());

CI->replaceAllUsesWith(Cast);
CI->eraseFromParent();
return Error::success();
});
}

void lowerToBindAndAnnotateHandle(Function &F) {
IRBuilder<> &IRB = OpBuilder.getIRB();

replaceFunction(F, [&](CallInst *CI) -> Error {
IRB.SetInsertPoint(CI);

dxil::ResourceInfo &RI = DRM[CI];
dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding();
std::pair<uint32_t, uint32_t> Props = RI.getAnnotateProps();

Constant *ResBind = OpBuilder.getResBind(
Binding.LowerBound, Binding.LowerBound + Binding.Size - 1,
Binding.Space, RI.getResourceClass());
std::array<Value *, 3> BindArgs{ResBind, CI->getArgOperand(3),
CI->getArgOperand(4)};
Expected<CallInst *> OpBind =
OpBuilder.tryCreateOp(OpCode::CreateHandleFromBinding, BindArgs);
if (Error E = OpBind.takeError())
return E;

std::array<Value *, 2> AnnotateArgs{
*OpBind, OpBuilder.getResProps(Props.first, Props.second)};
Expected<CallInst *> OpAnnotate =
OpBuilder.tryCreateOp(OpCode::AnnotateHandle, AnnotateArgs);
if (Error E = OpAnnotate.takeError())
return E;

Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType());

CI->replaceAllUsesWith(Cast);
CI->eraseFromParent();

return Error::success();
});
}

void lowerHandleFromBinding(Function &F) {
Triple TT(Triple(M.getTargetTriple()));
if (TT.getDXILVersion() < VersionTuple(1, 6))
lowerToCreateHandle(F);
else
lowerToBindAndAnnotateHandle(F);
}

bool lowerIntrinsics() {
bool Updated = false;

Expand All @@ -134,40 +251,55 @@ class OpLowerer {
replaceFunctionWithOp(F, OpCode); \
break;
#include "DXILOperation.inc"
case Intrinsic::dx_handle_fromBinding:
lowerHandleFromBinding(F);
}
Updated = true;
}
if (Updated)
cleanupHandleCasts();

return Updated;
}
};
} // namespace

PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &) {
if (OpLowerer(M).lowerIntrinsics())
return PreservedAnalyses::none();
return PreservedAnalyses::all();
PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) {
DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);

bool MadeChanges = OpLowerer(M, DRM).lowerIntrinsics();
if (!MadeChanges)
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserve<DXILResourceAnalysis>();
return PA;
}

namespace {
class DXILOpLoweringLegacy : public ModulePass {
public:
bool runOnModule(Module &M) override {
return OpLowerer(M).lowerIntrinsics();
DXILResourceMap &DRM =
getAnalysis<DXILResourceWrapperPass>().getResourceMap();

return OpLowerer(M, DRM).lowerIntrinsics();
}
StringRef getPassName() const override { return "DXIL Op Lowering"; }
DXILOpLoweringLegacy() : ModulePass(ID) {}

static char ID; // Pass identification.
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
// Specify the passes that your pass depends on
AU.addRequired<DXILIntrinsicExpansionLegacy>();
AU.addRequired<DXILResourceWrapperPass>();
AU.addPreserved<DXILResourceWrapperPass>();
}
};
char DXILOpLoweringLegacy::ID = 0;
} // end anonymous namespace

INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
false, false)
INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
false)

Expand Down
Loading

0 comments on commit c0d8b67

Please sign in to comment.