Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Add cl_khr_kernel_clock / SPV_KHR_shader_clock extension #92771

Merged
merged 2 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,39 @@ static bool generateGroupUniformInst(const SPIRV::IncomingCall *Call,
return true;
}

static bool generateKernelClockInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
MachineFunction &MF = MIRBuilder.getMF();
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
if (!ST->canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock)) {
std::string DiagMsg = std::string(Builtin->Name) +
": the builtin requires the following SPIR-V "
"extension: SPV_KHR_shader_clock";
report_fatal_error(DiagMsg.c_str(), false);
}

MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register ResultReg = Call->ReturnRegister;
MRI->setRegClass(ResultReg, &SPIRV::IDRegClass);

// Deduce the `Scope` operand from the builtin function name.
SPIRV::Scope::Scope ScopeArg =
StringSwitch<SPIRV::Scope::Scope>(Builtin->Name)
.EndsWith("device", SPIRV::Scope::Scope::Device)
.EndsWith("work_group", SPIRV::Scope::Scope::Workgroup)
.EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup);
Register ScopeReg = buildConstantIntReg(ScopeArg, MIRBuilder, GR);

MIRBuilder.buildInstr(SPIRV::OpReadClockKHR)
.addDef(ResultReg)
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
.addUse(ScopeReg);

return true;
}

// These queries ask for a single size_t result for a given dimension index, e.g
// size_t get_global_id(uint dimindex). In SPIR-V, the builtins corresonding to
// these values are all vec3 types, so we need to extract the correct index or
Expand Down Expand Up @@ -2290,6 +2323,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateIntelSubgroupsInst(Call.get(), MIRBuilder, GR);
case SPIRV::GroupUniform:
return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
case SPIRV::KernelClock:
return generateKernelClockInst(Call.get(), MIRBuilder, GR);
}
return false;
}
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def LoadStore : BuiltinGroup;
def IntelSubgroups : BuiltinGroup;
def AtomicFloating : BuiltinGroup;
def GroupUniform : BuiltinGroup;
def KernelClock : BuiltinGroup;

//===----------------------------------------------------------------------===//
// Class defining a demangled builtin record. The information in the record
Expand Down Expand Up @@ -952,6 +953,14 @@ defm : DemangledGroupBuiltin<"group_scan_exclusive_logical_xor", OnlyWork, OpGro
defm : DemangledGroupBuiltin<"group_scan_inclusive_logical_xor", OnlyWork, OpGroupLogicalXorKHR>;
defm : DemangledGroupBuiltin<"group_reduce_logical_xor", OnlyWork, OpGroupLogicalXorKHR>;

// cl_khr_kernel_clock / SPV_KHR_shader_clock
defm : DemangledNativeBuiltin<"clock_read_device", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
defm : DemangledNativeBuiltin<"clock_read_work_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
defm : DemangledNativeBuiltin<"clock_read_sub_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
defm : DemangledNativeBuiltin<"clock_read_hilo_device", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
defm : DemangledNativeBuiltin<"clock_read_hilo_work_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
defm : DemangledNativeBuiltin<"clock_read_hilo_sub_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;

//===----------------------------------------------------------------------===//
// Class defining an atomic instruction on floating-point numbers.
//
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
SPIRV::Extension::Extension::SPV_INTEL_variable_length_array},
{"SPV_INTEL_function_pointers",
SPIRV::Extension::Extension::SPV_INTEL_function_pointers},
{"SPV_KHR_shader_clock",
SPIRV::Extension::Extension::SPV_KHR_shader_clock},
};

bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,11 @@ def OpGroupNonUniformRotateKHR: Op<4431, (outs ID:$res),
(ins TYPE:$type, ID:$scope, ID:$value, ID:$delta, variable_ops),
"$res = OpGroupNonUniformRotateKHR $type $scope $value $delta">;

// SPV_KHR_shader_clock
def OpReadClockKHR: Op<5056, (outs ID:$res),
(ins TYPE:$type, ID:$scope),
"$res = OpReadClockKHR $type $scope">;

// 3.49.7, Constant-Creation Instructions

