diff --git a/applications/endoscopy_tool_tracking/cpp/main.cpp b/applications/endoscopy_tool_tracking/cpp/main.cpp index 983c87ee5..e6379fa7d 100644 --- a/applications/endoscopy_tool_tracking/cpp/main.cpp +++ b/applications/endoscopy_tool_tracking/cpp/main.cpp @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #include +#include #include #include #include @@ -24,7 +25,6 @@ #include #include #include -#include "holoscan/holoscan.hpp" #ifdef VTK_RENDERER #include #endif @@ -38,6 +38,11 @@ #include #endif +#include + +#define HOLOSCAN_VERSION \ + (HOLOSCAN_VERSION_MAJOR * 10000 + HOLOSCAN_VERSION_MINOR * 100 + HOLOSCAN_VERSION_PATCH) + class App : public holoscan::Application { public: void set_source(const std::string& source) { source_ = source; } @@ -111,6 +116,10 @@ class App : public holoscan::Application { height = 480; source = make_operator( "replayer", from_config("replayer"), Arg("directory", datapath)); +#if HOLOSCAN_VERSION >= 20600 + // the RMMAllocator supported since v2.6 is much faster than the default UnboundAllocator + source->add_arg(Arg("allocator", make_resource("video_replayer_allocator"))); +#endif source_block_size = width * height * 3 * 4; source_num_blocks = 2; } @@ -151,16 +160,19 @@ class App : public holoscan::Application { "pool", 1, lstm_inferer_block_size, lstm_inferer_num_blocks), Arg("cuda_stream_pool") = cuda_stream_pool); - const uint64_t tool_tracking_postprocessor_block_size = 107 * 60 * 7 * 4; - const uint64_t tool_tracking_postprocessor_num_blocks = 2; + // the tool tracking post process outputs + // - a RGBA float32 color mask + // - coordinates with x,y and size in float32 + const uint64_t tool_tracking_postprocessor_block_size = + std::max(107 * 60 * 7 * 4 * sizeof(float), 7 * 3 * sizeof(float)); + const uint64_t tool_tracking_postprocessor_num_blocks = 2 * 2; auto tool_tracking_postprocessor = make_operator( "tool_tracking_postprocessor", Arg("device_allocator") = make_resource("device_allocator", 1, tool_tracking_postprocessor_block_size, - tool_tracking_postprocessor_num_blocks), - Arg("host_allocator") = make_resource("host_allocator")); + tool_tracking_postprocessor_num_blocks)); if (this->visualizer_name == "holoviz") { std::shared_ptr visualizer_allocator; diff --git a/applications/endoscopy_tool_tracking/python/endoscopy_tool_tracking.py b/applications/endoscopy_tool_tracking/python/endoscopy_tool_tracking.py index 2a724ad82..43054ffdf 100644 --- a/applications/endoscopy_tool_tracking/python/endoscopy_tool_tracking.py +++ b/applications/endoscopy_tool_tracking/python/endoscopy_tool_tracking.py @@ -28,7 +28,6 @@ BlockMemoryPool, CudaStreamPool, MemoryStorageType, - UnboundedAllocator, ) from holohub.lstm_tensor_rt_inference import LSTMTensorRTInferenceOp @@ -114,6 +113,12 @@ def compose(self): directory=video_dir, **self.kwargs("replayer"), ) + # the RMMAllocator supported since v2.6 is much faster than the default UnboundAllocator + try: + from holoscan.resources import RMMAllocator + source.add_arg(allocator=RMMAllocator(self, name="video_replayer_allocator")) + except Exception: + pass # 4 bytes/channel, 3 channels source_block_size = width * height * 3 * 4 source_num_blocks = 2 @@ -133,9 +138,7 @@ def compose(self): pool=BlockMemoryPool(self, name="pool", **source_pool_kwargs), **self.kwargs("recorder_format_converter"), ) - recorder = VideoStreamRecorderOp( - name="recorder", fragment=self, **self.kwargs("recorder") - ) + recorder = VideoStreamRecorderOp(name="recorder", fragment=self, **self.kwargs("recorder")) config_key_name = "format_converter_" + self.source.lower() @@ -177,8 +180,14 @@ def compose(self): **self.kwargs("lstm_inference"), ) - tool_tracking_postprocessor_block_size = 107 * 60 * 7 * 4 - tool_tracking_postprocessor_num_blocks = 2 + # the tool tracking post process outputs + # - a RGBA float32 color mask + # - coordinates with x,y and size in float32 + bytes_per_float32 = 4 + tool_tracking_postprocessor_block_size = max( + 107 * 60 * 7 * 4 * bytes_per_float32, 7 * 3 * bytes_per_float32 + ) + tool_tracking_postprocessor_num_blocks = 2 * 2 tool_tracking_postprocessor = ToolTrackingPostprocessorOp( self, name="tool_tracking_postprocessor", @@ -189,7 +198,6 @@ def compose(self): block_size=tool_tracking_postprocessor_block_size, num_blocks=tool_tracking_postprocessor_num_blocks, ), - host_allocator=UnboundedAllocator(self, name="host_allocator"), ) if (record_type == "visualizer") and (self.source == "replayer"): diff --git a/gxf_extensions/lstm_tensor_rt_inference/tensor_rt_inference.cpp b/gxf_extensions/lstm_tensor_rt_inference/tensor_rt_inference.cpp index 8449ce610..12bfb039f 100644 --- a/gxf_extensions/lstm_tensor_rt_inference/tensor_rt_inference.cpp +++ b/gxf_extensions/lstm_tensor_rt_inference/tensor_rt_inference.cpp @@ -289,12 +289,6 @@ gxf_result_t TensorRtInference::registerInterface(gxf::Registrar* registrar) { "Relaxed Dimension Check", "Ignore dimensions of 1 for input tensor dimension check.", true); - result &= registrar->parameter(clock_, - "clock", - "Clock", - "Instance of clock for publish time.", - gxf::Registrar::NoDefaultParameter(), - GXF_PARAMETER_FLAGS_OPTIONAL); result &= registrar->parameter(rx_, "rx", "RX", "List of receivers to take input tensors"); result &= registrar->parameter(tx_, "tx", "TX", "Transmitter to publish output tensors"); diff --git a/gxf_extensions/lstm_tensor_rt_inference/tensor_rt_inference.hpp b/gxf_extensions/lstm_tensor_rt_inference/tensor_rt_inference.hpp index 6e8c8b31a..b49d78361 100644 --- a/gxf_extensions/lstm_tensor_rt_inference/tensor_rt_inference.hpp +++ b/gxf_extensions/lstm_tensor_rt_inference/tensor_rt_inference.hpp @@ -32,7 +32,6 @@ #include "gxf/cuda/cuda_stream.hpp" #include "gxf/cuda/cuda_stream_pool.hpp" #include "gxf/std/allocator.hpp" -#include "gxf/std/clock.hpp" #include "gxf/std/codelet.hpp" #include "gxf/std/receiver.hpp" #include "gxf/std/tensor.hpp" @@ -110,12 +109,11 @@ class TensorRtInference : public gxf::Codelet { gxf::Parameter> pool_; gxf::Parameter> cuda_stream_pool_; gxf::Parameter max_workspace_size_; - gxf::Parameter dla_core_; + gxf::Parameter dla_core_; gxf::Parameter max_batch_size_; gxf::Parameter enable_fp16_; gxf::Parameter relaxed_dimension_check_; gxf::Parameter verbose_; - gxf::Parameter> clock_; gxf::Parameter>> rx_; gxf::Parameter> tx_; diff --git a/operators/lstm_tensor_rt_inference/README.md b/operators/lstm_tensor_rt_inference/README.md index 294c98e75..352af4f65 100644 --- a/operators/lstm_tensor_rt_inference/README.md +++ b/operators/lstm_tensor_rt_inference/README.md @@ -38,7 +38,7 @@ This implementation is based on `nvidia::gxf::TensorRtInference`. - **`max_workspace_size`**: Size of working space in bytes (default: `67108864l` (64MB)) - type: `int64_t` - **`dla_core`**: DLA Core to use. Fallback to GPU is always enabled. Default to use GPU only (`optional`) - - type: `int64_t` + - type: `int32_t` - **`max_batch_size`**: Maximum possible batch size in case the first dimension is dynamic and used as batch size (default: `1`) - type: `int32_t` - **`enable_fp16_`**: Enable inference with FP16 and FP32 fallback (default: `false`) @@ -47,8 +47,6 @@ This implementation is based on `nvidia::gxf::TensorRtInference`. - type: `bool` - **`relaxed_dimension_check`**: Ignore dimensions of 1 for input tensor dimension check (default: `true`) - type: `bool` -- **`clock`**: Instance of clock for publish time (`optional`) - - type: `gxf::Handle` - **`rx`**: List of receivers to take input tensors - type: `std::vector>` - **`tx`**: Transmitter to publish output tensors diff --git a/operators/lstm_tensor_rt_inference/lstm_tensor_rt_inference.cpp b/operators/lstm_tensor_rt_inference/lstm_tensor_rt_inference.cpp index d4ad3ad8a..8ce9bbf4e 100644 --- a/operators/lstm_tensor_rt_inference/lstm_tensor_rt_inference.cpp +++ b/operators/lstm_tensor_rt_inference/lstm_tensor_rt_inference.cpp @@ -94,7 +94,8 @@ void LSTMTensorRTInferenceOp::setup(OperatorSpec& spec) { "dla_core", "DLA Core", "DLA Core to use. Fallback to GPU is always enabled. " - "Default to use GPU only."); + "Default to use GPU only.", + ParameterFlag::kOptional); spec.param(max_batch_size_, "max_batch_size", "Max Batch Size", @@ -117,7 +118,6 @@ void LSTMTensorRTInferenceOp::setup(OperatorSpec& spec) { "Relaxed Dimension Check", "Ignore dimensions of 1 for input tensor dimension check.", true); - spec.param(clock_, "clock", "Clock", "Instance of clock for publish time."); spec.param(rx_, "rx", "RX", "List of receivers to take input tensors", {&in_tensor}); spec.param(tx_, "tx", "TX", "Transmitter to publish output tensors", &out_tensor); diff --git a/operators/lstm_tensor_rt_inference/lstm_tensor_rt_inference.hpp b/operators/lstm_tensor_rt_inference/lstm_tensor_rt_inference.hpp index 82e267273..3d45eb96f 100644 --- a/operators/lstm_tensor_rt_inference/lstm_tensor_rt_inference.hpp +++ b/operators/lstm_tensor_rt_inference/lstm_tensor_rt_inference.hpp @@ -58,12 +58,11 @@ class LSTMTensorRTInferenceOp : public holoscan::ops::GXFOperator { Parameter> pool_; Parameter> cuda_stream_pool_; Parameter max_workspace_size_; - Parameter dla_core_; + Parameter dla_core_; Parameter max_batch_size_; Parameter enable_fp16_; Parameter relaxed_dimension_check_; Parameter verbose_; - Parameter> clock_; Parameter> rx_; Parameter tx_; diff --git a/operators/lstm_tensor_rt_inference/python/lstm_tensor_rt_inference.cpp b/operators/lstm_tensor_rt_inference/python/lstm_tensor_rt_inference.cpp index 413ad533e..df4653a06 100644 --- a/operators/lstm_tensor_rt_inference/python/lstm_tensor_rt_inference.cpp +++ b/operators/lstm_tensor_rt_inference/python/lstm_tensor_rt_inference.cpp @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,11 +25,11 @@ #include #include -#include "../../operator_util.hpp" #include #include #include #include +#include "../../operator_util.hpp" #include "holoscan/core/resources/gxf/cuda_stream_pool.hpp" using std::string_literals::operator""s; @@ -63,11 +63,8 @@ class PyLSTMTensorRTInferenceOp : public LSTMTensorRTInferenceOp { const std::vector& output_tensor_names, const std::vector& input_binding_names, const std::vector& output_binding_names, const std::string& model_file_path, - const std::string& engine_cache_dir, - // int64_t dla_core, - std::shared_ptr pool, - std::shared_ptr cuda_stream_pool, - // std::shared_ptr clock, + const std::string& engine_cache_dir, std::shared_ptr pool, + std::shared_ptr cuda_stream_pool, std::optional dla_core, const std::string& plugins_lib_namespace = "", const std::vector& input_state_tensor_names = std::vector{}, const std::vector& output_state_tensor_names = std::vector{}, @@ -80,10 +77,8 @@ class PyLSTMTensorRTInferenceOp : public LSTMTensorRTInferenceOp { Arg{"output_binding_names", output_binding_names}, Arg{"model_file_path", model_file_path}, Arg{"engine_cache_dir", engine_cache_dir}, - // Arg{"dla_core", dla_core}, Arg{"pool", pool}, Arg{"cuda_stream_pool", cuda_stream_pool}, - // Arg{"clock", clock}, Arg{"plugins_lib_namespace", plugins_lib_namespace}, Arg{"input_state_tensor_names", input_state_tensor_names}, Arg{"output_state_tensor_names", output_state_tensor_names}, @@ -93,6 +88,7 @@ class PyLSTMTensorRTInferenceOp : public LSTMTensorRTInferenceOp { Arg{"relaxed_dimension_check", relaxed_dimension_check}, Arg{"max_workspace_size", max_workspace_size}, Arg{"max_batch_size", max_batch_size}}) { + if (dla_core.has_value()) { add_arg(Arg{"dla_core", dla_core.value()}); } add_positional_condition_and_resource_args(this, args); name_ = name; fragment_ = fragment; @@ -131,10 +127,9 @@ PYBIND11_MODULE(_lstm_tensor_rt_inference, m) { const std::vector&, const std::string&, const std::string&, - // int64_t, // dla_core std::shared_ptr, std::shared_ptr, - // std::shared_ptr, // clock + std::optional, const std::string&, const std::vector&, const std::vector&, @@ -152,10 +147,9 @@ PYBIND11_MODULE(_lstm_tensor_rt_inference, m) { "output_binding_names"_a, "model_file_path"_a, "engine_cache_dir"_a, - // "dla_core"_a, "pool"_a, "cuda_stream_pool"_a, - // "clock"_a, + "dla_core"_a = py::none(), "plugins_lib_namespace"_a = "", "input_state_tensor_names"_a = std::vector{}, "output_state_tensor_names"_a = std::vector{}, diff --git a/operators/tool_tracking_postprocessor/CMakeLists.txt b/operators/tool_tracking_postprocessor/CMakeLists.txt index cc90a3e54..ccbaad72a 100644 --- a/operators/tool_tracking_postprocessor/CMakeLists.txt +++ b/operators/tool_tracking_postprocessor/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023-2034 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. cmake_minimum_required(VERSION 3.20) + project(tool_tracking_postprocessor LANGUAGES CXX CUDA) find_package(holoscan REQUIRED CONFIG @@ -25,10 +26,21 @@ add_library(tool_tracking_postprocessor SHARED tool_tracking_postprocessor.cuh ) -set_target_properties(tool_tracking_postprocessor PROPERTIES CUDA_ARCHITECTURES "70;80") +set_target_properties(tool_tracking_postprocessor + PROPERTIES + # separable compilation is required since we launch kernels from within kernels + CUDA_SEPARABLE_COMPILATION ON + ) -target_link_libraries(tool_tracking_postprocessor holoscan::core) -target_include_directories(tool_tracking_postprocessor INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(tool_tracking_postprocessor + PRIVATE + holoscan::core + ) + +target_include_directories(tool_tracking_postprocessor + INTERFACE + ${CMAKE_CURRENT_SOURCE_DIR} + ) if(HOLOHUB_BUILD_PYTHON) add_subdirectory(python) diff --git a/operators/tool_tracking_postprocessor/python/tool_tracking_postprocessor.cpp b/operators/tool_tracking_postprocessor/python/tool_tracking_postprocessor.cpp index 820fb449c..6278f35fb 100644 --- a/operators/tool_tracking_postprocessor/python/tool_tracking_postprocessor.cpp +++ b/operators/tool_tracking_postprocessor/python/tool_tracking_postprocessor.cpp @@ -73,12 +73,11 @@ class PyToolTrackingPostprocessorOp : public ToolTrackingPostprocessorOp { // Define a constructor that fully initializes the object. PyToolTrackingPostprocessorOp( Fragment* fragment, const py::args& args, std::shared_ptr device_allocator, - std::shared_ptr host_allocator, float min_prob = 0.5f, + float min_prob = 0.5f, std::vector> overlay_img_colors = VIZ_TOOL_DEFAULT_COLORS, std::shared_ptr cuda_stream_pool = nullptr, const std::string& name = "tool_tracking_postprocessor") : ToolTrackingPostprocessorOp(ArgList{Arg{"device_allocator", device_allocator}, - Arg{"host_allocator", host_allocator}, Arg{"min_prob", min_prob}, Arg{"overlay_img_colors", overlay_img_colors}}) { if (cuda_stream_pool) { this->add_arg(Arg{"cuda_stream_pool", cuda_stream_pool}); } @@ -116,14 +115,12 @@ PYBIND11_MODULE(_tool_tracking_postprocessor, m) { .def(py::init, - std::shared_ptr, float, std::vector>, std::shared_ptr, const std::string&>(), "fragment"_a, "device_allocator"_a, - "host_allocator"_a, "min_prob"_a = 0.5f, "overlay_img_colors"_a = VIZ_TOOL_DEFAULT_COLORS, "cuda_stream_pool"_a = py::none(), diff --git a/operators/tool_tracking_postprocessor/python/tool_tracking_postprocessor_pydoc.hpp b/operators/tool_tracking_postprocessor/python/tool_tracking_postprocessor_pydoc.hpp index 23c48652c..339c26827 100644 --- a/operators/tool_tracking_postprocessor/python/tool_tracking_postprocessor_pydoc.hpp +++ b/operators/tool_tracking_postprocessor/python/tool_tracking_postprocessor_pydoc.hpp @@ -40,7 +40,7 @@ Operator performing post-processing for the endoscopy tool tracking demo. **==Named Outputs==** out_coords : nvidia::gxf::Tensor - Coordinates tensor, stored on the host (CPU). + Coordinates tensor, stored on the device (GPU). out_mask : nvidia::gxf::Tensor Binary mask tensor, stored on device (GPU). @@ -51,8 +51,6 @@ fragment : Fragment The fragment that the operator belongs to. device_allocator : ``holoscan.resources.Allocator`` Output allocator used on the device side. -host_allocator : ``holoscan.resources.Allocator`` - Output allocator used on the host side. min_prob : float, optional Minimum probability (in range [0, 1]). Default value is 0.5. overlay_img_colors : sequence of sequence of float, optional diff --git a/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cpp b/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cpp index 0019d4c30..fcae481f6 100644 --- a/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cpp +++ b/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cpp @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,33 +22,11 @@ #include #include -#include "holoscan/core/fragment.hpp" -#include "holoscan/core/gxf/entity.hpp" -#include "holoscan/core/execution_context.hpp" -#include "holoscan/core/io_context.hpp" -#include "holoscan/core/operator_spec.hpp" - -#include "holoscan/core/conditions/gxf/boolean.hpp" -#include "holoscan/core/resources/gxf/allocator.hpp" -#include "holoscan/core/resources/gxf/cuda_stream_pool.hpp" - -#include "gxf/std/tensor.hpp" - -using holoscan::ops::tool_tracking_postprocessor::cuda_postprocess; - -#define CUDA_TRY(stmt) \ - ({ \ - cudaError_t _holoscan_cuda_err = stmt; \ - if (cudaSuccess != _holoscan_cuda_err) { \ - GXF_LOG_ERROR("CUDA Runtime call %s in line %d of file %s failed with '%s' (%d).", \ - #stmt, \ - __LINE__, \ - __FILE__, \ - cudaGetErrorString(_holoscan_cuda_err), \ - _holoscan_cuda_err); \ - } \ - _holoscan_cuda_err; \ - }) +#include + +#include + +#include "tool_tracking_postprocessor.cuh" namespace holoscan::ops { @@ -88,19 +66,22 @@ void ToolTrackingPostprocessorOp::setup(OperatorSpec& spec) { "Color of the image overlays, a list of RGB values with components between 0 and 1", DEFAULT_COLORS); - spec.param(host_allocator_, "host_allocator", "Allocator", "Output Allocator"); spec.param(device_allocator_, "device_allocator", "Allocator", "Output Allocator"); cuda_stream_handler_.define_params(spec); } +void ToolTrackingPostprocessorOp::stop() { + if (dev_colors_) { + CUDA_TRY(cudaFree(dev_colors_)); + dev_colors_ = nullptr; + } +} + void ToolTrackingPostprocessorOp::compute(InputContext& op_input, OutputContext& op_output, ExecutionContext& context) { // The type of `in_message` is 'holoscan::gxf::Entity'. auto in_message = op_input.receive("in").value(); - auto maybe_tensor = in_message.get("probs"); - if (!maybe_tensor) { throw std::runtime_error("Tensor 'probs' not found in message."); } - auto probs_tensor = maybe_tensor; // get the CUDA stream from the input message gxf_result_t stream_handler_result = @@ -109,118 +90,105 @@ void ToolTrackingPostprocessorOp::compute(InputContext& op_input, OutputContext& throw std::runtime_error("Failed to get the CUDA stream from incoming messages"); } - std::vector probs(probs_tensor->size()); - CUDA_TRY(cudaMemcpyAsync(probs.data(), - probs_tensor->data(), - probs_tensor->nbytes(), - cudaMemcpyDeviceToHost, - cuda_stream_handler_.get_cuda_stream(context.context()))); + auto maybe_tensor = in_message.get("probs"); + if (!maybe_tensor) { throw std::runtime_error("Tensor 'probs' not found in message."); } + auto probs_tensor = maybe_tensor; maybe_tensor = in_message.get("scaled_coords"); if (!maybe_tensor) { throw std::runtime_error("Tensor 'scaled_coords' not found in message."); } auto scaled_coords_tensor = maybe_tensor; - std::vector scaled_coords(scaled_coords_tensor->size()); - CUDA_TRY(cudaMemcpyAsync(scaled_coords.data(), - scaled_coords_tensor->data(), - scaled_coords_tensor->nbytes(), - cudaMemcpyDeviceToHost, - cuda_stream_handler_.get_cuda_stream(context.context()))); - maybe_tensor = in_message.get("binary_masks"); if (!maybe_tensor) { throw std::runtime_error("Tensor 'binary_masks' not found in message."); } auto binary_masks_tensor = maybe_tensor; - // Create a new message (nvidia::nvidia::gxf::Entity) for host tensor(s) - auto out_message_host = nvidia::gxf::Entity::New(context.context()); - - // filter coordinates based on probability - std::vector visible_classes; - { - // wait for the CUDA memory copy to finish - CUDA_TRY(cudaStreamSynchronize(cuda_stream_handler_.get_cuda_stream(context.context()))); - - std::vector filtered_scaled_coords; - for (size_t index = 0; index < probs.size(); ++index) { - if (probs[index] > min_prob_) { - filtered_scaled_coords.push_back(scaled_coords[index * 2]); - filtered_scaled_coords.push_back(scaled_coords[index * 2 + 1]); - visible_classes.push_back(index); - } else { - filtered_scaled_coords.push_back(-1.f); - filtered_scaled_coords.push_back(-1.f); - } - } + // get Handle to underlying nvidia::gxf::Allocator from std::shared_ptr + auto device_allocator = nvidia::gxf::Handle::Create( + context.context(), device_allocator_.get()->gxf_cid()); - auto out_coords_tensor = out_message_host.value().add("scaled_coords"); - if (!out_coords_tensor) { - throw std::runtime_error("Failed to allocate output tensor 'scaled_coords'"); - } + // Create a new message (nvidia::nvidia::gxf::Entity) for the scaled coords + auto out_coords_message = nvidia::gxf::Entity::New(context.context()); - // get Handle to underlying nvidia::gxf::Allocator from std::shared_ptr - auto host_allocator = nvidia::gxf::Handle::Create( - context.context(), host_allocator_.get()->gxf_cid()); + // Create a new tensor for the scaled coords + auto out_coords_tensor = out_coords_message.value().add("scaled_coords"); + if (!out_coords_tensor) { + throw std::runtime_error("Failed to allocate output tensor 'scaled_coords'"); + } - const nvidia::gxf::Shape output_shape{1, int32_t(filtered_scaled_coords.size() / 2), 2}; - out_coords_tensor.value()->reshape( - output_shape, nvidia::gxf::MemoryStorageType::kHost, host_allocator.value()); - if (!out_coords_tensor.value()->pointer()) { - throw std::runtime_error( - "Failed to allocate output tensor buffer for tensor 'scaled_coords'."); - } - memcpy(out_coords_tensor.value()->data().value(), - filtered_scaled_coords.data(), - filtered_scaled_coords.size() * sizeof(float)); + const nvidia::gxf::Shape coords_shape{int32_t(probs_tensor->size()), 3}; + out_coords_tensor.value()->reshape( + coords_shape, nvidia::gxf::MemoryStorageType::kDevice, device_allocator.value()); + if (!out_coords_tensor.value()->pointer()) { + throw std::runtime_error("Failed to allocate output tensor buffer for tensor 'scaled_coords'."); + } + + // Create a new message (nvidia::nvidia::gxf::Entity) for the mask + auto out_mask_message = nvidia::gxf::Entity::New(context.context()); + + // Create a new tensor for the mask + auto out_mask_tensor = out_mask_message.value().add("mask"); + if (!out_mask_tensor) { throw std::runtime_error("Failed to allocate output tensor 'mask'"); } + + const nvidia::gxf::Shape mask_shape{static_cast(binary_masks_tensor->shape()[2]), + static_cast(binary_masks_tensor->shape()[3]), + 4}; + out_mask_tensor.value()->reshape( + mask_shape, nvidia::gxf::MemoryStorageType::kDevice, device_allocator.value()); + if (!out_mask_tensor.value()->pointer()) { + throw std::runtime_error("Failed to allocate output tensor buffer for tensor 'mask'."); } - // Create a new message (nvidia::nvidia::gxf::Entity) for device tensor(s) - auto out_message_device = nvidia::gxf::Entity::New(context.context()); - - // filter binary mask - { - auto out_mask_tensor = out_message_device.value().add("mask"); - if (!out_mask_tensor) { throw std::runtime_error("Failed to allocate output tensor 'mask'"); } - - // get Handle to underlying nvidia::gxf::Allocator from std::shared_ptr - auto device_allocator = nvidia::gxf::Handle::Create( - context.context(), device_allocator_.get()->gxf_cid()); - - const nvidia::gxf::Shape output_shape{static_cast(binary_masks_tensor->shape()[2]), - static_cast(binary_masks_tensor->shape()[3]), - 4}; - out_mask_tensor.value()->reshape( - output_shape, nvidia::gxf::MemoryStorageType::kDevice, device_allocator.value()); - if (!out_mask_tensor.value()->pointer()) { - throw std::runtime_error("Failed to allocate output tensor buffer for tensor 'mask'."); + const cudaStream_t cuda_stream = cuda_stream_handler_.get_cuda_stream(context.context()); + + if (num_colors_ != probs_tensor->size()) { + num_colors_ = probs_tensor->size(); + if (dev_colors_) { + CUDA_TRY(cudaFree(dev_colors_)); + dev_colors_ = nullptr; } + } + + if (!dev_colors_) { + // copy colors to CUDA device memory, this is needed by the postprocessing kernel + CUDA_TRY(cudaMalloc(&dev_colors_, num_colors_ * sizeof(float3))); - float* const out_data = out_mask_tensor.value()->data().value(); - const size_t layer_size = output_shape.dimension(0) * output_shape.dimension(1); - bool first = true; - for (auto& index : visible_classes) { + // build a vector with the colors, if more colors are required than specified, repeat the + // last color + std::vector colors; + for (auto index = 0; index < num_colors_; ++index) { const auto& img_color = - overlay_img_colors_.get()[std::min(index, uint32_t(overlay_img_colors_.get().size()))]; - const std::array color{{img_color[0], img_color[1], img_color[2]}}; - cuda_postprocess(output_shape.dimension(0), - output_shape.dimension(1), - color, - first, - static_cast(binary_masks_tensor->data()) + index * layer_size, - reinterpret_cast(out_data), - cuda_stream_handler_.get_cuda_stream(context.context())); - first = false; + overlay_img_colors_.get()[std::min(index, int(overlay_img_colors_.get().size()))]; + colors.push_back(make_float3(img_color[0], img_color[1], img_color[2])); } + + CUDA_TRY(cudaMemcpyAsync(dev_colors_, + colors.data(), + num_colors_ * sizeof(float3), + cudaMemcpyHostToDevice, + cuda_stream)); } - // pass the CUDA stream to the output message - stream_handler_result = cuda_stream_handler_.to_message(out_message_device); + // filter coordinates based on probability and create a colored mask from the binary mask + cuda_postprocess(probs_tensor->size(), + min_prob_, + reinterpret_cast(probs_tensor->data()), + reinterpret_cast(scaled_coords_tensor->data()), + reinterpret_cast(out_coords_tensor.value()->pointer()), + mask_shape.dimension(0), + mask_shape.dimension(1), + reinterpret_cast(dev_colors_), + reinterpret_cast(binary_masks_tensor->data()), + reinterpret_cast(out_mask_tensor.value()->pointer()), + cuda_stream); + // pass the CUDA stream to the output message + stream_handler_result = cuda_stream_handler_.to_message(out_mask_message); if (stream_handler_result != GXF_SUCCESS) { throw std::runtime_error("Failed to add the CUDA stream to the outgoing messages"); } - auto result_host = gxf::Entity(std::move(out_message_host.value())); - auto result_device = gxf::Entity(std::move(out_message_device.value())); + auto result_host = gxf::Entity(std::move(out_coords_message.value())); + auto result_device = gxf::Entity(std::move(out_mask_message.value())); op_output.emit(result_host, "out_coords"); op_output.emit(result_device, "out_mask"); } diff --git a/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cu b/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cu index 2eb1de6e7..5c4557210 100644 --- a/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cu +++ b/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,16 +20,20 @@ #include "tool_tracking_postprocessor.cuh" namespace holoscan::ops { -namespace tool_tracking_postprocessor { -__global__ void postprocessing_kernel(uint32_t width, uint32_t height, const float3 color, - bool first, const float* input, float4* output) { +static __device__ __host__ uint32_t ceil_div(uint32_t numerator, uint32_t denominator) { + return (numerator + denominator - 1) / denominator; +} + +__global__ void filter_binary_mask_kernel(uint32_t width, uint32_t height, uint32_t index, + const float3* colors, const float* binary_mask, + float4* colored_mask) { const uint32_t x = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t y = blockIdx.y * blockDim.y + threadIdx.y; if ((x >= width) || (y >= height)) { return; } - float value = input[y * width + x]; + float value = binary_mask[((index * height) + y) * width + x]; const float minV = 0.3f; const float maxV = 0.99f; @@ -39,25 +43,57 @@ __global__ void postprocessing_kernel(uint32_t width, uint32_t height, const flo value /= range; value *= 0.7f; - const float4 dst = first ? make_float4(0.f, 0.f, 0.f, 0.f) : output[y * width + x]; - output[y * width + x] = make_float4((1.0f - value) * dst.x + color.x * value, - (1.0f - value) * dst.y + color.y * value, - (1.0f - value) * dst.z + color.z * value, - (1.0f - value) * dst.w + 1.f * value); + const float4 dst = colored_mask[y * width + x]; + colored_mask[y * width + x] = make_float4((1.0f - value) * dst.x + colors[index].x * value, + (1.0f - value) * dst.y + colors[index].y * value, + (1.0f - value) * dst.z + colors[index].z * value, + (1.0f - value) * dst.w + 1.f * value); } -uint16_t ceil_div(uint16_t numerator, uint16_t denominator) { - uint32_t accumulator = numerator + denominator - 1; - return accumulator / denominator; +__global__ void filter_coordinates_kernel(uint32_t count, float min_prob, const float* probs, + const float2* scaled_coords, + float3* filtered_scaled_coords, uint32_t width, + uint32_t height, const float3* colors, + const float* binary_mask, float4* colored_mask) { + const uint32_t index = blockIdx.x * blockDim.x + threadIdx.x; + // the third component of the coordinate is the size of the crosses and the text + constexpr float ITEM_SIZE = 0.05f; + + if (index > count) { return; } + + // check if the probability meets the minimum probability + if (probs[index] > min_prob) { + filtered_scaled_coords[index] = + make_float3(scaled_coords[index].x, scaled_coords[index].y, ITEM_SIZE); + // add the binary mask to the result only if probabiliy is met + const dim3 block(32, 32, 1); + const dim3 grid(ceil_div(width, block.x), ceil_div(height, block.y), 1); + filter_binary_mask_kernel<<>>( + width, height, index, colors, binary_mask, colored_mask); + } else { + // move outside of the screen + filtered_scaled_coords[index] = make_float3(-1.f, -1.f, ITEM_SIZE); + } } +void cuda_postprocess(uint32_t count, float min_prob, const float* probs, + const float2* scaled_coords, float3* filtered_scaled_coords, uint32_t width, + uint32_t height, const float3* colors, const float* binary_mask, + float4* colored_mask, cudaStream_t cuda_stream) { + // initialize the output mask to zero + CUDA_TRY(cudaMemsetAsync(colored_mask, 0, width * height * sizeof(float4))); -void cuda_postprocess(uint32_t width, uint32_t height, const std::array& color, - bool first, const float* input, float4* output, cudaStream_t cuda_stream) { - const dim3 block(32, 32, 1); - const dim3 grid(ceil_div(width, block.x), ceil_div(height, block.y), 1); - postprocessing_kernel<<>>( - width, height, make_float3(color[0], color[1], color[2]), first, input, output); + const dim3 block(32, 1, 1); + const dim3 grid(ceil_div(count, block.x), 1, 1); + filter_coordinates_kernel<<>>(count, + min_prob, + probs, + scaled_coords, + filtered_scaled_coords, + width, + height, + colors, + binary_mask, + colored_mask); } -} // namespace tool_tracking_postprocessor } // namespace holoscan::ops diff --git a/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cuh b/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cuh index eaf584870..726faa540 100644 --- a/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cuh +++ b/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,11 +20,27 @@ #include #include +#include + +#define CUDA_TRY(stmt) \ + ({ \ + cudaError_t _holoscan_cuda_err = stmt; \ + if (cudaSuccess != _holoscan_cuda_err) { \ + HOLOSCAN_LOG_ERROR("CUDA Runtime call {} in line {} of file {} failed with '{}' ({}).", \ + #stmt, \ + __LINE__, \ + __FILE__, \ + cudaGetErrorString(_holoscan_cuda_err), \ + static_cast(_holoscan_cuda_err)); \ + } \ + _holoscan_cuda_err; \ + }) + namespace holoscan::ops { -namespace tool_tracking_postprocessor { -void cuda_postprocess(uint32_t width, uint32_t height, const std::array& color, - bool first, const float* input, float4* output, cudaStream_t cuda_stream); +void cuda_postprocess(uint32_t count, float min_prob, const float* probs, + const float2* scaled_coords, float3* filtered_scaled_coords, uint32_t width, + uint32_t height, const float3 *colors, const float* binary_mask, + float4* colored_mask, cudaStream_t cuda_stream); -} // namespace tool_tracking_postprocessor } // namespace holoscan::ops diff --git a/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.hpp b/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.hpp index 18d6adc60..72f533efb 100644 --- a/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.hpp +++ b/operators/tool_tracking_postprocessor/tool_tracking_postprocessor.hpp @@ -22,8 +22,6 @@ #include #include -#include "tool_tracking_postprocessor.cuh" - #include "holoscan/core/gxf/gxf_operator.hpp" #include "holoscan/utils/cuda_stream_handler.hpp" @@ -42,17 +40,15 @@ namespace holoscan::ops { * ==Named Outputs== * * - **out_coords** : `nvidia::gxf::Tensor` - * - Coordinates tensor, stored on the host (CPU). + * - Coordinates tensor, stored on the device (GPU). * * - **out_mask** : `nvidia::gxf::Tensor` * - Binary mask tensor, stored on device (GPU). * * ==Parameters== * - * - **host_allocator**: The holoscan::Allocator class (e.g. UnboundedAllocator) use for host - * memory allocation of the `out_coords` tensor. * - **device_allocator**: The holoscan::Allocator class (e.g. UnboundedAllocator or - * BlockMemoryPool) used for device memory allocation for the `out_mask` tensor. + * BlockMemoryPool) used for device memory allocation for the `out_coords` and `out_mask` tensor. * - **min_prob**: Minimum probability threshold used by the operator. * Optional (default: 0.5). * - **overlay_img_colors**: A `vector>` where each inner vector is a set of three @@ -68,6 +64,7 @@ class ToolTrackingPostprocessorOp : public holoscan::Operator { ToolTrackingPostprocessorOp() = default; void setup(OperatorSpec& spec) override; + void stop() override; void compute(InputContext& op_input, OutputContext& op_output, ExecutionContext& context) override; @@ -79,10 +76,12 @@ class ToolTrackingPostprocessorOp : public holoscan::Operator { Parameter min_prob_; Parameter>> overlay_img_colors_; - Parameter> host_allocator_; Parameter> device_allocator_; CudaStreamHandler cuda_stream_handler_; + + uint32_t num_colors_ = 0; + void *dev_colors_ = nullptr; }; } // namespace holoscan::ops diff --git a/operators/vtk_renderer/vtk_renderer.cpp b/operators/vtk_renderer/vtk_renderer.cpp index ee65420cd..a089d419d 100644 --- a/operators/vtk_renderer/vtk_renderer.cpp +++ b/operators/vtk_renderer/vtk_renderer.cpp @@ -145,11 +145,11 @@ void VtkRendererOp::compute(InputContext& op_input, OutputContext&, ExecutionCon this->internals->foreground_renderer->RemoveAllViewProps(); - // scale_coords comes in the format [X0 Y0 X1 Y1 ... Xn Yn] + // scale_coords comes in the format [X0 Y0 S0 X1 Y1 S1 ... Xn Yn Sn] // each numbered tuple represent a label (scissors, clipper...) - for (int i = 0; i < scaled_coords.size(); i += 2) { + for (int i = 0; i < scaled_coords.size(); i += 3) { if (scaled_coords[i] > 0) { - std::string label = labels.get()[i / 2]; + std::string label = labels.get()[i / 3]; float x = scaled_coords[i]; float y = scaled_coords[i + 1]; render_text_at_location(label, x, y, this->internals->foreground_renderer);