Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add step builtins and step HLSL function to DirectX and SPIR-V backend #106471

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4763,6 +4763,7 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}


def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_select"];
let Attributes = [NoThrow, Const];
Expand All @@ -4775,6 +4776,12 @@ def HLSLSign : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLStep: LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_step"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18861,6 +18861,16 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {

return SelectVal;
}
case Builtin::BI__builtin_hlsl_step: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
Value *Op1 = EmitScalarExpr(E->getArg(1));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
"step operands must have a float representation");
return Builder.CreateIntrinsic(
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(Saturate, saturate)
GENERATE_HLSL_INTRINSIC_FUNCTION(Sign, sign)
GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
Expand Down
33 changes: 33 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,39 @@ float3 sqrt(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
float4 sqrt(float4);

//===----------------------------------------------------------------------===//
// step builtins
//===----------------------------------------------------------------------===//

/// \fn T step(T x, T y)
/// \brief Returns 1 if the x parameter is greater than or equal to the y
/// parameter; otherwise, 0. vector. \param x [in] The first floating-point
/// value to compare. \param y [in] The first floating-point value to compare.
///
/// Step is based on the following formula: (x >= y) ? 1 : 0

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
half step(half, half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
half2 step(half2, half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
half3 step(half3, half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
half4 step(half4, half4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
float step(float, float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
float2 step(float2, float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
float3 step(float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
float4 step(float4, float4);

//===----------------------------------------------------------------------===//
// tan builtins
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
break;
}
case Builtin::BI__builtin_hlsl_step: {
if (SemaRef.checkArgCount(TheCall, 2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: we usually check the arg count first.

return true;
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
// return type is the same as the input type
TheCall->setType(ArgTyA);
break;
}
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
Expand Down
84 changes: 84 additions & 0 deletions clang/test/CodeGenHLSL/builtins/step.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF \
// RUN: -DFNATTRS=noundef -DTARGET=dx
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
// RUN: -DFNATTRS=noundef -DTARGET=dx
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF \
// RUN: -DFNATTRS="spir_func noundef" -DTARGET=spv
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
// RUN: -DFNATTRS="spir_func noundef" -DTARGET=spv

// NATIVE_HALF: define [[FNATTRS]] half @
// NATIVE_HALF: call half @llvm.[[TARGET]].step.f16(half
// NO_HALF: call float @llvm.[[TARGET]].step.f32(float
// NATIVE_HALF: ret half
// NO_HALF: ret float
half test_step_half(half p0, half p1)
{
return step(p0, p1);
}
// NATIVE_HALF: define [[FNATTRS]] <2 x half> @
// NATIVE_HALF: call <2 x half> @llvm.[[TARGET]].step.v2f16(<2 x half>
// NO_HALF: call <2 x float> @llvm.[[TARGET]].step.v2f32(<2 x float>
// NATIVE_HALF: ret <2 x half> %hlsl.step
// NO_HALF: ret <2 x float> %hlsl.step
half2 test_step_half2(half2 p0, half2 p1)
{
return step(p0, p1);
}
// NATIVE_HALF: define [[FNATTRS]] <3 x half> @
// NATIVE_HALF: call <3 x half> @llvm.[[TARGET]].step.v3f16(<3 x half>
// NO_HALF: call <3 x float> @llvm.[[TARGET]].step.v3f32(<3 x float>
// NATIVE_HALF: ret <3 x half> %hlsl.step
// NO_HALF: ret <3 x float> %hlsl.step
half3 test_step_half3(half3 p0, half3 p1)
{
return step(p0, p1);
}
// NATIVE_HALF: define [[FNATTRS]] <4 x half> @
// NATIVE_HALF: call <4 x half> @llvm.[[TARGET]].step.v4f16(<4 x half>
// NO_HALF: call <4 x float> @llvm.[[TARGET]].step.v4f32(<4 x float>
// NATIVE_HALF: ret <4 x half> %hlsl.step
// NO_HALF: ret <4 x float> %hlsl.step
half4 test_step_half4(half4 p0, half4 p1)
{
return step(p0, p1);
}

// CHECK: define [[FNATTRS]] float @
// CHECK: call float @llvm.[[TARGET]].step.f32(float
// CHECK: ret float
float test_step_float(float p0, float p1)
{
return step(p0, p1);
}
// CHECK: define [[FNATTRS]] <2 x float> @
// CHECK: %hlsl.step = call <2 x float> @llvm.[[TARGET]].step.v2f32(
// CHECK: ret <2 x float> %hlsl.step
float2 test_step_float2(float2 p0, float2 p1)
{
return step(p0, p1);
}
// CHECK: define [[FNATTRS]] <3 x float> @
// CHECK: %hlsl.step = call <3 x float> @llvm.[[TARGET]].step.v3f32(
// CHECK: ret <3 x float> %hlsl.step
float3 test_step_float3(float3 p0, float3 p1)
{
return step(p0, p1);
}
// CHECK: define [[FNATTRS]] <4 x float> @
// CHECK: %hlsl.step = call <4 x float> @llvm.[[TARGET]].step.v4f32(
// CHECK: ret <4 x float> %hlsl.step
float4 test_step_float4(float4 p0, float4 p1)
{
return step(p0, p1);
}
31 changes: 31 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/step-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -disable-llvm-passes -verify -verify-ignore-unexpected

void test_too_few_arg()
{
return __builtin_hlsl_step();
// expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
}

void test_too_many_arg(float2 p0)
{
return __builtin_hlsl_step(p0, p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
}

bool builtin_bool_to_float_type_promotion(bool p1)
{
return __builtin_hlsl_step(p1, p1);
// expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
}

bool builtin_step_int_to_float_promotion(int p1)
{
return __builtin_hlsl_step(p1, p1);
// expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
}

bool2 builtin_step_int2_to_float2_promotion(int2 p1)
{
return __builtin_hlsl_step(p1, p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}
2 changes: 1 addition & 1 deletion llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;

def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>]>;
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ let TargetPrefix = "spv" in {
def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_spv_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [LLVMMatchType<0>, llvm_anyfloat_ty]>;
def int_spv_fdot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
Expand Down
26 changes: 25 additions & 1 deletion llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
case Intrinsic::dx_sign:
case Intrinsic::dx_step:
return true;
}
return false;
Expand Down Expand Up @@ -322,6 +323,28 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
return Exp2Call;
}

static Value *expandStepIntrinsic(CallInst *Orig) {

Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Type *Ty = X->getType();
IRBuilder<> Builder(Orig);

Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0);
Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0);
Value *Cond = Builder.CreateFCmpOLT(Y, X);

if (Ty != Ty->getScalarType()) {
auto *XVec = dyn_cast<FixedVectorType>(Ty);
One = ConstantVector::getSplat(
ElementCount::getFixed(XVec->getNumElements()), One);
Zero = ConstantVector::getSplat(
ElementCount::getFixed(XVec->getNumElements()), Zero);
}

return Builder.CreateSelect(Cond, Zero, One);
}

static Intrinsic::ID getMaxForClamp(Type *ElemTy,
Intrinsic::ID ClampIntrinsic) {
if (ClampIntrinsic == Intrinsic::dx_uclamp)
Expand Down Expand Up @@ -433,8 +456,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::dx_sign:
Result = expandSignIntrinsic(Orig);
break;
case Intrinsic::dx_step:
Result = expandStepIntrinsic(Orig);
}

if (Result) {
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
Expand Down
24 changes: 24 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectStep(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectUnmergeValues(MachineInstr &I) const;

Register buildI32Constant(uint32_t Val, MachineInstr &I,
Expand Down Expand Up @@ -1710,6 +1713,25 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
return Result;
}

bool SPIRVInstructionSelector::selectStep(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {

assert(I.getNumOperands() == 4);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
MachineBasicBlock &BB = *I.getParent();

return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
.addImm(GL::Step)
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg())
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -2468,6 +2490,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
}
case Intrinsic::spv_step:
return selectStep(ResVReg, ResType, I);
default: {
std::string DiagMsg;
raw_string_ostream OS(DiagMsg);
Expand Down
Loading
Loading