// - SPV_INTEL_function_pointers
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,14 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
}
break;
case SPIRV::OpReadClockKHR:
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
report_fatal_error("OpReadClockKHR instruction requires the "
"following SPIR-V extension: SPV_KHR_shader_clock",
false);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
break;
case SPIRV::OpFunctionPointerCallINTEL:
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ defm ImageGatherBiasLodAMD : CapabilityOperand<5009, 0, 0, [], [Shader]>;
defm FragmentMaskAMD : CapabilityOperand<5010, 0, 0, [], [Shader]>;
defm StencilExportEXT : CapabilityOperand<5013, 0, 0, [], [Shader]>;
defm ImageReadWriteLodAMD : CapabilityOperand<5015, 0, 0, [], [Shader]>;
defm ShaderClockKHR : CapabilityOperand<5055, 0, 0, [SPV_KHR_shader_clock], []>;
defm SampleMaskOverrideCoverageNV : CapabilityOperand<5249, 0, 0, [], [SampleRateShading]>;
defm GeometryShaderPassthroughNV : CapabilityOperand<5251, 0, 0, [], [Geometry]>;
defm ShaderViewportIndexLayerEXT : CapabilityOperand<5254, 0, 0, [], [MultiViewport]>;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_shader_clock %s -o - | FileCheck %s
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_shader_clock %s -o - -filetype=obj | spirv-val %}
svenvh marked this conversation as resolved.
Show resolved Hide resolved

; CHECK-ERROR: LLVM ERROR: clock_read_device: the builtin requires the following SPIR-V extension: SPV_KHR_shader_clock

; CHECK: OpCapability ShaderClockKHR
; CHECK: OpExtension "SPV_KHR_shader_clock"
; CHECK-DAG: [[uint:%[a-z0-9_]+]] = OpTypeInt 32
; CHECK-DAG: [[ulong:%[a-z0-9_]+]] = OpTypeInt 64
; CHECK-DAG: [[v2uint:%[a-z0-9_]+]] = OpTypeVector [[uint]] 2
; CHECK-DAG: [[uint_1:%[a-z0-9_]+]] = OpConstant [[uint]] 1
; CHECK-DAG: [[uint_2:%[a-z0-9_]+]] = OpConstant [[uint]] 2
; CHECK-DAG: [[uint_3:%[a-z0-9_]+]] = OpConstant [[uint]] 3
; CHECK: OpReadClockKHR [[ulong]] [[uint_1]]
; CHECK: OpReadClockKHR [[ulong]] [[uint_2]]
; CHECK: OpReadClockKHR [[ulong]] [[uint_3]]
; CHECK: OpReadClockKHR [[v2uint]] [[uint_1]]
; CHECK: OpReadClockKHR [[v2uint]] [[uint_2]]
; CHECK: OpReadClockKHR [[v2uint]] [[uint_3]]

define dso_local spir_kernel void @test_clocks(ptr addrspace(1) nocapture noundef writeonly align 8 %out64, ptr addrspace(1) nocapture noundef writeonly align 8 %outv2) {
entry:
svenvh marked this conversation as resolved.
Show resolved Hide resolved
%call = tail call spir_func i64 @_Z17clock_read_devicev()
store i64 %call, ptr addrspace(1) %out64, align 8
%call1 = tail call spir_func i64 @_Z21clock_read_work_groupv()
%arrayidx2 = getelementptr inbounds i8, ptr addrspace(1) %out64, i32 8
store i64 %call1, ptr addrspace(1) %arrayidx2, align 8
%call3 = tail call spir_func i64 @_Z20clock_read_sub_groupv()
%arrayidx4 = getelementptr inbounds i8, ptr addrspace(1) %out64, i32 16
store i64 %call3, ptr addrspace(1) %arrayidx4, align 8
%call5 = tail call spir_func <2 x i32> @_Z22clock_read_hilo_devicev()
store <2 x i32> %call5, ptr addrspace(1) %outv2, align 8
%call7 = tail call spir_func <2 x i32> @_Z26clock_read_hilo_work_groupv()
%arrayidx8 = getelementptr inbounds i8, ptr addrspace(1) %outv2, i32 8
store <2 x i32> %call7, ptr addrspace(1) %arrayidx8, align 8
%call9 = tail call spir_func <2 x i32> @_Z25clock_read_hilo_sub_groupv()
%arrayidx10 = getelementptr inbounds i8, ptr addrspace(1) %outv2, i32 16
store <2 x i32> %call9, ptr addrspace(1) %arrayidx10, align 8
ret void
}

; Function Attrs: convergent nounwind
declare spir_func i64 @_Z17clock_read_devicev() local_unnamed_addr

; Function Attrs: convergent nounwind
declare spir_func i64 @_Z21clock_read_work_groupv() local_unnamed_addr

; Function Attrs: convergent nounwind
declare spir_func i64 @_Z20clock_read_sub_groupv() local_unnamed_addr

; Function Attrs: convergent nounwind
declare spir_func <2 x i32> @_Z22clock_read_hilo_devicev() local_unnamed_addr

; Function Attrs: convergent nounwind
declare spir_func <2 x i32> @_Z26clock_read_hilo_work_groupv() local_unnamed_addr

; Function Attrs: convergent nounwind
declare spir_func <2 x i32> @_Z25clock_read_hilo_sub_groupv() local_unnamed_addr
Loading