diff --git a/.github/workflows/presubmit.yml b/.github/workflows/presubmit.yml index e64695a..8a5b249 100644 --- a/.github/workflows/presubmit.yml +++ b/.github/workflows/presubmit.yml @@ -71,6 +71,9 @@ jobs: -DPERFETTO_TRACE_PROCESSOR_LIB="$(pwd)/third_party/perfetto/out/linux_clang_release/libtrace_processor.a" \ -DPERFETTO_CXX_CONFIG_INCLUDE_PATH="$(pwd)/third_party/perfetto/buildtools/libcxx_config" \ -DPERFETTO_CXX_SYSTEM_INCLUDE_PATH="$(pwd)/third_party/perfetto/buildtools/libcxx/include" \ + -DSPIRV_HEADERS_INCLUDE_DIR="$(pwd)/third_party/spirv-tools/external/spirv-headers/include" \ + -DSPIRV_TOOLS_BUILD_DIR="$(pwd)/third_party/spirv-tools/build" \ + -DSPIRV_TOOLS_SOURCE_DIR="$(pwd)/third_party/spirv-tools" \ -DEXTRACTOR_NOSTDINCXX=1 \ -DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH};$(pwd)/install/lib/cmake/SPIRV-Tools-opt;$(pwd)/install/lib/cmake/SPIRV-Tools" \ -DCMAKE_CXX_COMPILER="$(which clang++)" \ @@ -85,6 +88,9 @@ jobs: -DPERFETTO_TRACE_PROCESSOR_LIB="$(pwd)/third_party/perfetto/out/linux_clang_release/libtrace_processor.a" \ -DPERFETTO_CXX_CONFIG_INCLUDE_PATH="$(pwd)/third_party/perfetto/buildtools/libcxx_config" \ -DPERFETTO_CXX_SYSTEM_INCLUDE_PATH="$(pwd)/third_party/perfetto/buildtools/libcxx/include" \ + -DSPIRV_HEADERS_INCLUDE_DIR="$(pwd)/third_party/spirv-tools/external/spirv-headers/include" \ + -DSPIRV_TOOLS_BUILD_DIR="$(pwd)/third_party/spirv-tools/build" \ + -DSPIRV_TOOLS_SOURCE_DIR="$(pwd)/third_party/spirv-tools" \ -DEXTRACTOR_NOSTDINCXX=1 \ -DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH};$(pwd)/install/lib/cmake/SPIRV-Tools-opt;$(pwd)/install/lib/cmake/SPIRV-Tools" \ -DCMAKE_CXX_COMPILER="$(which clang++)" \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e2eb20..8926d42 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,12 @@ endif() if (NOT PERFETTO_INTERNAL_INCLUDE_PATH) message(FATAL_ERROR "PERFETTO_INTERNAL_INCLUDE_PATH not defined") endif() +if (NOT SPIRV_TOOLS_SOURCE_PATH) + message(FATAL_ERROR "SPIRV_TOOLS_SOURCE_PATH not defined") +endif() +if (NOT SPIRV_TOOLS_BUILD_PATH) + message(FATAL_ERROR "SPIRV_TOOLS_BUILD_PATH not defined") +endif() find_package(SPIRV-Tools-opt) message(STATUS "SPIRV-Tools-opt_LIBRARIES = '${SPIRV-Tools-opt_LIBRARIES}'") diff --git a/README.md b/README.md index fc4e35a..0a63323 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ Using the `vulkan-shader-profiler-extractor` and `vulkan-shader-profiler-runner` To compile it, please run: ``` -cmake -B -S -DPERFETTO_SDK_PATH= -DPERFETTO_TRACE_PROCESSOR_LIB= -DPERFETTO_INTERNAL_INCLUDE_PATH= +cmake -B -S -DPERFETTO_SDK_PATH= -DPERFETTO_TRACE_PROCESSOR_LIB= -DPERFETTO_INTERNAL_INCLUDE_PATH= -DSPIRV_TOOLS_SOURCE_PATH= -DSPIRV_TOOLS_BUILD_PATH= cmake --build ``` @@ -53,15 +53,19 @@ For a real life examples, have a look at: ## Build options -* `PERFETTO_SDK_PATH` (REQUIRED): path to [perfetto](https://github.com/google/perfetto) sdk (`vulkan-kernel-profiler` is looking for `PERFETTO_SDK_PATH/perfetto.cc` and `PERFETTO_SDK_PATH/perfetto.h`). -* `PERFETTO_TRACE_PROCESSOR_LIB` (REQUIRED): path to `libtrace_processor.a` produces by a perfetto build. -* `PERFETTO_INTERNAL_INCLUDE_PATH` (REQUIRED): path to perfetto internal include directory (`/include`), or where it is installed. +* REQUIRED: + * `PERFETTO_SDK_PATH`: path to [perfetto](https://github.com/google/perfetto) sdk (`vulkan-kernel-profiler` is looking for `PERFETTO_SDK_PATH/perfetto.cc` and `PERFETTO_SDK_PATH/perfetto.h`). + * `PERFETTO_TRACE_PROCESSOR_LIB`: path to `libtrace_processor.a` produces by a perfetto build. + * `PERFETTO_INTERNAL_INCLUDE_PATH`: path to perfetto internal include directory (`/include`), or where it is installed. + * `SPIRV_TOOLS_SOURCE_PATH`: path to [SPIRV-Tools](https://github.com/KhronosGroup/SPIRV-Tools) source directory (PR [#5512](https://github.com/KhronosGroup/SPIRV-Tools/pull/5512) is needed). + * `SPIRV_TOOLS_BUILD_PATH`: path to where [SPIRV-Tools](https://github.com/KhronosGroup/SPIRV-Tools) is built (not installed, just built). * OPTIONAL: * `PERFETTO_LIBRARY`: name of a perfetto library already available (avoid having to compile `perfetto.cc`). * `PERFETTO_GEN_INCLUDE_PATH`: path to a a perfetto build (if not installed) `/out/release/gen/build_config`. * `PERFETTO_CXX_CONFIG_INCLUDE_PATH`: path to perfetto buildtools config `/buildtools/libcxx_config`. * `PERFETTO_CXX_SYSTEM_INCLUDE_PATH`: path to perfetto buildtools include `/buildtools/libcxx/include`. * `EXTRACTOR_NOSTDINCXX`: build `vulkan-shader-profiler-extractor` with `-nostdinc++` to be able to link with some `libtrace_processor.a`. + * `SPIRV_HEADERS_INCLUDE_PATH`: path to [SPIRV-Headers](https://github.com/KhronosGroup/SPIRV-Headers) include directory (`/include`). * `BACKEND`: [perfetto](https://github.com/google/perfetto) backend to use * `InProcess` (default): the application will generate the traces ([perfetto documentation](https://perfetto.dev/docs/instrumentation/tracing-sdk#in-process-mode)). Build options and environment variables can be used to control the maximum size of traces and the destination file where the traces will be recorded. * `System`: perfetto `traced` daemon will be responsible for generating the traces ([perfetto documentation](https://perfetto.dev/docs/instrumentation/tracing-sdk#system-mode)). diff --git a/common/common.hpp b/common/common.hpp new file mode 100644 index 0000000..5d4f2c2 --- /dev/null +++ b/common/common.hpp @@ -0,0 +1,137 @@ +// Copyright 2024 The Vulkan Shader Profiler authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace vksp { + +struct vksp_push_constant { + uint32_t offset; + uint32_t size; + uint32_t stageFlags; + const char *pValues; +}; + +#define VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER_BITS (0xf0000000) +#define VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER_MASK (~VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER_BITS) +#define VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER \ + (VK_DESCRIPTOR_TYPE_STORAGE_BUFFER | VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER_BITS) +struct vksp_descriptor_set { + uint32_t ds; + uint32_t binding; + uint32_t type; + union { + struct { + uint32_t flags; + uint32_t queueFamilyIndexCount; + uint32_t sharingMode; + uint32_t size; + uint32_t usage; + uint32_t range; + uint32_t offset; + uint32_t memorySize; + uint32_t memoryType; + uint32_t bindOffset; + } buffer; + struct { + uint32_t imageLayout; + uint32_t imageFlags; + uint32_t imageType; + uint32_t format; + uint32_t width; + uint32_t height; + uint32_t depth; + uint32_t mipLevels; + uint32_t arrayLayers; + uint32_t samples; + uint32_t tiling; + uint32_t usage; + uint32_t sharingMode; + uint32_t queueFamilyIndexCount; + uint32_t initialLayout; + uint32_t aspectMask; + uint32_t baseMipLevel; + uint32_t levelCount; + uint32_t baseArrayLayer; + uint32_t layerCount; + uint32_t viewFlags; + uint32_t viewType; + uint32_t viewFormat; + uint32_t component_a; + uint32_t component_b; + uint32_t component_g; + uint32_t component_r; + uint32_t memorySize; + uint32_t memoryType; + uint32_t bindOffset; + } image; + struct { + uint32_t flags; + uint32_t magFilter; + uint32_t minFilter; + uint32_t mipmapMode; + uint32_t addressModeU; + uint32_t addressModeV; + uint32_t addressModeW; + union { + float fMipLodBias; + uint32_t uMipLodBias; + }; + uint32_t anisotropyEnable; + union { + float fMaxAnisotropy; + uint32_t uMaxAnisotropy; + }; + uint32_t compareEnable; + uint32_t compareOp; + union { + float fMinLod; + uint32_t uMinLod; + }; + union { + float fMaxLod; + uint32_t uMaxLod; + }; + uint32_t borderColor; + uint32_t unnormalizedCoordinates; + } sampler; + }; +}; + +struct vksp_configuration { + const char *enabledExtensionNames; + uint32_t specializationInfoDataSize; + const char *specializationInfoData; + const char *shaderName; + const char *entryPoint; + uint32_t groupCountX; + uint32_t groupCountY; + uint32_t groupCountZ; +}; + +struct vksp_specialization_map_entry { + uint32_t constantID; + uint32_t offset; + uint32_t size; +}; + +struct vksp_counter { + uint32_t index; + const char *name; +}; + +} diff --git a/extractor/CMakeLists.txt b/extractor/CMakeLists.txt index b7bcb96..9fefe1c 100644 --- a/extractor/CMakeLists.txt +++ b/extractor/CMakeLists.txt @@ -13,12 +13,24 @@ # limitations under the License. add_library(vksp-spirv STATIC spirv.cpp) -target_include_directories(vksp-spirv PUBLIC ${SPIRV-Tools-opt_INCLUDE_DIRS}) target_link_libraries(vksp-spirv ${SPIRV-Tools-opt_LIBRARIES}) +target_include_directories(vksp-spirv PUBLIC + ${SPIRV-Tools-opt_INCLUDE_DIRS} + ${CMAKE_SOURCE_DIR} + ${SPIRV_TOOLS_SOURCE_PATH} + ${SPIRV_TOOLS_BUILD_PATH} +) +if (SPIRV_HEADERS_INCLUDE_PATH) + target_include_directories(vksp-spirv PUBLIC ${SPIRV_HEADERS_INCLUDE_PATH}) +endif() add_executable(vulkan-shader-profiler-extractor extractor.cpp) target_link_libraries(vulkan-shader-profiler-extractor ${PERFETTO_TRACE_PROCESSOR_LIB} vksp-spirv sqlite3) -target_include_directories(vulkan-shader-profiler-extractor PUBLIC ${Vulkan_INCLUDE_DIRS} ${PERFETTO_INTERNAL_INCLUDE_PATH}) +target_include_directories(vulkan-shader-profiler-extractor PUBLIC + ${Vulkan_INCLUDE_DIRS} + ${PERFETTO_INTERNAL_INCLUDE_PATH} + ${CMAKE_SOURCE_DIR} +) if (PERFETTO_GEN_INCLUDE_PATH) target_include_directories(vulkan-shader-profiler-extractor PUBLIC ${PERFETTO_GEN_INCLUDE_PATH}) endif() diff --git a/extractor/extractor.cpp b/extractor/extractor.cpp index 6cb1fe4..e0cbf27 100644 --- a/extractor/extractor.cpp +++ b/extractor/extractor.cpp @@ -121,7 +121,7 @@ std::unique_ptr initialize_database() } bool get_dispatch_compute_and_commandBuffer_from_dispatchId(TraceProcessor *tp, uint64_t dispatchId, uint64_t &dispatch, - uint64_t &compute, uint64_t &commandBuffer, vksp_configuration &config) + uint64_t &compute, uint64_t &commandBuffer, vksp::vksp_configuration &config) { std::string query = "SELECT arg_set_id FROM args WHERE args.key = 'debug.dispatchId' AND args.int_value = " + std::to_string(dispatchId); @@ -165,7 +165,7 @@ bool get_min_timestamp(TraceProcessor *tp, uint64_t commandBUffer, uint64_t max_ } bool get_shader_and_device_from_compute(TraceProcessor *tp, uint64_t compute, std::string &shader, - std::vector &shader_buffer, uint64_t &device, vksp_configuration &config) + std::vector &shader_buffer, uint64_t &device, vksp::vksp_configuration &config) { GET_STR_VALUE(tp, compute, "debug.shader", config.shaderName); @@ -222,7 +222,7 @@ bool get_extensions_from_device(TraceProcessor *tp, uint64_t device, const char } bool get_push_constants(TraceProcessor *tp, uint64_t commandBuffer, uint64_t max_timestamp, uint64_t min_timestamp, - std::vector &push_constants_vector) + std::vector &push_constants_vector) { std::string query = "SELECT arg_set_id, ts FROM slice WHERE slice.name = 'vkCmdPushConstants' AND slice.ts > " + std::to_string(min_timestamp) + " AND slice.ts < " + std::to_string(max_timestamp) + " AND " @@ -234,7 +234,7 @@ bool get_push_constants(TraceProcessor *tp, uint64_t commandBuffer, uint64_t max std::set offsetWritten; while (it.Next()) { uint64_t arg_set_id = it.Get(0).AsLong(); - vksp_push_constant pc; + vksp::vksp_push_constant pc; GET_INT_VALUE(tp, arg_set_id, "debug.offset", pc.offset); GET_INT_VALUE(tp, arg_set_id, "debug.size", pc.size); GET_INT_VALUE(tp, arg_set_id, "debug.stageFlags", pc.stageFlags); @@ -256,7 +256,7 @@ bool get_push_constants(TraceProcessor *tp, uint64_t commandBuffer, uint64_t max } bool get_buffer_descriptor_set( - TraceProcessor *tp, uint64_t write_arg_set_id, uint64_t write_timestamp, vksp_descriptor_set &ds) + TraceProcessor *tp, uint64_t write_arg_set_id, uint64_t write_timestamp, vksp::vksp_descriptor_set &ds) { uint64_t buffer; GET_INT_VALUE(tp, write_arg_set_id, "debug.buffer", buffer); @@ -304,7 +304,7 @@ bool get_buffer_descriptor_set( } bool get_image_descriptor_set( - TraceProcessor *tp, uint64_t write_arg_set_id, uint64_t write_timestamp, vksp_descriptor_set &ds) + TraceProcessor *tp, uint64_t write_arg_set_id, uint64_t write_timestamp, vksp::vksp_descriptor_set &ds) { uint64_t image_view; GET_INT_VALUE(tp, write_arg_set_id, "debug.imageView", image_view); @@ -384,7 +384,7 @@ bool get_image_descriptor_set( } bool get_sampler_descriptor_set( - TraceProcessor *tp, uint64_t write_arg_set_id, uint64_t write_timestamp, vksp_descriptor_set &ds) + TraceProcessor *tp, uint64_t write_arg_set_id, uint64_t write_timestamp, vksp::vksp_descriptor_set &ds) { uint64_t sampler; GET_INT_VALUE(tp, write_arg_set_id, "debug.sampler", sampler); @@ -418,7 +418,7 @@ bool get_sampler_descriptor_set( } bool get_descriptor_set(TraceProcessor *tp, uint64_t commandBuffer, uint64_t max_timestamp, uint64_t min_timestamp, - std::vector &descriptor_sets_vector) + std::vector &descriptor_sets_vector) { std::map> dsSeen; std::string query @@ -432,7 +432,7 @@ bool get_descriptor_set(TraceProcessor *tp, uint64_t commandBuffer, uint64_t max do { uint64_t arg_set_id = it.Get(0).AsLong(); uint64_t bind_timestamp = it.Get(1).AsLong(); - vksp_descriptor_set ds; + vksp::vksp_descriptor_set ds; uint64_t dstSet; { @@ -493,7 +493,8 @@ bool get_descriptor_set(TraceProcessor *tp, uint64_t commandBuffer, uint64_t max } bool get_map_entries_from_cmd_buffer(TraceProcessor *tp, uint64_t commandBuffer, uint64_t max_timestamp, - uint64_t min_timestamp, std::vector &map_entry_vector, vksp_configuration &config) + uint64_t min_timestamp, std::vector &map_entry_vector, + vksp::vksp_configuration &config) { std::string query = "SELECT arg_set_id FROM slice WHERE slice.name = 'vkCmdBindPipeline' AND " + std::to_string(commandBuffer) @@ -513,7 +514,7 @@ bool get_map_entries_from_cmd_buffer(TraceProcessor *tp, uint64_t commandBuffer, EXECUTE_QUERY_NO_CHECK(it2, tp, query2); while (it2.Next()) { uint64_t arg_set_id = it2.Get(0).AsLong(); - vksp_specialization_map_entry me; + vksp::vksp_specialization_map_entry me; GET_INT_VALUE(tp, arg_set_id, "debug.constantID", me.constantID); GET_INT_VALUE(tp, arg_set_id, "debug.offset", me.offset); GET_INT_VALUE(tp, arg_set_id, "debug.size", me.size); @@ -604,7 +605,7 @@ int main(int argc, char **argv) CHECK(tp != nullptr, "Initialization failed"); PRINT("%s read with success", gInput.c_str()); - vksp_configuration config; + vksp::vksp_configuration config; uint64_t dispatch, compute, commandBuffer; CHECK(get_dispatch_compute_and_commandBuffer_from_dispatchId( tp.get(), gDispatchId, dispatch, compute, commandBuffer, config), @@ -633,7 +634,7 @@ int main(int argc, char **argv) CHECK(get_min_timestamp(tp.get(), commandBuffer, max_timestamp, min_timestamp), "Could not get min_timestamp"); PRINT("Min timestamp: %lu", min_timestamp); - std::vector map_entry_vector; + std::vector map_entry_vector; CHECK(get_map_entries_from_cmd_buffer( tp.get(), commandBuffer, max_timestamp, min_timestamp, map_entry_vector, config), "Could not get map entries from command buffer"); @@ -646,7 +647,7 @@ int main(int argc, char **argv) "Could not get features and extensions names from device"); PRINT("Extensions: '%s'", config.enabledExtensionNames); - std::vector push_constants_vector; + std::vector push_constants_vector; CHECK(get_push_constants(tp.get(), commandBuffer, max_timestamp, min_timestamp, push_constants_vector), "Could not get push_constants"); for (auto &pc : push_constants_vector) { @@ -654,7 +655,7 @@ int main(int argc, char **argv) pc.pValues); } - std::vector descriptor_sets_vector; + std::vector descriptor_sets_vector; CHECK(get_descriptor_set(tp.get(), commandBuffer, max_timestamp, min_timestamp, descriptor_sets_vector), "Could not get descriptor_set"); for (auto &ds : descriptor_sets_vector) { diff --git a/extractor/spirv.cpp b/extractor/spirv.cpp index 80ecf1d..47febf5 100644 --- a/extractor/spirv.cpp +++ b/extractor/spirv.cpp @@ -18,8 +18,202 @@ #include "spirv.hpp" #include "utils.hpp" +#include "common/common.hpp" +#include "source/opt/pass.h" +#include "spirv/unified1/NonSemanticVkspReflection.h" + #include +namespace vksp { + +class InsertVkspReflectInfoPass : public spvtools::opt::Pass { +public: + InsertVkspReflectInfoPass(std::vector *pc, std::vector *ds, + std::vector *me, vksp_configuration *config) + : pc_(pc) + , ds_(ds) + , me_(me) + , config_(config) + { + } + const char *name() const override { return "insert-vksp-reflect-info"; } + Status Process() override + { + auto module = context()->module(); + + std::vector ext_words = spvtools::utils::MakeVector("NonSemantic.VkspReflection.1"); + auto ExtInstId = context()->TakeNextId(); + auto ExtInst = new spvtools::opt::Instruction( + context(), spv::Op::OpExtInstImport, 0u, ExtInstId, { { SPV_OPERAND_TYPE_LITERAL_STRING, ext_words } }); + module->AddExtInstImport(std::unique_ptr(ExtInst)); + + uint32_t void_ty_id = context()->get_type_mgr()->GetVoidTypeId(); + + std::vector enabledExtensions = spvtools::utils::MakeVector(config_->enabledExtensionNames); + std::vector pData = spvtools::utils::MakeVector(config_->specializationInfoData); + std::vector shaderName = spvtools::utils::MakeVector(config_->shaderName); + std::vector entryPoint = spvtools::utils::MakeVector(config_->entryPoint); + auto ConfigId = context()->TakeNextId(); + auto ConfigInst = new spvtools::opt::Instruction(context(), spv::Op::OpExtInst, void_ty_id, ConfigId, + { + { SPV_OPERAND_TYPE_ID, { ExtInstId } }, + { SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, { NonSemanticVkspReflectionConfiguration } }, + { SPV_OPERAND_TYPE_LITERAL_STRING, enabledExtensions }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { config_->specializationInfoDataSize } }, + { SPV_OPERAND_TYPE_LITERAL_STRING, pData }, + { SPV_OPERAND_TYPE_LITERAL_STRING, shaderName }, + { SPV_OPERAND_TYPE_LITERAL_STRING, entryPoint }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { config_->groupCountX } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { config_->groupCountY } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { config_->groupCountZ } }, + }); + module->AddExtInstDebugInfo(std::unique_ptr(ConfigInst)); + + for (auto &pc : *pc_) { + std::vector pValues = spvtools::utils::MakeVector(pc.pValues); + auto PcInstId = context()->TakeNextId(); + auto PcInst = new spvtools::opt::Instruction(context(), spv::Op::OpExtInst, void_ty_id, PcInstId, + { + { SPV_OPERAND_TYPE_ID, { ExtInstId } }, + { SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, { NonSemanticVkspReflectionPushConstants } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { pc.offset } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { pc.size } }, + { SPV_OPERAND_TYPE_LITERAL_STRING, pValues }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { pc.stageFlags } }, + }); + module->AddExtInstDebugInfo(std::unique_ptr(PcInst)); + } + + for (auto &ds : *ds_) { + switch (ds.type) { + case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: + case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER: { + auto DsInstId = context()->TakeNextId(); + auto DstInst = new spvtools::opt::Instruction(context(), spv::Op::OpExtInst, void_ty_id, DsInstId, + { + { SPV_OPERAND_TYPE_ID, { ExtInstId } }, + { SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + { NonSemanticVkspReflectionDescriptorSetBuffer } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.ds } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.binding } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.type } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.flags } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.queueFamilyIndexCount } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.sharingMode } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.size } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.usage } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.range } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.offset } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.memorySize } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.memoryType } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.buffer.bindOffset } }, + }); + module->AddExtInstDebugInfo(std::unique_ptr(DstInst)); + } break; + case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE: + case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE: { + auto DsInstId = context()->TakeNextId(); + auto DstInst = new spvtools::opt::Instruction(context(), spv::Op::OpExtInst, void_ty_id, DsInstId, + { + { SPV_OPERAND_TYPE_ID, { ExtInstId } }, + { SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + { NonSemanticVkspReflectionDescriptorSetImage } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.ds } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.binding } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.type } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.imageLayout } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.imageFlags } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.imageType } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.format } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.width } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.height } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.depth } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.mipLevels } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.arrayLayers } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.samples } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.tiling } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.usage } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.sharingMode } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.queueFamilyIndexCount } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.initialLayout } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.aspectMask } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.baseMipLevel } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.levelCount } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.baseArrayLayer } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.layerCount } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.viewFlags } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.viewType } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.viewFormat } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.component_a } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.component_b } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.component_g } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.component_r } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.memorySize } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.memoryType } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.image.bindOffset } }, + }); + module->AddExtInstDebugInfo(std::unique_ptr(DstInst)); + } break; + case VK_DESCRIPTOR_TYPE_SAMPLER: { + auto DsInstId = context()->TakeNextId(); + auto DstInst = new spvtools::opt::Instruction(context(), spv::Op::OpExtInst, void_ty_id, DsInstId, + { + { SPV_OPERAND_TYPE_ID, { ExtInstId } }, + { SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + { NonSemanticVkspReflectionDescriptorSetSampler } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.ds } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.binding } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.type } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.flags } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.magFilter } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.minFilter } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.mipmapMode } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.addressModeU } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.addressModeV } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.addressModeW } }, + { SPV_OPERAND_TYPE_LITERAL_FLOAT, { ds.sampler.uMipLodBias } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.anisotropyEnable } }, + { SPV_OPERAND_TYPE_LITERAL_FLOAT, { ds.sampler.uMaxAnisotropy } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.compareEnable } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.compareOp } }, + { SPV_OPERAND_TYPE_LITERAL_FLOAT, { ds.sampler.uMinLod } }, + { SPV_OPERAND_TYPE_LITERAL_FLOAT, { ds.sampler.uMaxLod } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.borderColor } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { ds.sampler.unnormalizedCoordinates } }, + }); + module->AddExtInstDebugInfo(std::unique_ptr(DstInst)); + } break; + default: + break; + } + } + + for (auto &me : *me_) { + auto MapEntryId = context()->TakeNextId(); + auto MapEntryInst = new spvtools::opt::Instruction(context(), spv::Op::OpExtInst, void_ty_id, MapEntryId, + { + { SPV_OPERAND_TYPE_ID, { ExtInstId } }, + { SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + { NonSemanticVkspReflectionSpecializationMapEntry } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { me.constantID } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { me.offset } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { me.size } }, + }); + module->AddExtInstDebugInfo(std::unique_ptr(MapEntryInst)); + } + + return Status::SuccessWithChange; + }; + +private: + std::vector *pc_; + std::vector *ds_; + std::vector *me_; + vksp_configuration *config_; +}; + +} + bool text_to_binary(spv_context context, std::string *shader, spv_binary &binary) { spv_diagnostic diagnostic; @@ -32,12 +226,13 @@ bool text_to_binary(spv_context context, std::string *shader, spv_binary &binary return true; } -bool create_binary(spv_context context, spv_binary input_binary, std::vector *pc, - std::vector *ds, std::vector *me, - spvtools::vksp_configuration *config, std::vector &output_binary) +bool create_binary(spv_context context, spv_binary input_binary, std::vector *pc, + std::vector *ds, std::vector *me, + vksp::vksp_configuration *config, std::vector &output_binary) { spvtools::Optimizer opt(SPV_ENV_VULKAN_1_3); - opt.RegisterPass(spvtools::CreateInsertVkspReflectInfoPass(pc, ds, me, config)); + opt.RegisterPass( + spvtools::Optimizer::PassToken(std::make_unique(pc, ds, me, config))); spvtools::OptimizerOptions options; options.set_run_validator(false); if (!opt.Run(input_binary->code, input_binary->wordCount, &output_binary, options)) { @@ -64,9 +259,9 @@ bool binary_to_text(spv_context context, std::vector &binary, spv_text } bool store_shader_binary_in_output(spv_context context, spv_binary input_binary, - std::vector *pc, std::vector *ds, - std::vector *me, spvtools::vksp_configuration *config, - const char *output_filename, bool binary_output) + std::vector *pc, std::vector *ds, + std::vector *me, vksp::vksp_configuration *config, const char *output_filename, + bool binary_output) { std::vector binary; if (!create_binary(context, input_binary, pc, ds, me, config, binary)) { @@ -93,9 +288,9 @@ bool store_shader_binary_in_output(spv_context context, spv_binary input_binary, return true; } -extern "C" bool store_shader_in_output(std::string *shader, std::vector *pc, - std::vector *ds, std::vector *me, - spvtools::vksp_configuration *config, const char *output_filename, bool binary_output) +extern "C" bool store_shader_in_output(std::string *shader, std::vector *pc, + std::vector *ds, std::vector *me, + vksp::vksp_configuration *config, const char *output_filename, bool binary_output) { spv_context context = spvContextCreate(SPV_ENV_VULKAN_1_3); spv_binary binary; @@ -111,9 +306,9 @@ extern "C" bool store_shader_in_output(std::string *shader, std::vector *shader_buffer, - std::vector *pc, std::vector *ds, - std::vector *me, spvtools::vksp_configuration *config, - const char *output_filename, bool binary_output) + std::vector *pc, std::vector *ds, + std::vector *me, vksp::vksp_configuration *config, const char *output_filename, + bool binary_output) { spv_context context = spvContextCreate(SPV_ENV_VULKAN_1_3); spv_binary_t binary diff --git a/extractor/spirv.hpp b/extractor/spirv.hpp index 0da78e6..ff67766 100644 --- a/extractor/spirv.hpp +++ b/extractor/spirv.hpp @@ -14,14 +14,15 @@ #pragma once -#include "spirv-tools/optimizer.hpp" +#include "common/common.hpp" #include +#include -extern "C" bool store_shader_in_output(std::string *shader, std::vector *pc, - std::vector *ds, std::vector *me, - spvtools::vksp_configuration *config, const char *output_filename, bool binary_output); +extern "C" bool store_shader_in_output(std::string *shader, std::vector *pc, + std::vector *ds, std::vector *me, + vksp::vksp_configuration *config, const char *output_filename, bool binary_output); extern "C" bool store_shader_buffer_in_output(std::vector *shader_buffer, - std::vector *pc, std::vector *ds, - std::vector *me, spvtools::vksp_configuration *config, - const char *output_filename, bool binary_output); + std::vector *pc, std::vector *ds, + std::vector *me, vksp::vksp_configuration *config, const char *output_filename, + bool binary_output); extern "C" bool read_shader_buffer(std::string *gShaderFile, std::vector *shader_buffer); diff --git a/runner/CMakeLists.txt b/runner/CMakeLists.txt index 4e7c68b..e3733e1 100644 --- a/runner/CMakeLists.txt +++ b/runner/CMakeLists.txt @@ -13,6 +13,15 @@ # limitations under the License. add_executable(vulkan-shader-profiler-runner runner.cpp) -target_include_directories(vulkan-shader-profiler-runner PUBLIC ${Vulkan_INCLUDE_DIRS} ${SPIRV-Tools-opt_INCLUDE_DIRS}) -target_link_libraries(vulkan-shader-profiler-runner ${Vulkan_LIBRARIES} ${SPIRV-Tools-opt_LIBRARIES}) +target_include_directories(vulkan-shader-profiler-runner PUBLIC + ${Vulkan_INCLUDE_DIRS} + ${SPIRV-Tools-opt_INCLUDE_DIRS} + ${SPIRV_TOOLS_SOURCE_PATH} + ${SPIRV_TOOLS_BUILD_PATH} + ${CMAKE_SOURCE_DIR} +) +if (SPIRV_HEADERS_INCLUDE_PATH) + target_include_directories(vulkan-shader-profiler-runner PUBLIC ${SPIRV_HEADERS_INCLUDE_PATH}) +endif() +target_link_libraries(vulkan-shader-profiler-runner ${Vulkan_LIBRARIES} ${SPIRV-Tools-opt_LIBRARIES}) diff --git a/runner/runner.cpp b/runner/runner.cpp index 5bc9459..25e4b7b 100644 --- a/runner/runner.cpp +++ b/runner/runner.cpp @@ -12,20 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include -#include -#include -#include #include +#include "spirv.hpp" + #include +#include #include #include #include +#include +#include +#include #include -#include #define PRINT_IMPL(file, message, ...) \ do { \ @@ -184,9 +185,9 @@ static int get_device_queue_and_cmd_buffer(VkPhysicalDevice &pDevice, VkDevice & return 0; } -static bool extract_from_input(std::vector &shader, std::vector &ds, - std::vector &pc, std::vector &me, - std::vector &counters, spvtools::vksp_configuration &config) +static bool extract_from_input(std::vector &shader, std::vector &ds, + std::vector &pc, std::vector &me, + std::vector &counters, vksp::vksp_configuration &config) { FILE *input = fopen(gInput.c_str(), "r"); fseek(input, 0, SEEK_END); @@ -218,7 +219,8 @@ static bool extract_from_input(std::vector &shader, std::vector(&pc, &ds, &me, &counters, &config))); opt.RegisterPass(spvtools::CreateStripReflectInfoPass()); spvtools::OptimizerOptions options; options.set_run_validator(false); @@ -248,8 +250,8 @@ static bool extract_from_input(std::vector &shader, std::vector &descSet) +static uint32_t handle_descriptor_set_buffer(vksp::vksp_descriptor_set &ds, VkDevice device, VkCommandBuffer cmdBuffer, + VkPhysicalDeviceMemoryProperties &memProperties, std::vector &descSet) { VkResult res; @@ -320,8 +322,8 @@ static uint32_t handle_descriptor_set_buffer(spvtools::vksp_descriptor_set &ds, return 0; } -static uint32_t handle_descriptor_set_image(spvtools::vksp_descriptor_set &ds, VkDevice device, - VkCommandBuffer cmdBuffer, VkPhysicalDeviceMemoryProperties &memProperties, std::vector &descSet) +static uint32_t handle_descriptor_set_image(vksp::vksp_descriptor_set &ds, VkDevice device, VkCommandBuffer cmdBuffer, + VkPhysicalDeviceMemoryProperties &memProperties, std::vector &descSet) { VkResult res; @@ -381,8 +383,8 @@ static uint32_t handle_descriptor_set_image(spvtools::vksp_descriptor_set &ds, V return 0; } -static uint32_t handle_descriptor_set_sampler(spvtools::vksp_descriptor_set &ds, VkDevice device, - VkCommandBuffer cmdBuffer, VkPhysicalDeviceMemoryProperties &memProperties, std::vector &descSet) +static uint32_t handle_descriptor_set_sampler(vksp::vksp_descriptor_set &ds, VkDevice device, VkCommandBuffer cmdBuffer, + VkPhysicalDeviceMemoryProperties &memProperties, std::vector &descSet) { VkResult res; VkSampler sampler; @@ -415,7 +417,7 @@ static uint32_t handle_descriptor_set_sampler(spvtools::vksp_descriptor_set &ds, } static uint32_t allocate_descriptor_set(VkDevice device, std::vector &descSet, - std::vector &dsVector, std::vector &descSetLayoutVector) + std::vector &dsVector, std::vector &descSetLayoutVector) { VkResult res; std::map descTypeCount; @@ -468,7 +470,7 @@ static uint32_t allocate_descriptor_set(VkDevice device, std::vector dsVector) +static uint32_t count_descriptor_set(std::vector dsVector) { std::set ds_set; for (auto &ds : dsVector) { @@ -477,7 +479,7 @@ static uint32_t count_descriptor_set(std::vector return ds_set.size(); } -static uint32_t handle_push_constant(std::vector &pcVector, VkPushConstantRange &range, +static uint32_t handle_push_constant(std::vector &pcVector, VkPushConstantRange &range, VkCommandBuffer cmdBuffer, VkPipelineLayout pipelineLayout) { std::vector pValues(range.size); @@ -500,7 +502,7 @@ static uint32_t handle_push_constant(std::vector & return 0; } -static uint32_t allocate_pipeline_layout(VkDevice device, std::vector &pcVector, +static uint32_t allocate_pipeline_layout(VkDevice device, std::vector &pcVector, std::vector &pcRanges, std::vector &descSetLayoutVector, VkPipelineLayout &pipelineLayout) { @@ -516,8 +518,8 @@ static uint32_t allocate_pipeline_layout(VkDevice device, std::vector &shader, VkPipelineLayout pipelineLayout, VkDevice device, - VkCommandBuffer cmdBuffer, std::vector &meVector, - spvtools::vksp_configuration &config, VkPipeline &pipeline) + VkCommandBuffer cmdBuffer, std::vector &meVector, + vksp::vksp_configuration &config, VkPipeline &pipeline) { VkResult res; const VkShaderModuleCreateInfo shaderModuleCreateInfo @@ -605,13 +607,13 @@ static void print_time(uint64_t time, const char *prefix, uint32_t prefix_max_si } } -static uint32_t counters_size(std::vector &counters) +static uint32_t counters_size(std::vector &counters) { return sizeof(uint64_t) * (2 + counters.size()); } -static uint32_t execute(VkDevice device, VkCommandBuffer cmdBuffer, VkQueue queue, spvtools::vksp_configuration &config, - std::vector &counters, uint64_t *gpu_timestamps, +static uint32_t execute(VkDevice device, VkCommandBuffer cmdBuffer, VkQueue queue, vksp::vksp_configuration &config, + std::vector &counters, uint64_t *gpu_timestamps, std::chrono::steady_clock::time_point *host_timestamps) { VkResult res; @@ -670,8 +672,8 @@ static uint32_t execute(VkDevice device, VkCommandBuffer cmdBuffer, VkQueue queu return 0; } -static uint32_t print_results(VkPhysicalDevice pDevice, VkDevice device, spvtools::vksp_configuration &config, - std::vector &counters, uint64_t *gpu_timestamps, +static uint32_t print_results(VkPhysicalDevice pDevice, VkDevice device, vksp::vksp_configuration &config, + std::vector &counters, uint64_t *gpu_timestamps, std::chrono::steady_clock::time_point *host_timestamps) { VkPhysicalDeviceProperties properties; @@ -862,11 +864,11 @@ int main(int argc, char **argv) gVerbose, spvTargetEnvDescription(gSpvTargetEnv), gHotRun, gColdRun); std::vector shader; - std::vector dsVector; - std::vector pcVector; - std::vector meVector; - std::vector counters; - spvtools::vksp_configuration config; + std::vector dsVector; + std::vector pcVector; + std::vector meVector; + std::vector counters; + vksp::vksp_configuration config; CHECK(extract_from_input(shader, dsVector, pcVector, meVector, counters, config), "Could not extract data from input"); PRINT("Shader name: '%s'", config.shaderName); @@ -940,7 +942,7 @@ int main(int argc, char **argv) } } std::vector pcRanges; - std::map> pcMap; + std::map> pcMap; for (auto &pc : pcVector) { PRINT("push_constants: offset %u size %u stageFlags %u pValues %s", pc.offset, pc.size, pc.stageFlags, pc.pValues); diff --git a/runner/spirv.hpp b/runner/spirv.hpp new file mode 100644 index 0000000..c6244a8 --- /dev/null +++ b/runner/spirv.hpp @@ -0,0 +1,546 @@ +// Copyright 2024 The Vulkan Shader Profiler authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "common/common.hpp" +#include "source/opt/pass.h" +#include "spirv/unified1/NonSemanticVkspReflection.h" + +#define UNDEFINED_ID (UINT32_MAX) + +namespace vksp { + +class ExtractVkspReflectInfoPass : public spvtools::opt::Pass { +public: + ExtractVkspReflectInfoPass(std::vector *pc, std::vector *ds, + std::vector *me, std::vector *counters, vksp_configuration *config) + : pc_(pc) + , ds_(ds) + , me_(me) + , counters_(counters) + , config_(config) + { + } + const char *name() const override { return "extract-vksp-reflect-info"; } + Status Process() override + { + auto module = context()->module(); + uint32_t ext_inst_id = module->GetExtInstImportId("NonSemantic.VkspReflection.1"); + int32_t descriptor_set_0_max_binding = -1; + std::map id_to_descriptor_set; + std::map id_to_binding; + std::vector start_counters; + std::vector stop_counters; + + module->ForEachInst([this, ext_inst_id, &id_to_descriptor_set, &id_to_binding, &descriptor_set_0_max_binding, + &start_counters, &stop_counters](spvtools::opt::Instruction *inst) { + ParseInstruction(inst, ext_inst_id, id_to_descriptor_set, id_to_binding, descriptor_set_0_max_binding, + start_counters, stop_counters); + }); + + context()->AddExtension("SPV_KHR_shader_clock"); + context()->AddExtension("SPV_KHR_storage_buffer_storage_class"); + context()->AddCapability(spv::Capability::ShaderClockKHR); + context()->AddCapability(spv::Capability::Int64); + context()->AddCapability(spv::Capability::Int64Atomics); + + uint32_t global_counters_ds = 0; + uint32_t global_counters_binding = descriptor_set_0_max_binding + 1; + auto counters_size + = (uint32_t)(sizeof(uint64_t) * (2 + start_counters.size())); // 2 for the number of invocations and the + // time of the whole entry point + ds_->push_back( + { global_counters_ds, global_counters_binding, (uint32_t)VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER, + { .buffer = { 0, 0, VK_SHARING_MODE_EXCLUSIVE, counters_size, + VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, counters_size, 0, + counters_size, UINT32_MAX, 0 } } }); + + auto *cst_mgr = context()->get_constant_mgr(); + auto *type_mgr = context()->get_type_mgr(); + + auto u64_ty = type_mgr->GetIntType(64, 0); + spvtools::opt::analysis::RuntimeArray run_arr(u64_ty); + auto u64_run_arr_ty = type_mgr->GetRegisteredType(&run_arr); + spvtools::opt::analysis::Struct st({ u64_run_arr_ty }); + auto u64_run_arr_st_ty = type_mgr->GetRegisteredType(&st); + spvtools::opt::analysis::Pointer u64_run_arr_st_ty_ptr(u64_run_arr_st_ty, spv::StorageClass::StorageBuffer); + auto u64_run_arr_st_ptr_ty = type_mgr->GetRegisteredType(&u64_run_arr_st_ty_ptr); + + spvtools::opt::analysis::Pointer u64_ty_ptr(u64_ty, spv::StorageClass::StorageBuffer); + auto u64_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr); + + auto counters_ty_id = type_mgr->GetId(u64_run_arr_st_ptr_ty); + auto u64_ty_id = type_mgr->GetId(u64_ty); + auto u64_ptr_ty_id = type_mgr->GetId(u64_ptr_ty); + auto u64_arr_ty_id = type_mgr->GetId(u64_run_arr_ty); + auto u64_arr_st_ty_id = type_mgr->GetId(u64_run_arr_st_ty); + + uint32_t local_counters_ty_id = UNDEFINED_ID; + uint32_t u64_private_ptr_ty_id = UNDEFINED_ID; + + if (start_counters.size() > 0) { + spvtools::opt::analysis::Array arr(u64_ty, + spvtools::opt::analysis::Array::LengthInfo { + cst_mgr->GetUIntConstId((uint32_t)start_counters.size()), { 0, (uint32_t)start_counters.size() } }); + auto u64_arr_ty = type_mgr->GetRegisteredType(&arr); + spvtools::opt::analysis::Pointer u64_arr_ty_ptr(u64_arr_ty, spv::StorageClass::Private); + auto u64_arr_ptr_ty = type_mgr->GetRegisteredType(&u64_arr_ty_ptr); + spvtools::opt::analysis::Pointer u64_ty_ptr_private(u64_ty, spv::StorageClass::Private); + auto u64_private_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr_private); + + local_counters_ty_id = type_mgr->GetId(u64_arr_ptr_ty); + u64_private_ptr_ty_id = type_mgr->GetId(u64_private_ptr_ty); + } + + auto subgroup_scope_id = cst_mgr->GetUIntConstId((uint32_t)spv::Scope::Subgroup); + auto device_scope_id = cst_mgr->GetUIntConstId((uint32_t)spv::Scope::Device); + auto acq_rel_mem_sem_id = cst_mgr->GetUIntConstId((uint32_t)spv::MemorySemanticsMask::AcquireRelease); + + uint32_t global_counters_id; + uint32_t local_counters_id; + CreateVariables(u64_arr_ty_id, u64_arr_st_ty_id, local_counters_ty_id, counters_ty_id, global_counters_ds, + global_counters_binding, global_counters_id, local_counters_id); + + bool found = false; + for (auto &entry_point_inst : module->entry_points()) { + auto function_name = entry_point_inst.GetOperand(2).AsString(); + if (function_name != std::string(config_->entryPoint)) { + continue; + } + found = true; + + uint32_t read_clock_id; + spvtools::opt::Function *function; + CreatePrologue(&entry_point_inst, u64_private_ptr_ty_id, u64_ty_id, subgroup_scope_id, global_counters_id, + local_counters_id, start_counters, function, read_clock_id); + + function->ForEachInst([this, read_clock_id, u64_ty_id, u64_ptr_ty_id, u64_private_ptr_ty_id, + subgroup_scope_id, device_scope_id, acq_rel_mem_sem_id, global_counters_id, + local_counters_id, &start_counters](spvtools::opt::Instruction *inst) { + if (inst->opcode() != spv::Op::OpReturn) { + return; + } + CreateEpilogue(inst, read_clock_id, u64_ty_id, u64_ptr_ty_id, u64_private_ptr_ty_id, subgroup_scope_id, + device_scope_id, acq_rel_mem_sem_id, global_counters_id, local_counters_id, start_counters); + }); + + break; + } + if (!found) { + return Status::Failure; + } + + CreateCounters( + u64_ty_id, u64_private_ptr_ty_id, subgroup_scope_id, start_counters, stop_counters, local_counters_id); + + return Status::SuccessWithChange; + }; + +private: + int32_t UpdateMaxBinding(uint32_t ds, uint32_t binding, int32_t max_binding) + { + if (ds != 0) { + return max_binding; + } else { + return std::max(max_binding, (int32_t)binding); + } + } + + void ParseInstruction(spvtools::opt::Instruction *inst, uint32_t ext_inst_id, + std::map &id_to_descriptor_set, std::map &id_to_binding, + int32_t &descriptor_set_0_max_binding, std::vector &start_counters, + std::vector &stop_counters) + { + uint32_t op_id = 2; + if (inst->opcode() == spv::Op::OpDecorate) { + spv::Decoration decoration = (spv::Decoration)inst->GetOperand(1).words[0]; + if (decoration == spv::Decoration::DescriptorSet) { + auto id = inst->GetOperand(0).AsId(); + auto ds = inst->GetOperand(2).words[0]; + id_to_descriptor_set[id] = ds; + if (ds == 0 && id_to_binding.count(id) > 0) { + descriptor_set_0_max_binding + = UpdateMaxBinding(ds, id_to_binding[id], descriptor_set_0_max_binding); + } + } else if (decoration == spv::Decoration::Binding) { + auto id = inst->GetOperand(0).AsId(); + auto binding = inst->GetOperand(2).words[0]; + id_to_binding[id] = binding; + if (id_to_descriptor_set.count(id) > 0) { + descriptor_set_0_max_binding + = UpdateMaxBinding(id_to_descriptor_set[id], binding, descriptor_set_0_max_binding); + } + } + return; + } else if (inst->opcode() != spv::Op::OpExtInst || ext_inst_id != inst->GetOperand(op_id++).AsId()) { + return; + } + + auto vksp_inst = inst->GetOperand(op_id++).words[0]; + switch (vksp_inst) { + case NonSemanticVkspReflectionConfiguration: + config_->enabledExtensionNames = strdup(inst->GetOperand(op_id++).AsString().c_str()); + config_->specializationInfoDataSize = inst->GetOperand(op_id++).words[0]; + config_->specializationInfoData = strdup(inst->GetOperand(op_id++).AsString().c_str()); + config_->shaderName = strdup(inst->GetOperand(op_id++).AsString().c_str()); + config_->entryPoint = strdup(inst->GetOperand(op_id++).AsString().c_str()); + config_->groupCountX = inst->GetOperand(op_id++).words[0]; + config_->groupCountY = inst->GetOperand(op_id++).words[0]; + config_->groupCountZ = inst->GetOperand(op_id++).words[0]; + break; + case NonSemanticVkspReflectionDescriptorSetBuffer: { + vksp_descriptor_set ds; + ds.ds = inst->GetOperand(op_id++).words[0]; + ds.binding = inst->GetOperand(op_id++).words[0]; + ds.type = inst->GetOperand(op_id++).words[0]; + ds.buffer.flags = inst->GetOperand(op_id++).words[0]; + ds.buffer.queueFamilyIndexCount = inst->GetOperand(op_id++).words[0]; + ds.buffer.sharingMode = inst->GetOperand(op_id++).words[0]; + ds.buffer.size = inst->GetOperand(op_id++).words[0]; + ds.buffer.usage = inst->GetOperand(op_id++).words[0]; + ds.buffer.range = inst->GetOperand(op_id++).words[0]; + ds.buffer.offset = inst->GetOperand(op_id++).words[0]; + ds.buffer.memorySize = inst->GetOperand(op_id++).words[0]; + ds.buffer.memoryType = inst->GetOperand(op_id++).words[0]; + ds.buffer.bindOffset = inst->GetOperand(op_id++).words[0]; + ds_->push_back(ds); + descriptor_set_0_max_binding = UpdateMaxBinding(ds.ds, ds.binding, descriptor_set_0_max_binding); + } break; + case NonSemanticVkspReflectionDescriptorSetImage: { + vksp_descriptor_set ds; + ds.ds = inst->GetOperand(op_id++).words[0]; + ds.binding = inst->GetOperand(op_id++).words[0]; + ds.type = inst->GetOperand(op_id++).words[0]; + ds.image.imageLayout = inst->GetOperand(op_id++).words[0]; + ds.image.imageFlags = inst->GetOperand(op_id++).words[0]; + ds.image.imageType = inst->GetOperand(op_id++).words[0]; + ds.image.format = inst->GetOperand(op_id++).words[0]; + ds.image.width = inst->GetOperand(op_id++).words[0]; + ds.image.height = inst->GetOperand(op_id++).words[0]; + ds.image.depth = inst->GetOperand(op_id++).words[0]; + ds.image.mipLevels = inst->GetOperand(op_id++).words[0]; + ds.image.arrayLayers = inst->GetOperand(op_id++).words[0]; + ds.image.samples = inst->GetOperand(op_id++).words[0]; + ds.image.tiling = inst->GetOperand(op_id++).words[0]; + ds.image.usage = inst->GetOperand(op_id++).words[0]; + ds.image.sharingMode = inst->GetOperand(op_id++).words[0]; + ds.image.queueFamilyIndexCount = inst->GetOperand(op_id++).words[0]; + ds.image.initialLayout = inst->GetOperand(op_id++).words[0]; + ds.image.aspectMask = inst->GetOperand(op_id++).words[0]; + ds.image.baseMipLevel = inst->GetOperand(op_id++).words[0]; + ds.image.levelCount = inst->GetOperand(op_id++).words[0]; + ds.image.baseArrayLayer = inst->GetOperand(op_id++).words[0]; + ds.image.layerCount = inst->GetOperand(op_id++).words[0]; + ds.image.viewFlags = inst->GetOperand(op_id++).words[0]; + ds.image.viewType = inst->GetOperand(op_id++).words[0]; + ds.image.viewFormat = inst->GetOperand(op_id++).words[0]; + ds.image.component_a = inst->GetOperand(op_id++).words[0]; + ds.image.component_b = inst->GetOperand(op_id++).words[0]; + ds.image.component_g = inst->GetOperand(op_id++).words[0]; + ds.image.component_r = inst->GetOperand(op_id++).words[0]; + ds.image.memorySize = inst->GetOperand(op_id++).words[0]; + ds.image.memoryType = inst->GetOperand(op_id++).words[0]; + ds.image.bindOffset = inst->GetOperand(op_id++).words[0]; + ds_->push_back(ds); + descriptor_set_0_max_binding = UpdateMaxBinding(ds.ds, ds.binding, descriptor_set_0_max_binding); + } break; + case NonSemanticVkspReflectionDescriptorSetSampler: { + vksp_descriptor_set ds; + ds.ds = inst->GetOperand(op_id++).words[0]; + ds.binding = inst->GetOperand(op_id++).words[0]; + ds.type = inst->GetOperand(op_id++).words[0]; + ds.sampler.flags = inst->GetOperand(op_id++).words[0]; + ds.sampler.magFilter = inst->GetOperand(op_id++).words[0]; + ds.sampler.minFilter = inst->GetOperand(op_id++).words[0]; + ds.sampler.mipmapMode = inst->GetOperand(op_id++).words[0]; + ds.sampler.addressModeU = inst->GetOperand(op_id++).words[0]; + ds.sampler.addressModeV = inst->GetOperand(op_id++).words[0]; + ds.sampler.addressModeW = inst->GetOperand(op_id++).words[0]; + ds.sampler.uMipLodBias = inst->GetOperand(op_id++).words[0]; + ds.sampler.anisotropyEnable = inst->GetOperand(op_id++).words[0]; + ds.sampler.uMaxAnisotropy = inst->GetOperand(op_id++).words[0]; + ds.sampler.compareEnable = inst->GetOperand(op_id++).words[0]; + ds.sampler.compareOp = inst->GetOperand(op_id++).words[0]; + ds.sampler.uMinLod = inst->GetOperand(op_id++).words[0]; + ds.sampler.uMaxLod = inst->GetOperand(op_id++).words[0]; + ds.sampler.borderColor = inst->GetOperand(op_id++).words[0]; + ds.sampler.unnormalizedCoordinates = inst->GetOperand(op_id++).words[0]; + ds_->push_back(ds); + descriptor_set_0_max_binding = UpdateMaxBinding(ds.ds, ds.binding, descriptor_set_0_max_binding); + } break; + case NonSemanticVkspReflectionPushConstants: + vksp_push_constant pc; + pc.offset = inst->GetOperand(op_id++).words[0]; + pc.size = inst->GetOperand(op_id++).words[0]; + pc.pValues = strdup(inst->GetOperand(op_id++).AsString().c_str()); + pc.stageFlags = inst->GetOperand(op_id++).words[0]; + pc_->push_back(pc); + break; + case NonSemanticVkspReflectionSpecializationMapEntry: + vksp_specialization_map_entry me; + me.constantID = inst->GetOperand(op_id++).words[0]; + me.offset = inst->GetOperand(op_id++).words[0]; + me.size = inst->GetOperand(op_id++).words[0]; + me_->push_back(me); + break; + case NonSemanticVkspReflectionStartCounter: + start_counters.push_back(inst); + break; + case NonSemanticVkspReflectionStopCounter: + stop_counters.push_back(inst); + break; + default: + break; + } + } + + void CreateVariables(uint32_t u64_arr_ty_id, uint32_t u64_arr_st_ty_id, uint32_t local_counters_ty_id, + uint32_t counters_ty_id, uint32_t global_counters_ds, uint32_t global_counters_binding, + uint32_t &global_counters_id, uint32_t &local_counters_id) + { + auto module = context()->module(); + + auto decorate_arr_inst = new spvtools::opt::Instruction(context(), spv::Op::OpDecorate, 0, 0, + { { SPV_OPERAND_TYPE_ID, { u64_arr_ty_id } }, + { SPV_OPERAND_TYPE_DECORATION, { (uint32_t)spv::Decoration::ArrayStride } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { 8 } } }); + module->AddAnnotationInst(std::unique_ptr(decorate_arr_inst)); + + auto decorate_member_offset_inst = new spvtools::opt::Instruction(context(), spv::Op::OpMemberDecorate, 0, 0, + { { SPV_OPERAND_TYPE_ID, { u64_arr_st_ty_id } }, { SPV_OPERAND_TYPE_LITERAL_INTEGER, { 0 } }, + { SPV_OPERAND_TYPE_DECORATION, { (uint32_t)spv::Decoration::Offset } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { 0 } } }); + module->AddAnnotationInst(std::unique_ptr(decorate_member_offset_inst)); + + auto decorate_arr_st_inst = new spvtools::opt::Instruction(context(), spv::Op::OpDecorate, 0, 0, + { { SPV_OPERAND_TYPE_ID, { u64_arr_st_ty_id } }, + { SPV_OPERAND_TYPE_DECORATION, { (uint32_t)spv::Decoration::Block } } }); + module->AddAnnotationInst(std::unique_ptr(decorate_arr_st_inst)); + + if (local_counters_ty_id != UNDEFINED_ID) { + local_counters_id = context()->TakeNextId(); + auto local_counters_inst = new spvtools::opt::Instruction(context(), spv::Op::OpVariable, + local_counters_ty_id, local_counters_id, + { { SPV_OPERAND_TYPE_LITERAL_INTEGER, { (uint32_t)spv::StorageClass::Private } } }); + module->AddGlobalValue(std::unique_ptr(local_counters_inst)); + } else { + local_counters_id = UNDEFINED_ID; + } + + global_counters_id = context()->TakeNextId(); + auto global_counters_inst + = new spvtools::opt::Instruction(context(), spv::Op::OpVariable, counters_ty_id, global_counters_id, + { { SPV_OPERAND_TYPE_LITERAL_INTEGER, { (uint32_t)spv::StorageClass::StorageBuffer } } }); + module->AddGlobalValue(std::unique_ptr(global_counters_inst)); + + auto counters_descriptor_set_inst = new spvtools::opt::Instruction(context(), spv::Op::OpDecorate, 0, 0, + { { SPV_OPERAND_TYPE_ID, { global_counters_inst->result_id() } }, + { SPV_OPERAND_TYPE_DECORATION, { (uint32_t)spv::Decoration::DescriptorSet } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { global_counters_ds } } }); + module->AddAnnotationInst(std::unique_ptr(counters_descriptor_set_inst)); + + auto counters_binding_inst = new spvtools::opt::Instruction(context(), spv::Op::OpDecorate, 0, 0, + { { SPV_OPERAND_TYPE_ID, { global_counters_inst->result_id() } }, + { SPV_OPERAND_TYPE_DECORATION, { (uint32_t)spv::Decoration::Binding } }, + { SPV_OPERAND_TYPE_LITERAL_INTEGER, { global_counters_binding } } }); + module->AddAnnotationInst(std::unique_ptr(counters_binding_inst)); + } + + void CreatePrologue(spvtools::opt::Instruction *entry_point_inst, uint32_t u64_private_ptr_ty_id, + uint32_t u64_ty_id, uint32_t subgroup_scope_id, uint32_t global_counters_id, uint32_t local_counters_id, + std::vector &start_counters, spvtools::opt::Function *&function, + uint32_t &read_clock_id) + { + auto *cst_mgr = context()->get_constant_mgr(); + entry_point_inst->AddOperand({ SPV_OPERAND_TYPE_ID, { global_counters_id } }); + if (local_counters_id != UNDEFINED_ID) { + entry_point_inst->AddOperand({ SPV_OPERAND_TYPE_ID, { local_counters_id } }); + } + + auto function_id = entry_point_inst->GetOperand(1).AsId(); + function = context()->GetFunction(function_id); + + auto &function_first_inst = *function->entry()->begin(); + + auto u64_cst0_id = cst_mgr->GetDefiningInstruction(cst_mgr->GetIntConst(0, 64, 0))->result_id(); + + for (unsigned i = 0; i < start_counters.size(); i++) { + auto get_id = context()->TakeNextId(); + auto gep_inst + = new spvtools::opt::Instruction(context(), spv::Op::OpAccessChain, u64_private_ptr_ty_id, get_id, + { { SPV_OPERAND_TYPE_ID, { local_counters_id } }, + { SPV_OPERAND_TYPE_ID, { cst_mgr->GetUIntConstId(i) } } }); + gep_inst->InsertBefore(&function_first_inst); + + auto store_inst = new spvtools::opt::Instruction(context(), spv::Op::OpStore, 0, 0, + { { SPV_OPERAND_TYPE_ID, { get_id } }, { SPV_OPERAND_TYPE_ID, { u64_cst0_id } } }); + store_inst->InsertAfter(gep_inst); + } + + read_clock_id = context()->TakeNextId(); + auto read_clock_inst = new spvtools::opt::Instruction(context(), spv::Op::OpReadClockKHR, u64_ty_id, + read_clock_id, { { SPV_OPERAND_TYPE_SCOPE_ID, { subgroup_scope_id } } }); + read_clock_inst->InsertBefore(&function_first_inst); + } + + void CreateEpilogue(spvtools::opt::Instruction *return_inst, uint32_t read_clock_id, uint32_t u64_ty_id, + uint32_t u64_ptr_ty_id, uint32_t u64_private_ptr_ty_id, uint32_t subgroup_scope_id, uint32_t device_scope_id, + uint32_t acq_rel_mem_sem_id, uint32_t global_counters_id, uint32_t local_counters_id, + std::vector &start_counters) + { + auto *cst_mgr = context()->get_constant_mgr(); + + auto read_clock_end_id = context()->TakeNextId(); + auto read_clock_end_inst = new spvtools::opt::Instruction(context(), spv::Op::OpReadClockKHR, u64_ty_id, + read_clock_end_id, { { SPV_OPERAND_TYPE_SCOPE_ID, { subgroup_scope_id } } }); + read_clock_end_inst->InsertBefore(return_inst); + + auto substraction_id = context()->TakeNextId(); + auto substraction_inst = new spvtools::opt::Instruction(context(), spv::Op::OpISub, u64_ty_id, substraction_id, + { { SPV_OPERAND_TYPE_ID, { read_clock_end_id } }, { SPV_OPERAND_TYPE_ID, { read_clock_id } } }); + substraction_inst->InsertAfter(read_clock_end_inst); + + auto gep_invocations_id = context()->TakeNextId(); + auto gep_invocations_inst = new spvtools::opt::Instruction(context(), spv::Op::OpAccessChain, u64_ptr_ty_id, + gep_invocations_id, + { { SPV_OPERAND_TYPE_ID, { global_counters_id } }, { SPV_OPERAND_TYPE_ID, { cst_mgr->GetUIntConstId(0) } }, + { SPV_OPERAND_TYPE_ID, { cst_mgr->GetUIntConstId(0) } } }); + gep_invocations_inst->InsertAfter(substraction_inst); + + auto atomic_incr_id = context()->TakeNextId(); + auto atomic_incr_inst + = new spvtools::opt::Instruction(context(), spv::Op::OpAtomicIIncrement, u64_ty_id, atomic_incr_id, + { { SPV_OPERAND_TYPE_ID, { gep_invocations_id } }, { SPV_OPERAND_TYPE_SCOPE_ID, { device_scope_id } }, + { SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID, { acq_rel_mem_sem_id } } }); + atomic_incr_inst->InsertAfter(gep_invocations_inst); + + auto gep_entrypoint_counter_id = context()->TakeNextId(); + auto gep_entrypoint_counter_inst = new spvtools::opt::Instruction(context(), spv::Op::OpAccessChain, + u64_ptr_ty_id, gep_entrypoint_counter_id, + { { SPV_OPERAND_TYPE_ID, { global_counters_id } }, { SPV_OPERAND_TYPE_ID, { cst_mgr->GetUIntConstId(0) } }, + { SPV_OPERAND_TYPE_ID, { cst_mgr->GetUIntConstId(1) } } }); + gep_entrypoint_counter_inst->InsertAfter(atomic_incr_inst); + + auto atomic_add_id = context()->TakeNextId(); + auto atomic_add_inst + = new spvtools::opt::Instruction(context(), spv::Op::OpAtomicIAdd, u64_ty_id, atomic_add_id, + { { SPV_OPERAND_TYPE_ID, { gep_entrypoint_counter_id } }, + { SPV_OPERAND_TYPE_SCOPE_ID, { device_scope_id } }, + { SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID, { acq_rel_mem_sem_id } }, + { SPV_OPERAND_TYPE_ID, { substraction_id } } }); + atomic_add_inst->InsertAfter(gep_entrypoint_counter_inst); + + for (unsigned i = 0; i < start_counters.size(); i++) { + auto gep_id = context()->TakeNextId(); + auto gep_inst + = new spvtools::opt::Instruction(context(), spv::Op::OpAccessChain, u64_private_ptr_ty_id, gep_id, + { { SPV_OPERAND_TYPE_ID, { local_counters_id } }, + { SPV_OPERAND_TYPE_ID, { cst_mgr->GetUIntConstId(i) } } }); + gep_inst->InsertAfter(atomic_add_inst); + + auto load_id = context()->TakeNextId(); + auto load_inst = new spvtools::opt::Instruction( + context(), spv::Op::OpLoad, u64_ty_id, load_id, { { SPV_OPERAND_TYPE_ID, { gep_id } } }); + load_inst->InsertAfter(gep_inst); + + auto gep_atomic_id = context()->TakeNextId(); + auto gep_atomic_inst + = new spvtools::opt::Instruction(context(), spv::Op::OpAccessChain, u64_ptr_ty_id, gep_atomic_id, + { { SPV_OPERAND_TYPE_ID, { global_counters_id } }, + { SPV_OPERAND_TYPE_ID, { cst_mgr->GetUIntConstId(0) } }, + { SPV_OPERAND_TYPE_ID, { cst_mgr->GetUIntConstId(2 + i) } } }); + gep_atomic_inst->InsertAfter(load_inst); + + atomic_add_id = context()->TakeNextId(); + atomic_add_inst = new spvtools::opt::Instruction(context(), spv::Op::OpAtomicIAdd, u64_ty_id, atomic_add_id, + { { SPV_OPERAND_TYPE_ID, { gep_atomic_id } }, { SPV_OPERAND_TYPE_SCOPE_ID, { device_scope_id } }, + { SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID, { acq_rel_mem_sem_id } }, + { SPV_OPERAND_TYPE_ID, { load_id } } }); + atomic_add_inst->InsertAfter(gep_atomic_inst); + } + } + + void CreateCounters(uint32_t u64_ty_id, uint32_t u64_private_ptr_ty_id, uint32_t subgroup_scope_id, + std::vector &start_counters, + std::vector &stop_counters, uint32_t local_counters_id) + { + auto *cst_mgr = context()->get_constant_mgr(); + std::map> start_counters_id_map; + uint32_t next_counter_id = 2; + + for (auto *inst : start_counters) { + const char *counter_name = strdup(inst->GetOperand(4).AsString().c_str()); + + auto read_clock_id = context()->TakeNextId(); + auto read_clock_inst = new spvtools::opt::Instruction(context(), spv::Op::OpReadClockKHR, u64_ty_id, + read_clock_id, { { SPV_OPERAND_TYPE_SCOPE_ID, { subgroup_scope_id } } }); + read_clock_inst->InsertBefore(inst); + + counters_->push_back({ next_counter_id, counter_name }); + start_counters_id_map[inst->result_id()] = std::make_pair(read_clock_id, next_counter_id); + next_counter_id++; + } + + for (auto *inst : stop_counters) { + auto read_clock_ext_inst_id = inst->GetOperand(4).AsId(); + if (start_counters_id_map.count(read_clock_ext_inst_id) == 0) { + continue; + } + auto pair = start_counters_id_map[read_clock_ext_inst_id]; + auto read_clock_id = pair.first; + auto counters_var_index = pair.second; + + auto read_clock_end_id = context()->TakeNextId(); + auto read_clock_end_inst = new spvtools::opt::Instruction(context(), spv::Op::OpReadClockKHR, u64_ty_id, + read_clock_end_id, { { SPV_OPERAND_TYPE_SCOPE_ID, { subgroup_scope_id } } }); + read_clock_end_inst->InsertAfter(inst); + + auto substraction_id = context()->TakeNextId(); + auto substraction_inst + = new spvtools::opt::Instruction(context(), spv::Op::OpISub, u64_ty_id, substraction_id, + { { SPV_OPERAND_TYPE_ID, { read_clock_end_id } }, { SPV_OPERAND_TYPE_ID, { read_clock_id } } }); + substraction_inst->InsertAfter(read_clock_end_inst); + + auto gep_id = context()->TakeNextId(); + auto gep_inst + = new spvtools::opt::Instruction(context(), spv::Op::OpAccessChain, u64_private_ptr_ty_id, gep_id, + { { SPV_OPERAND_TYPE_ID, { local_counters_id } }, + { SPV_OPERAND_TYPE_ID, { cst_mgr->GetUIntConstId(counters_var_index - 2) } } }); + gep_inst->InsertAfter(substraction_inst); + + auto load_id = context()->TakeNextId(); + auto load_inst = new spvtools::opt::Instruction( + context(), spv::Op::OpLoad, u64_ty_id, load_id, { { SPV_OPERAND_TYPE_ID, { gep_id } } }); + load_inst->InsertAfter(gep_inst); + + auto add_id = context()->TakeNextId(); + auto add_inst = new spvtools::opt::Instruction(context(), spv::Op::OpIAdd, u64_ty_id, add_id, + { { SPV_OPERAND_TYPE_ID, { load_id } }, { SPV_OPERAND_TYPE_ID, { substraction_id } } }); + add_inst->InsertAfter(load_inst); + + auto store_inst = new spvtools::opt::Instruction(context(), spv::Op::OpStore, 0, 0, + { { SPV_OPERAND_TYPE_ID, { gep_id } }, { SPV_OPERAND_TYPE_ID, { add_id } } }); + store_inst->InsertAfter(add_inst); + } + } + + std::vector *pc_; + std::vector *ds_; + std::vector *me_; + std::vector *counters_; + vksp_configuration *config_; +}; + +}