Skip to content

Commit

Permalink
[DirectX] Lower @llvm.dx.typedBufferLoad to DXIL ops
Browse files Browse the repository at this point in the history
The `@llvm.dx.typedBufferLoad` intrinsic is lowered to `@dx.op.bufferLoad`.
There's some complexity here due to translating from a vector return type to a
named struct and trying to avoid excessive IR coming out of that.

Note that this change includes a bit of a hack in how it deals with
`getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal
with operation overloads to generate a table directly rather than proxy through
the OverloadKind enum, but that's left for a later change here.

Pull Request: llvm#104252
  • Loading branch information
bogner committed Aug 14, 2024
1 parent c0d8b67 commit d78ffd2
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 8 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def int_dx_handle_fromBinding
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
[IntrNoMem]>;

def int_dx_typedBufferLoad
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[llvm_any_ty, llvm_i32_ty]>;

// Cast between target extension handle types and dxil-style opaque handles
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;

Expand Down
16 changes: 15 additions & 1 deletion llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType;
def HalfTy : DXILOpParamType;
def FloatTy : DXILOpParamType;
def DoubleTy : DXILOpParamType;
def ResRetTy : DXILOpParamType;
def ResRetHalfTy : DXILOpParamType;
def ResRetFloatTy : DXILOpParamType;
def ResRetInt16Ty : DXILOpParamType;
def ResRetInt32Ty : DXILOpParamType;
def HandleTy : DXILOpParamType;
def ResBindTy : DXILOpParamType;
def ResPropsTy : DXILOpParamType;
Expand Down Expand Up @@ -683,6 +686,17 @@ def CreateHandle : DXILOp<57, createHandle> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def BufferLoad : DXILOp<68, bufferLoad> {
let Doc = "reads from a TypedBuffer";
// Handle, Coord0, Coord1
let arguments = [HandleTy, Int32Ty, Int32Ty];
let result = OverloadTy;
let overloads =
[Overloads<DXIL1_0,
[ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def ThreadId : DXILOp<93, threadId> {
let Doc = "Reads the thread ID";
let LLVMIntrinsic = int_dx_thread_id;
Expand Down
31 changes: 25 additions & 6 deletions llvm/lib/Target/DirectX/DXILOpBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,15 @@ static OverloadKind getOverloadKind(Type *Ty) {
}
case Type::PointerTyID:
return OverloadKind::UserDefineType;
case Type::StructTyID:
case Type::StructTyID: {
// TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
// how we're handling overloads and remove the `OverloadKind` proxy enum.
StructType *ST = cast<StructType>(Ty);
if (ST->hasName() && ST->getName().starts_with("dx.types.ResRet"))
return getOverloadKind(ST->getElementType(0));

return OverloadKind::ObjectType;
}
default:
llvm_unreachable("invalid overload type");
return OverloadKind::VOID;
Expand Down Expand Up @@ -195,10 +202,11 @@ static StructType *getOrCreateStructType(StringRef Name,
return StructType::create(Ctx, EltTys, Name);
}

static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
OverloadKind Kind = getOverloadKind(OverloadTy);
static StructType *getResRetType(Type *ElementTy) {
LLVMContext &Ctx = ElementTy->getContext();
OverloadKind Kind = getOverloadKind(ElementTy);
std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
Type::getInt32Ty(Ctx)};
return getOrCreateStructType(TypeName, FieldTypes, Ctx);
}
Expand Down Expand Up @@ -248,8 +256,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return Type::getInt64Ty(Ctx);
case OpParamType::OverloadTy:
return OverloadTy;
case OpParamType::ResRetTy:
return getResRetType(OverloadTy, Ctx);
case OpParamType::ResRetHalfTy:
return getResRetType(Type::getHalfTy(Ctx));
case OpParamType::ResRetFloatTy:
return getResRetType(Type::getFloatTy(Ctx));
case OpParamType::ResRetInt16Ty:
return getResRetType(Type::getInt16Ty(Ctx));
case OpParamType::ResRetInt32Ty:
return getResRetType(Type::getInt32Ty(Ctx));
case OpParamType::HandleTy:
return getHandleType(Ctx);
case OpParamType::ResBindTy:
Expand Down Expand Up @@ -391,6 +405,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
return makeOpError(OpCode, "Wrong number of arguments");
OverloadTy = Args[ArgIndex]->getType();
}

FunctionType *DXILOpFT =
getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);

Expand Down Expand Up @@ -451,6 +466,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
return *Result;
}

StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
return ::getResRetType(ElementTy);
}

StructType *DXILOpBuilder::getHandleType() {
return ::getHandleType(IRB.getContext());
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/DirectX/DXILOpBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class DXILOpBuilder {
Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
Type *RetTy = nullptr);

/// Get a `%dx.types.ResRet` type with the given element type.
StructType *getResRetType(Type *ElementTy);
/// Get the `%dx.types.Handle` type.
StructType *getHandleType();

Expand Down
57 changes: 57 additions & 0 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,59 @@ class OpLowerer {
lowerToBindAndAnnotateHandle(F);
}

void lowerTypedBufferLoad(Function &F) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int32Ty = IRB.getInt32Ty();

replaceFunction(F, [&](CallInst *CI) -> Error {
IRB.SetInsertPoint(CI);

Value *Handle =
createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
Value *Index0 = CI->getArgOperand(1);
Value *Index1 = UndefValue::get(Int32Ty);
Type *RetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());

std::array<Value *, 3> Args{Handle, Index0, Index1};
Expected<CallInst *> OpCall =
OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, RetTy);
if (Error E = OpCall.takeError())
return E;

std::array<Value *, 4> Extracts = {};

// We've switched the return type from a vector to a struct, but at this
// point most vectors have probably already been scalarized. Try to
// forward arguments directly rather than inserting into and immediately
// extracting from a vector.
for (Use &U : make_early_inc_range(CI->uses()))
if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser()))
if (auto *Index = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
size_t IndexVal = Index->getZExtValue();
assert(IndexVal < 4 && "Index into buffer load out of range");
if (!Extracts[IndexVal])
Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal);
EEI->replaceAllUsesWith(Extracts[IndexVal]);
EEI->eraseFromParent();
}

// If there are still uses then we need to create a vector.
if (!CI->use_empty()) {
for (int I = 0, E = 4; I != E; ++I)
if (!Extracts[I])
Extracts[I] = IRB.CreateExtractValue(*OpCall, I);

Value *Vec = UndefValue::get(CI->getType());
for (int I = 0, E = 4; I != E; ++I)
Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
CI->replaceAllUsesWith(Vec);
}

CI->eraseFromParent();
return Error::success();
});
}

bool lowerIntrinsics() {
bool Updated = false;

Expand All @@ -253,6 +306,10 @@ class OpLowerer {
#include "DXILOperation.inc"
case Intrinsic::dx_handle_fromBinding:
lowerHandleFromBinding(F);
break;
case Intrinsic::dx_typedBufferLoad:
lowerTypedBufferLoad(F);
break;
}
Updated = true;
}
Expand Down
102 changes: 102 additions & 0 deletions llvm/test/CodeGen/DirectX/BufferLoad.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
; RUN: opt -S -dxil-op-lower %s | FileCheck %s

target triple = "dxil-pc-shadermodel6.6-compute"

declare void @scalar_user(float)
declare void @vector_user(<4 x float>)

define void @loadfloats() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
%buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; The temporary casts should all have been cleaned up
; CHECK-NOT: %dx.cast_handle

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call <4 x float> @llvm.dx.typedBufferLoad(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)

; The extract order depends on the users, so don't enforce that here.
; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
%data0_0 = extractelement <4 x float> %data0, i32 0
; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
%data0_2 = extractelement <4 x float> %data0, i32 2

; If all of the uses are extracts, we skip creating a vector
; CHECK-NOT: insertelement
call void @scalar_user(float %data0_0)
call void @scalar_user(float %data0_2)

; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
%data4 = call <4 x float> @llvm.dx.typedBufferLoad(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)

; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
; CHECK: insertelement <4 x float> undef
; CHECK: insertelement <4 x float>
; CHECK: insertelement <4 x float>
; CHECK: insertelement <4 x float>
call void @vector_user(<4 x float> %data4)

; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
%data12 = call <4 x float> @llvm.dx.typedBufferLoad(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)

; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
%data12_3 = extractelement <4 x float> %data12, i32 3

; If there are a mix of users we need the vector, but extracts are direct
; CHECK: call void @scalar_user(float [[DATA12_3]])
call void @scalar_user(float %data12_3)
call void @vector_user(<4 x float> %data12)

ret void
}

define void @loadint() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
%buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)

ret void
}

define void @loadhalf() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
%buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0)
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call <4 x half> @llvm.dx.typedBufferLoad(
target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)

ret void
}

define void @loadi16() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
%buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0)
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)

ret void
}
6 changes: 5 additions & 1 deletion llvm/utils/TableGen/DXILEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) {
.Case("Int8Ty", "OverloadKind::I8")
.Case("Int16Ty", "OverloadKind::I16")
.Case("Int32Ty", "OverloadKind::I32")
.Case("Int64Ty", "OverloadKind::I64");
.Case("Int64Ty", "OverloadKind::I64")
.Case("ResRetHalfTy", "OverloadKind::HALF")
.Case("ResRetFloatTy", "OverloadKind::FLOAT")
.Case("ResRetInt16Ty", "OverloadKind::I16")
.Case("ResRetInt32Ty", "OverloadKind::I32");
}

/// Return a string representation of valid overload information denoted
Expand Down

0 comments on commit d78ffd2

Please sign in to comment.