From 544394b547be49f22249d8cd14df850a1fe73336 Mon Sep 17 00:00:00 2001 From: Andrii Levytskyi <107996072+aabysswalker@users.noreply.github.com> Date: Tue, 23 Jul 2024 21:03:39 +0300 Subject: [PATCH] [SPIRV][HLSL] Add lowering of frac to SPIR-V (#97111) Summary: Implements frac lowering to SPIR-V. Closes #88059 Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251390 --- clang/lib/CodeGen/CGBuiltin.cpp | 8 +- clang/lib/CodeGen/CGHLSLRuntime.h | 1 + clang/test/CodeGenHLSL/builtins/frac.hlsl | 109 +++++++++++------- llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + .../Target/SPIRV/SPIRVInstructionSelector.cpp | 22 ++++ .../CodeGen/SPIRV/hlsl-intrinsics/frac.ll | 68 +++++++++++ 6 files changed, 166 insertions(+), 43 deletions(-) create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/frac.ll diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index c1999769560857..a0d03b87ccdc95 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -18471,10 +18471,10 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, if (!E->getArg(0)->getType()->hasFloatingRepresentation()) llvm_unreachable("frac operand must have a float representation"); return Builder.CreateIntrinsic( - /*ReturnType=*/Op0->getType(), Intrinsic::dx_frac, - ArrayRef{Op0}, nullptr, "dx.frac"); - } - case Builtin::BI__builtin_hlsl_elementwise_isinf: { + /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getFracIntrinsic(), + ArrayRef{Op0}, nullptr, "hlsl.frac"); +} +case Builtin::BI__builtin_hlsl_elementwise_isinf: { Value *Op0 = EmitScalarExpr(E->getArg(0)); llvm::Type *Xty = Op0->getType(); llvm::Type *retType = llvm::Type::getInt1Ty(this->getLLVMContext()); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index 4036ce711bea11..8c067f49639556 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -74,6 +74,7 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(All, all) GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any) + GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac) GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp) GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt) GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id) diff --git a/clang/test/CodeGenHLSL/builtins/frac.hlsl b/clang/test/CodeGenHLSL/builtins/frac.hlsl index 7c4d1468e96d27..b457f5c2787918 100644 --- a/clang/test/CodeGenHLSL/builtins/frac.hlsl +++ b/clang/test/CodeGenHLSL/builtins/frac.hlsl @@ -1,53 +1,84 @@ // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ // RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ -// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ -// RUN: --check-prefixes=CHECK,NATIVE_HALF +// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ +// RUN: --check-prefixes=CHECK,DXIL_CHECK,DXIL_NATIVE_HALF,NATIVE_HALF // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ // RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \ -// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF +// RUN: -o - | FileCheck %s --check-prefixes=CHECK,DXIL_CHECK,NO_HALF,DXIL_NO_HALF +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \ +// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ +// RUN: --check-prefixes=CHECK,SPIR_CHECK,NATIVE_HALF,SPIR_NATIVE_HALF +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \ +// RUN: -o - | FileCheck %s --check-prefixes=CHECK,SPIR_CHECK,NO_HALF,SPIR_NO_HALF -// NATIVE_HALF: define noundef half @ -// NATIVE_HALF: %dx.frac = call half @llvm.dx.frac.f16( -// NATIVE_HALF: ret half %dx.frac -// NO_HALF: define noundef float @"?test_frac_half@@YA$halff@$halff@@Z"( -// NO_HALF: %dx.frac = call float @llvm.dx.frac.f32( -// NO_HALF: ret float %dx.frac +// DXIL_NATIVE_HALF: define noundef half @ +// SPIR_NATIVE_HALF: define spir_func noundef half @ +// DXIL_NATIVE_HALF: %hlsl.frac = call half @llvm.dx.frac.f16( +// SPIR_NATIVE_HALF: %hlsl.frac = call half @llvm.spv.frac.f16( +// NATIVE_HALF: ret half %hlsl.frac +// DXIL_NO_HALF: define noundef float @ +// SPIR_NO_HALF: define spir_func noundef float @ +// DXIL_NO_HALF: %hlsl.frac = call float @llvm.dx.frac.f32( +// SPIR_NO_HALF: %hlsl.frac = call float @llvm.spv.frac.f32( +// NO_HALF: ret float %hlsl.frac half test_frac_half(half p0) { return frac(p0); } -// NATIVE_HALF: define noundef <2 x half> @ -// NATIVE_HALF: %dx.frac = call <2 x half> @llvm.dx.frac.v2f16 -// NATIVE_HALF: ret <2 x half> %dx.frac -// NO_HALF: define noundef <2 x float> @ -// NO_HALF: %dx.frac = call <2 x float> @llvm.dx.frac.v2f32( -// NO_HALF: ret <2 x float> %dx.frac +// DXIL_NATIVE_HALF: define noundef <2 x half> @ +// SPIR_NATIVE_HALF: define spir_func noundef <2 x half> @ +// DXIL_NATIVE_HALF: %hlsl.frac = call <2 x half> @llvm.dx.frac.v2f16 +// SPIR_NATIVE_HALF: %hlsl.frac = call <2 x half> @llvm.spv.frac.v2f16 +// NATIVE_HALF: ret <2 x half> %hlsl.frac +// DXIL_NO_HALF: define noundef <2 x float> @ +// SPIR_NO_HALF: define spir_func noundef <2 x float> @ +// DXIL_NO_HALF: %hlsl.frac = call <2 x float> @llvm.dx.frac.v2f32( +// SPIR_NO_HALF: %hlsl.frac = call <2 x float> @llvm.spv.frac.v2f32( +// NO_HALF: ret <2 x float> %hlsl.frac half2 test_frac_half2(half2 p0) { return frac(p0); } -// NATIVE_HALF: define noundef <3 x half> @ -// NATIVE_HALF: %dx.frac = call <3 x half> @llvm.dx.frac.v3f16 -// NATIVE_HALF: ret <3 x half> %dx.frac -// NO_HALF: define noundef <3 x float> @ -// NO_HALF: %dx.frac = call <3 x float> @llvm.dx.frac.v3f32( -// NO_HALF: ret <3 x float> %dx.frac +// DXIL_NATIVE_HALF: define noundef <3 x half> @ +// SPIR_NATIVE_HALF: define spir_func noundef <3 x half> @ +// DXIL_NATIVE_HALF: %hlsl.frac = call <3 x half> @llvm.dx.frac.v3f16 +// SPIR_NATIVE_HALF: %hlsl.frac = call <3 x half> @llvm.spv.frac.v3f16 +// NATIVE_HALF: ret <3 x half> %hlsl.frac +// DXIL_NO_HALF: define noundef <3 x float> @ +// SPIR_NO_HALF: define spir_func noundef <3 x float> @ +// DXIL_NO_HALF: %hlsl.frac = call <3 x float> @llvm.dx.frac.v3f32( +// SPIR_NO_HALF: %hlsl.frac = call <3 x float> @llvm.spv.frac.v3f32( +// NO_HALF: ret <3 x float> %hlsl.frac half3 test_frac_half3(half3 p0) { return frac(p0); } -// NATIVE_HALF: define noundef <4 x half> @ -// NATIVE_HALF: %dx.frac = call <4 x half> @llvm.dx.frac.v4f16 -// NATIVE_HALF: ret <4 x half> %dx.frac -// NO_HALF: define noundef <4 x float> @ -// NO_HALF: %dx.frac = call <4 x float> @llvm.dx.frac.v4f32( -// NO_HALF: ret <4 x float> %dx.frac +// DXIL_NATIVE_HALF: define noundef <4 x half> @ +// SPIR_NATIVE_HALF: define spir_func noundef <4 x half> @ +// DXIL_NATIVE_HALF: %hlsl.frac = call <4 x half> @llvm.dx.frac.v4f16 +// SPIR_NATIVE_HALF: %hlsl.frac = call <4 x half> @llvm.spv.frac.v4f16 +// NATIVE_HALF: ret <4 x half> %hlsl.frac +// DXIL_NO_HALF: define noundef <4 x float> @ +// SPIR_NO_HALF: define spir_func noundef <4 x float> @ +// DXIL_NO_HALF: %hlsl.frac = call <4 x float> @llvm.dx.frac.v4f32( +// SPIR_NO_HALF: %hlsl.frac = call <4 x float> @llvm.spv.frac.v4f32( +// NO_HALF: ret <4 x float> %hlsl.frac half4 test_frac_half4(half4 p0) { return frac(p0); } -// CHECK: define noundef float @ -// CHECK: %dx.frac = call float @llvm.dx.frac.f32( -// CHECK: ret float %dx.frac +// DXIL_CHECK: define noundef float @ +// SPIR_CHECK: define spir_func noundef float @ +// DXIL_CHECK: %hlsl.frac = call float @llvm.dx.frac.f32( +// SPIR_CHECK: %hlsl.frac = call float @llvm.spv.frac.f32( +// CHECK: ret float %hlsl.frac float test_frac_float(float p0) { return frac(p0); } -// CHECK: define noundef <2 x float> @ -// CHECK: %dx.frac = call <2 x float> @llvm.dx.frac.v2f32 -// CHECK: ret <2 x float> %dx.frac +// DXIL_CHECK: define noundef <2 x float> @ +// SPIR_CHECK: define spir_func noundef <2 x float> @ +// DXIL_CHECK: %hlsl.frac = call <2 x float> @llvm.dx.frac.v2f32 +// SPIR_CHECK: %hlsl.frac = call <2 x float> @llvm.spv.frac.v2f32 +// CHECK: ret <2 x float> %hlsl.frac float2 test_frac_float2(float2 p0) { return frac(p0); } -// CHECK: define noundef <3 x float> @ -// CHECK: %dx.frac = call <3 x float> @llvm.dx.frac.v3f32 -// CHECK: ret <3 x float> %dx.frac +// DXIL_CHECK: define noundef <3 x float> @ +// SPIR_CHECK: define spir_func noundef <3 x float> @ +// DXIL_CHECK: %hlsl.frac = call <3 x float> @llvm.dx.frac.v3f32 +// SPIR_CHECK: %hlsl.frac = call <3 x float> @llvm.spv.frac.v3f32 +// CHECK: ret <3 x float> %hlsl.frac float3 test_frac_float3(float3 p0) { return frac(p0); } -// CHECK: define noundef <4 x float> @ -// CHECK: %dx.frac = call <4 x float> @llvm.dx.frac.v4f32 -// CHECK: ret <4 x float> %dx.frac +// DXIL_CHECK: define noundef <4 x float> @ +// SPIR_CHECK: define spir_func noundef <4 x float> @ +// DXIL_CHECK: %hlsl.frac = call <4 x float> @llvm.dx.frac.v4f32 +// SPIR_CHECK: %hlsl.frac = call <4 x float> @llvm.spv.frac.v4f32 +// CHECK: ret <4 x float> %hlsl.frac float4 test_frac_float4(float4 p0) { return frac(p0); } diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 683acf4a6ffa90..ef6ddf12c32f68 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -60,6 +60,7 @@ let TargetPrefix = "spv" in { Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>; def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; + def int_spv_frac : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>; def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], [IntrNoMem, IntrWillReturn] >; def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 04def5ef01e7b3..8391e0dec9a395 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -173,6 +173,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectFmix(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectFrac(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectRsqrt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; @@ -1330,6 +1333,23 @@ bool SPIRVInstructionSelector::selectFmix(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectFrac(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + + assert(I.getNumOperands() == 3); + assert(I.getOperand(2).isReg()); + MachineBasicBlock &BB = *I.getParent(); + + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addImm(static_cast(SPIRV::InstructionSet::GLSL_std_450)) + .addImm(GL::Fract) + .addUse(I.getOperand(2).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { @@ -2059,6 +2079,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectAny(ResVReg, ResType, I); case Intrinsic::spv_lerp: return selectFmix(ResVReg, ResType, I); + case Intrinsic::spv_frac: + return selectFrac(ResVReg, ResType, I); case Intrinsic::spv_rsqrt: return selectRsqrt(ResVReg, ResType, I); case Intrinsic::spv_lifetime_start: diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/frac.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/frac.ll new file mode 100644 index 00000000000000..3c48782a185862 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/frac.ll @@ -0,0 +1,68 @@ +; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450" + +; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64 + +; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4 +; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4 +; CHECK-DAG: %[[#vec4_float_64:]] = OpTypeVector %[[#float_64]] 4 + +define noundef float @frac_float(float noundef %a) { +entry: +; CHECK: %[[#float_32_arg:]] = OpFunctionParameter %[[#float_32]] +; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] Fract %[[#float_32_arg]] + %elt.frac = call float @llvm.spv.frac.f32(float %a) + ret float %elt.frac +} + +define noundef half @frac_half(half noundef %a) { +entry: +; CHECK: %[[#float_16_arg:]] = OpFunctionParameter %[[#float_16]] +; CHECK: %[[#]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] Fract %[[#float_16_arg]] + %elt.frac = call half @llvm.spv.frac.f16(half %a) + ret half %elt.frac +} + +define noundef double @frac_double(double noundef %a) { +entry: +; CHECK: %[[#float_64_arg:]] = OpFunctionParameter %[[#float_64]] +; CHECK: %[[#]] = OpExtInst %[[#float_64]] %[[#op_ext_glsl]] Fract %[[#float_64_arg]] + %elt.frac = call double @llvm.spv.frac.f64(double %a) + ret double %elt.frac +} + +define noundef <4 x float> @frac_float_vector(<4 x float> noundef %a) { +entry: +; CHECK: %[[#vec4_float_32_arg:]] = OpFunctionParameter %[[#vec4_float_32]] +; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] Fract %[[#vec4_float_32_arg]] + %elt.frac = call <4 x float> @llvm.spv.frac.v4f32(<4 x float> %a) + ret <4 x float> %elt.frac +} + +define noundef <4 x half> @frac_half_vector(<4 x half> noundef %a) { +entry: +; CHECK: %[[#vec4_float_16_arg:]] = OpFunctionParameter %[[#vec4_float_16]] +; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] Fract %[[#vec4_float_16_arg]] + %elt.frac = call <4 x half> @llvm.spv.frac.v4f16(<4 x half> %a) + ret <4 x half> %elt.frac +} + +define noundef <4 x double> @frac_double_vector(<4 x double> noundef %a) { +entry: +; CHECK: %[[#vec4_float_64_arg:]] = OpFunctionParameter %[[#vec4_float_64]] +; CHECK: %[[#]] = OpExtInst %[[#vec4_float_64]] %[[#op_ext_glsl]] Fract %[[#vec4_float_64_arg]] + %elt.frac = call <4 x double> @llvm.spv.frac.v4f64(<4 x double> %a) + ret <4 x double> %elt.frac +} + +declare half @llvm.spv.frac.f16(half) +declare float @llvm.spv.frac.f32(float) +declare double @llvm.spv.frac.f64(double) + +declare <4 x float> @llvm.spv.frac.v4f32(<4 x float>) +declare <4 x half> @llvm.spv.frac.v4f16(<4 x half>) +declare <4 x double> @llvm.spv.frac.v4f64(<4 x double>)