From e76a7590cc42a87e2093d0283d68c39e2e326fed Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 9 Sep 2024 23:27:27 +0200 Subject: [PATCH] [SPIRV] Add sign intrinsic part 1 (#101987) partially fixes #70078 ### Changes - Added `int_spv_sign` intrinsic in `IntrinsicsSPIRV.td` - Added lowering and map to `int_spv_sign in `SPIRVInstructionSelector.cpp` - Added SPIR-V backend test case in `llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sign.ll` ### Related PRs - https://github.com/llvm/llvm-project/pull/101988 - https://github.com/llvm/llvm-project/pull/101989 --- llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + .../Target/SPIRV/SPIRVInstructionSelector.cpp | 52 +++++++ .../CodeGen/SPIRV/hlsl-intrinsics/sign.ll | 143 ++++++++++++++++++ 3 files changed, 196 insertions(+) create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sign.ll diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index cbf6e04f2844d62..766fc0d99d2dbb1 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -80,4 +80,5 @@ let TargetPrefix = "spv" in { [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, Commutative] >; def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>; + def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index fed82b904af4f7b..1e861da35aaac92 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -28,6 +28,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineModuleInfoImpls.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/Register.h" #include "llvm/CodeGen/TargetOpcodes.h" #include "llvm/IR/IntrinsicsSPIRV.h" #include "llvm/Support/Debug.h" @@ -184,6 +185,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectRsqrt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectSign(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectFloatDot(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; @@ -1603,6 +1607,52 @@ bool SPIRVInstructionSelector::selectSaturate(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectSign(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(I.getNumOperands() == 3); + assert(I.getOperand(2).isReg()); + MachineBasicBlock &BB = *I.getParent(); + Register InputRegister = I.getOperand(2).getReg(); + SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister); + auto &DL = I.getDebugLoc(); + + if (!InputType) + report_fatal_error("Input Type could not be determined."); + + bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat); + + unsigned SignBitWidth = GR.getScalarOrVectorBitWidth(InputType); + unsigned ResBitWidth = GR.getScalarOrVectorBitWidth(ResType); + + bool NeedsConversion = IsFloatTy || SignBitWidth != ResBitWidth; + + auto SignOpcode = IsFloatTy ? GL::FSign : GL::SSign; + Register SignReg = NeedsConversion + ? MRI->createVirtualRegister(&SPIRV::IDRegClass) + : ResVReg; + + bool Result = + BuildMI(BB, I, DL, TII.get(SPIRV::OpExtInst)) + .addDef(SignReg) + .addUse(GR.getSPIRVTypeID(InputType)) + .addImm(static_cast(SPIRV::InstructionSet::GLSL_std_450)) + .addImm(SignOpcode) + .addUse(InputRegister) + .constrainAllUses(TII, TRI, RBI); + + if (NeedsConversion) { + auto ConvertOpcode = IsFloatTy ? SPIRV::OpConvertFToS : SPIRV::OpSConvert; + Result |= BuildMI(*I.getParent(), I, DL, TII.get(ConvertOpcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SignReg) + .constrainAllUses(TII, TRI, RBI); + } + + return Result; +} + bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { @@ -2339,6 +2389,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectNormalize(ResVReg, ResType, I); case Intrinsic::spv_rsqrt: return selectRsqrt(ResVReg, ResType, I); + case Intrinsic::spv_sign: + return selectSign(ResVReg, ResType, I); case Intrinsic::spv_lifetime_start: case Intrinsic::spv_lifetime_end: { unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sign.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sign.ll new file mode 100644 index 000000000000000..52a41c3d3ad6487 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sign.ll @@ -0,0 +1,143 @@ +; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450" + +; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64 + +; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16 +; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 +; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64 + +; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4 +; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4 +; CHECK-DAG: %[[#vec4_float_64:]] = OpTypeVector %[[#float_64]] 4 + +; CHECK-DAG: %[[#vec4_int_16:]] = OpTypeVector %[[#int_16]] 4 +; CHECK-DAG: %[[#vec4_int_32:]] = OpTypeVector %[[#int_32]] 4 +; CHECK-DAG: %[[#vec4_int_64:]] = OpTypeVector %[[#int_64]] 4 + + +define noundef i32 @sign_half(half noundef %a) { +entry: +; CHECK: %[[#float_16_arg:]] = OpFunctionParameter %[[#float_16]] +; CHECK: %[[#fsign:]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] FSign %[[#float_16_arg]] +; CHECK: %[[#]] = OpConvertFToS %[[#int_32]] %[[#fsign]] + %elt.sign = call i32 @llvm.spv.sign.f16(half %a) + ret i32 %elt.sign +} + +define noundef i32 @sign_float(float noundef %a) { +entry: +; CHECK: %[[#float_32_arg:]] = OpFunctionParameter %[[#float_32]] +; CHECK: %[[#fsign:]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] FSign %[[#float_32_arg]] +; CHECK: %[[#]] = OpConvertFToS %[[#int_32]] %[[#fsign]] + %elt.sign = call i32 @llvm.spv.sign.f32(float %a) + ret i32 %elt.sign +} + +define noundef i32 @sign_double(double noundef %a) { +entry: +; CHECK: %[[#float_64_arg:]] = OpFunctionParameter %[[#float_64]] +; CHECK: %[[#fsign:]] = OpExtInst %[[#float_64]] %[[#op_ext_glsl]] FSign %[[#float_64_arg]] +; CHECK: %[[#]] = OpConvertFToS %[[#int_32]] %[[#fsign]] + %elt.sign = call i32 @llvm.spv.sign.f64(double %a) + ret i32 %elt.sign +} + +define noundef i32 @sign_i16(i16 noundef %a) { +entry: +; CHECK: %[[#int_16_arg:]] = OpFunctionParameter %[[#int_16]] +; CHECK: %[[#ssign:]] = OpExtInst %[[#int_16]] %[[#op_ext_glsl]] SSign %[[#int_16_arg]] +; CHECK: %[[#]] = OpSConvert %[[#int_32]] %[[#ssign]] + %elt.sign = call i32 @llvm.spv.sign.i16(i16 %a) + ret i32 %elt.sign +} + +define noundef i32 @sign_i32(i32 noundef %a) { +entry: +; CHECK: %[[#int_32_arg:]] = OpFunctionParameter %[[#int_32]] +; CHECK: %[[#]] = OpExtInst %[[#int_32]] %[[#op_ext_glsl]] SSign %[[#int_32_arg]] + %elt.sign = call i32 @llvm.spv.sign.i32(i32 %a) + ret i32 %elt.sign +} + +define noundef i32 @sign_i64(i64 noundef %a) { +entry: +; CHECK: %[[#int_64_arg:]] = OpFunctionParameter %[[#int_64]] +; CHECK: %[[#ssign:]] = OpExtInst %[[#int_64]] %[[#op_ext_glsl]] SSign %[[#int_64_arg]] +; CHECK: %[[#]] = OpSConvert %[[#int_32]] %[[#ssign]] + %elt.sign = call i32 @llvm.spv.sign.i64(i64 %a) + ret i32 %elt.sign +} + +define noundef <4 x i32> @sign_half_vector(<4 x half> noundef %a) { +entry: +; CHECK: %[[#vec4_float_16_arg:]] = OpFunctionParameter %[[#vec4_float_16]] +; CHECK: %[[#fsign:]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] FSign %[[#vec4_float_16_arg]] +; CHECK: %[[#]] = OpConvertFToS %[[#vec4_int_32]] %[[#fsign]] + %elt.sign = call <4 x i32> @llvm.spv.sign.v4f16(<4 x half> %a) + ret <4 x i32> %elt.sign +} + +define noundef <4 x i32> @sign_float_vector(<4 x float> noundef %a) { +entry: +; CHECK: %[[#vec4_float_32_arg:]] = OpFunctionParameter %[[#vec4_float_32]] +; CHECK: %[[#fsign:]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] FSign %[[#vec4_float_32_arg]] +; CHECK: %[[#]] = OpConvertFToS %[[#vec4_int_32]] %[[#fsign]] + %elt.sign = call <4 x i32> @llvm.spv.sign.v4f32(<4 x float> %a) + ret <4 x i32> %elt.sign +} + +define noundef <4 x i32> @sign_double_vector(<4 x double> noundef %a) { +entry: +; CHECK: %[[#vec4_float_64_arg:]] = OpFunctionParameter %[[#vec4_float_64]] +; CHECK: %[[#fsign:]] = OpExtInst %[[#vec4_float_64]] %[[#op_ext_glsl]] FSign %[[#vec4_float_64_arg]] +; CHECK: %[[#]] = OpConvertFToS %[[#vec4_int_32]] %[[#fsign]] + %elt.sign = call <4 x i32> @llvm.spv.sign.v4f64(<4 x double> %a) + ret <4 x i32> %elt.sign +} + +define noundef <4 x i32> @sign_i16_vector(<4 x i16> noundef %a) { +entry: +; CHECK: %[[#vec4_int_16_arg:]] = OpFunctionParameter %[[#vec4_int_16]] +; CHECK: %[[#ssign:]] = OpExtInst %[[#vec4_int_16]] %[[#op_ext_glsl]] SSign %[[#vec4_int_16_arg]] +; CHECK: %[[#]] = OpSConvert %[[#vec4_int_32]] %[[#ssign]] + %elt.sign = call <4 x i32> @llvm.spv.sign.v4i16(<4 x i16> %a) + ret <4 x i32> %elt.sign +} + +define noundef <4 x i32> @sign_i32_vector(<4 x i32> noundef %a) { +entry: +; CHECK: %[[#vec4_int_32_arg:]] = OpFunctionParameter %[[#vec4_int_32]] +; CHECK: %[[#]] = OpExtInst %[[#vec4_int_32]] %[[#op_ext_glsl]] SSign %[[#vec4_int_32_arg]] + %elt.sign = call <4 x i32> @llvm.spv.sign.v4i32(<4 x i32> %a) + ret <4 x i32> %elt.sign +} + +define noundef <4 x i32> @sign_i64_vector(<4 x i64> noundef %a) { +entry: +; CHECK: %[[#vec4_int_64_arg:]] = OpFunctionParameter %[[#vec4_int_64]] +; CHECK: %[[#ssign:]] = OpExtInst %[[#vec4_int_64]] %[[#op_ext_glsl]] SSign %[[#vec4_int_64_arg]] +; CHECK: %[[#]] = OpSConvert %[[#vec4_int_32]] %[[#ssign]] + %elt.sign = call <4 x i32> @llvm.spv.sign.v4i64(<4 x i64> %a) + ret <4 x i32> %elt.sign +} + +declare i32 @llvm.spv.sign.f16(half) +declare i32 @llvm.spv.sign.f32(float) +declare i32 @llvm.spv.sign.f64(double) + +declare i32 @llvm.spv.sign.i16(i16) +declare i32 @llvm.spv.sign.i32(i32) +declare i32 @llvm.spv.sign.i64(i64) + +declare <4 x i32> @llvm.spv.sign.v4f16(<4 x half>) +declare <4 x i32> @llvm.spv.sign.v4f32(<4 x float>) +declare <4 x i32> @llvm.spv.sign.v4f64(<4 x double>) + +declare <4 x i32> @llvm.spv.sign.v4i16(<4 x i16>) +declare <4 x i32> @llvm.spv.sign.v4i32(<4 x i32>) +declare <4 x i32> @llvm.spv.sign.v4i64(<4 x i64>)