From 408c18bb0d8a7c1b300e02fd7f6bb58369fdf4c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Wed, 11 Sep 2024 20:06:14 +0200 Subject: [PATCH] Enable GPU->CPU transfers (#5593) * Add experimental_exec_dynamic flag to Pipeline to enable new executor * Add DataNode.cpu() that triggers a GPU->CPU copy * Remove checks that prevented GPU->CPU transitions from Python and Pipeline class * Remove checks that prevented CPU operators from taking GPU inputs * Use old executor's graph lowering to run the checks * Add cpu->gpu tests * TODO: Improve input backend checks (#5631) * TODO: Add tensorflow support --------- Signed-off-by: Michal Zientkiewicz --- dali/c_api/c_api.cc | 33 +++- dali/pipeline/executor/executor_factory.cc | 26 +-- dali/pipeline/executor/executor_factory.h | 2 +- dali/pipeline/executor/lowered_graph.cc | 25 ++- dali/pipeline/operator/builtin/copy.cc | 40 +++- dali/pipeline/operator/builtin/copy.cu | 27 --- dali/pipeline/operator/builtin/copy.h | 36 +--- .../operator/checkpointing/checkpoint_test.cc | 2 +- dali/pipeline/pipeline.cc | 173 ++++++++---------- dali/pipeline/pipeline.h | 59 +++--- dali/pipeline/pipeline_test.cc | 41 +---- dali/python/backend_impl.cc | 21 ++- dali/python/nvidia/dali/_utils/eager_utils.py | 7 +- dali/python/nvidia/dali/data_node.py | 15 +- dali/python/nvidia/dali/fn/__init__.py | 7 +- dali/python/nvidia/dali/pipeline.py | 5 + dali/test/python/operator_1/test_slice.py | 4 +- dali/test/python/test_pipeline.py | 112 +++++++++++- include/dali/c_api.h | 48 ++++- 19 files changed, 415 insertions(+), 268 deletions(-) delete mode 100644 dali/pipeline/operator/builtin/copy.cu diff --git a/dali/c_api/c_api.cc b/dali/c_api/c_api.cc index c56ff3ea48..13eabadadb 100644 --- a/dali/c_api/c_api.cc +++ b/dali/c_api/c_api.cc @@ -246,13 +246,33 @@ daliCreatePipeline2(daliPipelineHandle *pipe_handle, const char *serialized_pipe int async_execution, int separated_execution, int prefetch_queue_depth, int cpu_prefetch_queue_depth, int gpu_prefetch_queue_depth, int enable_memory_stats) { - bool se = separated_execution != 0; - bool pe = pipelined_execution != 0; - bool ae = async_execution != 0; + dali_exec_flags_t flags = {}; + if (async_execution) + flags = flags | DALI_EXEC_IS_ASYNC; + if (pipelined_execution) + flags = flags | DALI_EXEC_IS_PIPELINED; + if (separated_execution) + flags = flags | DALI_EXEC_IS_SEPARATED; + daliCreatePipeline3(pipe_handle, serialized_pipeline, length, + max_batch_size, num_threads, device_id, flags, + prefetch_queue_depth, cpu_prefetch_queue_depth, gpu_prefetch_queue_depth, + enable_memory_stats); +} - auto pipeline = - std::make_unique(std::string(serialized_pipeline, length), max_batch_size, - num_threads, device_id, pe, prefetch_queue_depth, ae); +DLL_PUBLIC void +daliCreatePipeline3(daliPipelineHandle *pipe_handle, const char *serialized_pipeline, int length, + int max_batch_size, int num_threads, int device_id, + dali_exec_flags_t exec_flags, int prefetch_queue_depth, + int cpu_prefetch_queue_depth, int gpu_prefetch_queue_depth, + int enable_memory_stats) { + bool se = exec_flags & DALI_EXEC_IS_SEPARATED; + bool pe = exec_flags & DALI_EXEC_IS_PIPELINED; + bool ae = exec_flags & DALI_EXEC_IS_ASYNC; + bool de = exec_flags & DALI_EXEC_IS_DYNAMIC; + + auto pipeline = std::make_unique( + std::string(serialized_pipeline, length), max_batch_size, num_threads, device_id, + pe, prefetch_queue_depth, ae, de); pipeline->SetExecutionTypes(pe, se, ae); if (se) { pipeline->SetQueueSizes(cpu_prefetch_queue_depth, gpu_prefetch_queue_depth); @@ -263,7 +283,6 @@ daliCreatePipeline2(daliPipelineHandle *pipe_handle, const char *serialized_pipe *pipe_handle = WrapPipeline(std::move(pipeline)).release(); } - void daliDeserializeDefault(daliPipelineHandle *pipe_handle, const char *serialized_pipeline, int length) { auto pipeline = std::make_unique(std::string(serialized_pipeline, length)); diff --git a/dali/pipeline/executor/executor_factory.cc b/dali/pipeline/executor/executor_factory.cc index a89a4b02b6..2d2f84d943 100644 --- a/dali/pipeline/executor/executor_factory.cc +++ b/dali/pipeline/executor/executor_factory.cc @@ -58,40 +58,44 @@ bool ForceExec2() { } // namespace template -std::unique_ptr GetExecutorImpl(bool pipelined, bool separated, bool async, - T&&... args) { - if (async && separated && pipelined) { +std::unique_ptr GetExecutorImpl( + bool pipelined, bool separated, bool async, bool dynamic, T&&... args) { + // Go over possible combinations and throw otherwise. + if (async && separated && pipelined && !dynamic) { return std::make_unique(std::forward(args)...); } else if (async && !separated && pipelined) { - if (ForceExec2()) { - std::cerr << "\n!!! Forced use of Executor 2.0 !!!" << std::endl; + bool force_exec2 = ForceExec2(); + if (dynamic || force_exec2) { + if (force_exec2) + std::cerr << "\n!!! Forced use of Executor 2.0 !!!" << std::endl; return std::make_unique(MakeExec2Config(std::forward(args)...)); } else { return std::make_unique(std::forward(args)...); } - } else if (!async && separated && pipelined) { + } else if (!async && separated && pipelined && !dynamic) { return std::make_unique(std::forward(args)...); - } else if (!async && !separated && pipelined) { + } else if (!async && !separated && pipelined && !dynamic) { return std::make_unique(std::forward(args)...); - } else if (!async && !separated && !pipelined) { + } else if (!async && !separated && !pipelined && !dynamic) { return std::make_unique(std::forward(args)...); } std::stringstream error; error << std::boolalpha; error << "No supported executor selected for pipelined = " << pipelined - << ", separated = " << separated << ", async = " << async << std::endl; + << ", separated = " << separated << ", async = " << async + << ", dynamic = " << dynamic << std::endl; DALI_FAIL(error.str()); } -std::unique_ptr GetExecutor(bool pipelined, bool separated, bool async, +std::unique_ptr GetExecutor(bool pipelined, bool separated, bool async, bool dynamic, int batch_size, int num_thread, int device_id, size_t bytes_per_sample_hint, bool set_affinity, int max_num_stream, int default_cuda_stream_priority, QueueSizes prefetch_queue_depth) { return GetExecutorImpl( - pipelined, separated, async, + pipelined, separated, async, dynamic, batch_size, num_thread, device_id, bytes_per_sample_hint, set_affinity, max_num_stream, default_cuda_stream_priority, prefetch_queue_depth); } diff --git a/dali/pipeline/executor/executor_factory.h b/dali/pipeline/executor/executor_factory.h index ec8cc22b94..d8c5f684c7 100644 --- a/dali/pipeline/executor/executor_factory.h +++ b/dali/pipeline/executor/executor_factory.h @@ -23,7 +23,7 @@ namespace dali { DLL_PUBLIC -std::unique_ptr GetExecutor(bool pipelined, bool separated, bool async, +std::unique_ptr GetExecutor(bool pipelined, bool separated, bool async, bool dynamic, int batch_size, int num_thread, int device_id, size_t bytes_per_sample_hint, bool set_affinity = false, int max_num_stream = -1, diff --git a/dali/pipeline/executor/lowered_graph.cc b/dali/pipeline/executor/lowered_graph.cc index eec420d300..cfa93221a5 100644 --- a/dali/pipeline/executor/lowered_graph.cc +++ b/dali/pipeline/executor/lowered_graph.cc @@ -97,8 +97,18 @@ void OpGraph::Lower(const graph::OpGraph &definition) { if (!op_nodes_.empty() || !tensor_nodes_.empty()) throw std::logic_error("The target graph must be empty"); for (auto &node : definition.OpNodes()) { - auto &lowered_op = AddOp(node.spec, node.instance_name); - lowered_op.definition = &node; + try { + auto &lowered_op = AddOp(node.spec, node.instance_name); + lowered_op.definition = &node; + } catch (...) { + PropagateError({ + std::current_exception(), + make_string( + "Critical error when building pipeline:\n", + GetErrorContextMessage(node.spec)), + "\nCurrent pipeline object is no longer valid." + }); + } } for (auto &t : tensor_nodes_) { t.definition = definition.GetData(t.name); @@ -131,14 +141,19 @@ OpNode &OpGraph::AddOp(const OpSpec &spec, const std::string &op_name) { // Validate the op specification CheckOpConstraints(spec); + const char *gpu2cpu_error = + "This pipeline doesn't support transition from GPU to CPU.\n" + "To enable GPU->CPU transitions, use the experimental \"dynamic\" executor.\n" + "Specify experimental_exec_dynamic=True in your Pipeline constructor or @pipeline_def."; + string device = spec.GetArgument("device"); auto op_type = ParseOpType(device); // TODO(klecki): refactor this out switch (op_type) { case OpType::CPU: { // Enforce graph constraints - DALI_ENFORCE(AllInputsCPU(spec), "CPU ops cannot receive GPU input data."); - DALI_ENFORCE(AllOutputsCPU(spec), "CPU ops can only produce CPU output data."); + DALI_ENFORCE(AllInputsCPU(spec), gpu2cpu_error); + DALI_ENFORCE(AllOutputsCPU(spec), "CPU operators can only produce CPU output data."); break; } case OpType::GPU: { @@ -146,7 +161,7 @@ OpNode &OpGraph::AddOp(const OpSpec &spec, const std::string &op_name) { } case OpType::MIXED: { // Enforce graph constraints - DALI_ENFORCE(AllInputsCPU(spec), "Mixed ops cannot receive GPU input data."); + DALI_ENFORCE(AllInputsCPU(spec), gpu2cpu_error); break; } default: diff --git a/dali/pipeline/operator/builtin/copy.cc b/dali/pipeline/operator/builtin/copy.cc index e3523b7e29..8588bc92f0 100644 --- a/dali/pipeline/operator/builtin/copy.cc +++ b/dali/pipeline/operator/builtin/copy.cc @@ -17,11 +17,47 @@ namespace dali { template <> -void Copy::RunCopies(Workspace &ws) { - scatter_gather_.Run(ws.GetThreadPool(), true); +void Copy::RunImpl(Workspace &ws) { + if (ws.InputIsType(0)) { + auto &input = ws.Input(0); + auto &output = ws.Output(0); + + int batch_size = input.num_samples(); + output.SetLayout(input.GetLayout()); + auto shapes = input.shape(); + + auto &thread_pool = ws.GetThreadPool(); + for (int sample_id = 0; sample_id < batch_size; ++sample_id) { + thread_pool.AddWork( + [sample_id, &input, &output](int tid) { + output.CopySample(sample_id, input, sample_id, AccessOrder::host()); + }, + shapes.tensor_size(sample_id)); + } + thread_pool.RunAll(); + } else { + auto &input = ws.Input(0); + auto &output = ws.Output(0); + output.Copy(input, ws.output_order()); + } +} + +template <> +void Copy::RunImpl(Workspace &ws) { + if (ws.InputIsType(0)) { + auto &input = ws.Input(0); + auto &output = ws.Output(0); + output.Copy(input, ws.output_order()); + } else { + auto &input = ws.Input(0); + auto &output = ws.Output(0); + output.Copy(input, ws.output_order()); + } } DALI_REGISTER_OPERATOR(Copy, Copy, CPU); +DALI_REGISTER_OPERATOR(Copy, Copy, GPU); + DALI_SCHEMA(Copy) .DocStr("Creates a copy of the input tensor.") diff --git a/dali/pipeline/operator/builtin/copy.cu b/dali/pipeline/operator/builtin/copy.cu deleted file mode 100644 index 3a8e41795e..0000000000 --- a/dali/pipeline/operator/builtin/copy.cu +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "dali/pipeline/operator/builtin/copy.h" - -namespace dali { - -template <> -void Copy::RunCopies(Workspace &ws) { - scatter_gather_.Run(ws.stream(), true); -} - -DALI_REGISTER_OPERATOR(Copy, Copy, GPU); - -} // namespace dali diff --git a/dali/pipeline/operator/builtin/copy.h b/dali/pipeline/operator/builtin/copy.h index ebf04a8ded..91c2995ea6 100644 --- a/dali/pipeline/operator/builtin/copy.h +++ b/dali/pipeline/operator/builtin/copy.h @@ -28,10 +28,8 @@ namespace dali { template class Copy : public StatelessOperator { public: - inline explicit Copy(const OpSpec &spec) : - StatelessOperator(spec), scatter_gather_(kMaxSizePerBlock) {} - - inline ~Copy() override = default; + explicit Copy(const OpSpec &spec) : + StatelessOperator(spec) {} DISABLE_COPY_MOVE_ASSIGN(Copy); @@ -42,37 +40,15 @@ class Copy : public StatelessOperator { bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { output_desc.resize(1); - const auto &input = ws.Input(0); - output_desc[0].type = input.type(); - output_desc[0].shape = input.shape(); + output_desc[0].type = ws.GetInputDataType(0); + output_desc[0].shape = ws.GetInputShape(0); return true; } - void RunImpl(Workspace &ws) override { - auto &input = ws.Input(0); - auto data_type_size = input.type_info().size(); - auto &output = ws.Output(0); - output.SetLayout(input.GetLayout()); - for (int i = 0; i < input.num_samples(); i++) { - auto tensor_shape = input.tensor_shape(i); - auto tensor_size = volume(tensor_shape); - scatter_gather_.AddCopy(output.raw_mutable_tensor(i), input.raw_tensor(i), - tensor_size * data_type_size); - } - RunCopies(ws); - } - - void RunCopies(Workspace &ws); - - std::conditional_t< - std::is_same::value, - kernels::ScatterGatherCPU, - kernels::ScatterGatherGPU> scatter_gather_; - // 256 kB per block for GPU - static constexpr size_t kMaxSizePerBlock = - std::is_same::value ? kernels::ScatterGatherCPU::kAnyBlockSize : 1 << 18; + void RunImpl(Workspace &ws) override; }; + } // namespace dali #endif // DALI_PIPELINE_OPERATOR_BUILTIN_COPY_H_ diff --git a/dali/pipeline/operator/checkpointing/checkpoint_test.cc b/dali/pipeline/operator/checkpointing/checkpoint_test.cc index 6d1fce29e1..a72feca647 100644 --- a/dali/pipeline/operator/checkpointing/checkpoint_test.cc +++ b/dali/pipeline/operator/checkpointing/checkpoint_test.cc @@ -36,7 +36,7 @@ void BuildFromLegacyGraph(Checkpoint &checkpoint, const OpGraph &graph) { } auto GetSimpleExecutor() { - return GetExecutor(false, false, false, 1, 1, CPU_ONLY_DEVICE_ID, 0); + return GetExecutor(false, false, false, false, 1, 1, CPU_ONLY_DEVICE_ID, 0); } } // namespace diff --git a/dali/pipeline/pipeline.cc b/dali/pipeline/pipeline.cc index fc1d956907..02b162a891 100644 --- a/dali/pipeline/pipeline.cc +++ b/dali/pipeline/pipeline.cc @@ -96,20 +96,20 @@ void InitializeMemoryResources() { Pipeline::Pipeline(int max_batch_size, int num_threads, int device_id, int64_t seed, bool pipelined_execution, int prefetch_queue_depth, - bool async_execution, size_t bytes_per_sample_hint, - bool set_affinity, int max_num_stream, int default_cuda_stream_priority) - : built_(false), separated_execution_{false} { + bool async_execution, bool dynamic_execution, size_t bytes_per_sample_hint, + bool set_affinity, int max_num_stream, int default_cuda_stream_priority) { InitializeMemoryResources(); Init(max_batch_size, num_threads, device_id, seed, pipelined_execution, separated_execution_, - async_execution, bytes_per_sample_hint, set_affinity, max_num_stream, + async_execution, dynamic_execution, bytes_per_sample_hint, set_affinity, max_num_stream, default_cuda_stream_priority, QueueSizes{prefetch_queue_depth}); } -Pipeline::Pipeline(const string &serialized_pipe, int batch_size, int num_threads, int device_id, - bool pipelined_execution, int prefetch_queue_depth, bool async_execution, +Pipeline::Pipeline(const string &serialized_pipe, + int batch_size, int num_threads, int device_id, + bool pipelined_execution, int prefetch_queue_depth, + bool async_execution, bool dynamic_execution, size_t bytes_per_sample_hint, bool set_affinity, int max_num_stream, - int default_cuda_stream_priority, int64_t seed) - : built_(false), separated_execution_(false) { + int default_cuda_stream_priority, int64_t seed) { InitializeMemoryResources(); dali_proto::PipelineDef def; DALI_ENFORCE(DeserializePipeline(serialized_pipe, def), "Error parsing serialized pipeline."); @@ -138,6 +138,7 @@ Pipeline::Pipeline(const string &serialized_pipe, int batch_size, int num_thread pipelined_execution, separated_execution_, async_execution, + dynamic_execution, bytes_per_sample_hint, set_affinity, max_num_stream, @@ -177,7 +178,8 @@ Pipeline::~Pipeline() { } void Pipeline::Init(int max_batch_size, int num_threads, int device_id, int64_t seed, - bool pipelined_execution, bool separated_execution, bool async_execution, + bool pipelined_execution, bool separated_execution, + bool async_execution, bool dynamic_execution, size_t bytes_per_sample_hint, bool set_affinity, int max_num_stream, int default_cuda_stream_priority, QueueSizes prefetch_queue_depth) { DALI_ENFORCE(device_id == CPU_ONLY_DEVICE_ID || cuInitChecked(), @@ -194,6 +196,7 @@ void Pipeline::Init(int max_batch_size, int num_threads, int device_id, int64_t this->original_seed_ = seed < 0 ? Clock::now().time_since_epoch().count() : seed; this->pipelined_execution_ = pipelined_execution; this->async_execution_ = async_execution; + this->dynamic_execution_ = dynamic_execution; this->bytes_per_sample_hint_ = bytes_per_sample_hint; this->set_affinity_ = set_affinity; this->max_num_stream_ = max_num_stream; @@ -329,43 +332,14 @@ int Pipeline::AddOperatorImpl(const OpSpec &const_spec, const std::string &inst_ make_string("Data node \"", input_name, "\" requested as ", FormatInput(spec, i), " to operator \"", inst_name, "\" is not known to the pipeline.")); - // Table of possible scenarios: - // Op location / requested input type / data location - // cpu / cpu / cpu -> everything is fine - // cpu / cpu / gpu -> error, data does not exist on cpu - // cpu / gpu / cpu -> error, cpu op not allowed to have gpu inputs - // cpu / gpu / gpu -> both of above errors - // gpu / cpu / cpu -> need to use contiguous version - // gpu / cpu / gpu -> error, data not in specified location - // gpu / gpu / cpu -> need to insert copy to device - // gpu / gpu / gpu -> everything is fine - // mixed / cpu / cpu -> everything is fine - // mixed / cpu / gpu -> error, data does not exist on cpu - // mixed / gpu / cpu -> error, mixed op not allowed to have gpu inputs - // mixed / gpu / gpu -> both of above errors - if (device == "cpu" || device == "mixed") { - DALI_ENFORCE(input_device == "cpu", - make_string("Error while specifying ", FormatInput(spec, i), - ". CPU/Mixed ops can only take CPU inputs. CPU operator cannot " - "follow GPU operator. ")); - DALI_ENFORCE(it->second.has_cpu, - make_string("Error while specifying ", FormatInput(spec, i), - ". CPU input requested by operator exists only on GPU. CPU " - "operator cannot follow GPU operator.")); - DALI_ENFORCE(device_id_ != CPU_ONLY_DEVICE_ID || device == "cpu", - "Cannot add a Mixed operator with a GPU output, 'device_id' " - "should not be `CPU_ONLY_DEVICE_ID`."); - } + DALI_ENFORCE(device_id_ != CPU_ONLY_DEVICE_ID || device == "cpu", + "Cannot add a Mixed operator with a GPU output, 'device_id' " + "should not be `CPU_ONLY_DEVICE_ID`."); if (input_device == "gpu") { - SetupGPUInput(it); + ToGPU(it); } else { - // device == gpu - // TODO(michalz): Add a D2H copy instead - DALI_ENFORCE(it->second.has_cpu, - make_string("Error while specifying ", FormatInput(spec, i), - ". CPU input requested by operator exists only on GPU. CPU " - "operator cannot follow GPU operator.")); + ToCPU(it); } } @@ -385,6 +359,8 @@ int Pipeline::AddOperatorImpl(const OpSpec &const_spec, const std::string &inst_ ". Named argument inputs to operators must be CPU data nodes. " "However, a GPU data node was provided.")); } + + ToCPU(it); } // Verify and record the outputs of the op @@ -398,30 +374,14 @@ int Pipeline::AddOperatorImpl(const OpSpec &const_spec, const std::string &inst_ make_string("Error while specifying ", FormatOutput(spec, i), ". Output name \"", output_name, "\" conflicts with an existing intermediate result name.")); - // Validate output data conforms to graph constraints - // Note: DALI CPU -> GPU flow is enforced, when the operators are added via the Python layer - // in `generate_outputs` - the output_device is calculated and assigned to DataNode. - // TODO(michalz): Remove this constraint! Insert GPU->CPU copy instead. - bool mark_explicitly_contiguous = false; if (device == "cpu") { DALI_ENFORCE(output_device == "cpu", make_string("Error while specifying ", FormatOutput(spec, i), ". Only CPU operators can produce CPU outputs.")); - } else if (device == "gpu") { - if (output_device == "cpu") { - mark_explicitly_contiguous = true; - } } - // The edge describes that the named output of this operator produces the CPU or GPU data, - // the former for "cpu" ops, the latter for "mixed" and "gpu" (see Note above about the DALI - // CPU -> GPU flow). - // There are exceptions: we can have CPU output from Mixed MakeContiguous - see - // [cpu output of mixed] where we break the constraints from Python frontend. + // the former for "cpu" ops, the latter for "mixed" and "gpu". EdgeMeta meta = NewEdge(output_device); - if (mark_explicitly_contiguous) { - meta.has_contiguous = true; - } DALI_ENFORCE(edge_names_.insert({output_name, meta}).second, make_string("Error while specifying ", FormatOutput(spec, i), "node name: \"", @@ -511,28 +471,13 @@ void Pipeline::Build(std::vector output_descs) { make_string("User specified incorrect number of outputs (", num_outputs, ").")); executor_ = - GetExecutor(pipelined_execution_, separated_execution_, async_execution_, max_batch_size_, - num_threads_, device_id_, bytes_per_sample_hint_, set_affinity_, max_num_stream_, - default_cuda_stream_priority_, prefetch_queue_depth_); + GetExecutor(pipelined_execution_, separated_execution_, async_execution_, dynamic_execution_, + max_batch_size_, num_threads_, device_id_, bytes_per_sample_hint_, set_affinity_, + max_num_stream_, default_cuda_stream_priority_, prefetch_queue_depth_); executor_->EnableMemoryStats(enable_memory_stats_); executor_->EnableCheckpointing(checkpointing_); executor_->Init(); - // Creating the graph - - for (auto& name_op_spec : op_specs_) { - string& inst_name = name_op_spec.instance_name; - OpSpec op_spec = name_op_spec.spec; - PrepareOpSpec(&op_spec, name_op_spec.logical_id); - try { - graph_builder_.Add(inst_name, op_spec); - } catch (...) { - PropagateError({std::current_exception(), - "Critical error when building pipeline:\n" + GetErrorContextMessage(op_spec), - "\nCurrent pipeline object is no longer valid."}); - } - } - // Validate the output tensors names vector outputs; for (const auto &out_desc : output_descs_) { @@ -543,14 +488,17 @@ void Pipeline::Build(std::vector output_descs) { name + "' is not known to the pipeline."); if (device == "cpu") { - DALI_ENFORCE(it->second.has_cpu, "Requested cpu output '" + - name + "' only exists on gpu."); - // Add a make contiguous op to produce this output - we need pipeline outputs to be dense. - auto output_name = AddMakeContiguousNode(it->second, name, "cpu", "cpu", "cpu"); - if (!it->second.has_contiguous) { - it->second.has_contiguous = true; + if (!it->second.has_cpu) + ToCPU(it); + + if (!it->second.has_contiguous_cpu) { + // Add a make contiguous op to produce this output - we need pipeline outputs to be dense. + auto output_name = AddMakeContiguousNode(it->second, name, "cpu", "cpu", "cpu"); + outputs.push_back(output_name); + } else { + outputs.push_back(it->first + "_cpu"); } - outputs.push_back(output_name); + } else if (device == "gpu") { DALI_ENFORCE(device_id_ != CPU_ONLY_DEVICE_ID, make_string( @@ -559,16 +507,14 @@ void Pipeline::Build(std::vector output_descs) { "is set to `CPU_ONLY_DEVICE_ID`. Set 'device_id' " "to a valid GPU identifier to enable GPU features " "in the pipeline.")); - if (!it->second.has_gpu) { - DALI_ENFORCE(it->second.has_cpu, "Output '" + name + - "' exists on neither cpu or gpu, internal error"); - // Add a copy to device to create the gpu output - auto output_name = AddMakeContiguousNode(it->second, name, "cpu", "mixed", "gpu"); - outputs.push_back(output_name); - } else { - // Add an optional copy/pass through to normalize the output. + if (!it->second.has_gpu) + ToGPU(it); + + if (!it->second.has_contiguous_gpu) { auto output_name = AddMakeContiguousNode(it->second, name, "gpu", "gpu", "gpu"); outputs.push_back(output_name); + } else { + outputs.push_back(it->first + "_gpu"); } } else { DALI_FAIL("Invalid device argument \"" + device + @@ -576,6 +522,21 @@ void Pipeline::Build(std::vector output_descs) { } } + // Creating the graph + + for (auto &name_op_spec : op_specs_) { + const string &inst_name = name_op_spec.instance_name; + OpSpec op_spec = name_op_spec.spec; + try { + PrepareOpSpec(&op_spec, name_op_spec.logical_id); + graph_builder_.Add(inst_name, op_spec); + } catch (...) { + PropagateError({std::current_exception(), + "Critical error when building pipeline:\n" + GetErrorContextMessage(op_spec), + "\nCurrent pipeline object is no longer valid."}); + } + } + for (auto &out : outputs) graph_builder_.AddOutput(out); @@ -677,8 +638,25 @@ void Pipeline::ReleaseOutputs() { } } -void Pipeline::SetupGPUInput(std::map::iterator it) { - if (it->second.has_gpu) return; +void Pipeline::ToCPU(std::map::iterator it) { + // Insert a D2H copy, if needed + if (it->second.has_cpu) + return; + OpSpec copy_to_host_spec = + OpSpec("Copy") + .AddArg("device", "cpu") + .AddInput(it->first, "gpu") + .AddOutput(it->first, "cpu"); + // don't put it into op_specs_for_serialization_, only op_specs_ + AddToOpSpecs("__Copy_GpuToCpu_" + it->first, copy_to_host_spec, GetNextInternalLogicalId()); + it->second.has_cpu = true; + it->second.has_contiguous_cpu = true; // the result is always contiguous +} + +void Pipeline::ToGPU(std::map::iterator it) { + // Insert a H2D copy, if needed + if (it->second.has_gpu) + return; OpSpec copy_to_dev_spec = OpSpec("MakeContiguous") .AddArg("device", "mixed") @@ -687,6 +665,7 @@ void Pipeline::SetupGPUInput(std::map::iterator it) { // don't put it into op_specs_for_serialization_, only op_specs_ AddToOpSpecs("__Copy_CpuToGpu_" + it->first, copy_to_dev_spec, GetNextInternalLogicalId()); it->second.has_gpu = true; + it->second.has_contiguous_gpu = true; // the result is always contiguous } void Pipeline::PrepareOpSpec(OpSpec *spec, int logical_id) { @@ -1048,8 +1027,8 @@ std::string Pipeline::AddMakeContiguousNode(EdgeMeta &meta, const std::string &i } // Add a make contiguous op to produce this output - PrepareOpSpec(&spec, GetNextInternalLogicalId()); - graph_builder_.Add(op_name, std::move(spec)); + auto id = GetNextInternalLogicalId(); + AddToOpSpecs(op_name, std::move(spec), id); if (output_dev == "cpu") { meta.has_make_contiguous_cpu = true; diff --git a/dali/pipeline/pipeline.h b/dali/pipeline/pipeline.h index c85983aa55..dece608da3 100644 --- a/dali/pipeline/pipeline.h +++ b/dali/pipeline/pipeline.h @@ -65,13 +65,7 @@ class DLL_PUBLIC Pipeline { * @brief Creates a pipeline that will produce batches of size `batch_size`, * using `num_threads` worker threads on gpu `device_id`. * - * GPU memory and pinned memory allocations cause implicit synchronization of - * the device, resulting in very slow startup times as dali buffer sizes - * stabilize. To avoid this slowdown, we optionally take in an estimated size - * of each image that will be processed in bytes. This hint is used to - * pre-size buffers, potentially avoiding slow startup if the hint is close - * to the true amount of memory that will be needed by the largest image to - * be processed. + * TODO(michalz): Rework Pipeline construction to use a configuration structure. * * @param max_batch_size the maximum size of the batch that can be produced. * @param num_threads the number of threads to use in the prefetch stage. @@ -83,8 +77,10 @@ class DLL_PUBLIC Pipeline { * @param prefetch_queue_depth sets the length of the executor internal pipeline * @param async_execution whether to use extra host-threads to enable asynchronous execution * of cpu and gpu work. See AsyncExecutor/AsyncPipelinedExecutor. + * @param dynamic_execution whether to use the dynamic executor, enabling GPU->CPU transfers + * and dynamic allocation of memory. * @param bytes_per_sample_hint Estimated size of each sample to be processed. - * Defaults to 0. + * Defaults to 0. Ignored when dynamic_execution is true. * @param set_affinity indicates whether thread affinity should be * configured in the thread pool. Defaults to 'false'. * @param max_num_stream set an upper limit on the number of cudaStreams @@ -94,13 +90,14 @@ class DLL_PUBLIC Pipeline { */ DLL_PUBLIC Pipeline(int max_batch_size, int num_threads, int device_id, int64_t seed = -1, bool pipelined_execution = true, int prefetch_queue_depth = 2, - bool async_execution = true, size_t bytes_per_sample_hint = 0, - bool set_affinity = false, int max_num_stream = -1, - int default_cuda_stream_priority = 0); + bool async_execution = true, bool dynamic_execution = false, + size_t bytes_per_sample_hint = 0, bool set_affinity = false, + int max_num_stream = -1, int default_cuda_stream_priority = 0); - DLL_PUBLIC Pipeline(const string &serialized_pipe, int max_batch_size = -1, int num_threads = -1, - int device_id = -1, bool pipelined_execution = true, - int prefetch_queue_depth = 2, bool async_execution = true, + DLL_PUBLIC Pipeline(const string &serialized_pipe, + int max_batch_size = -1, int num_threads = -1, int device_id = -1, + bool pipelined_execution = true, int prefetch_queue_depth = 2, + bool async_execution = true, bool dynamic_execution = false, size_t bytes_per_sample_hint = 0, bool set_affinity = false, int max_num_stream = -1, int default_cuda_stream_priority = 0, int64_t seed = -1); @@ -115,10 +112,10 @@ class DLL_PUBLIC Pipeline { * device placemnt. */ DLL_PUBLIC int AddExternalInput(const string &name, - const string &device = "cpu", - DALIDataType dtype = DALI_NO_TYPE, - int ndim = -1, - const TensorLayout &layout = "") { + const string &device = "cpu", + DALIDataType dtype = DALI_NO_TYPE, + int ndim = -1, + const TensorLayout &layout = "") { auto spec = OpSpec("ExternalSource") .AddArg("name", name) .AddArg("device", device) @@ -280,6 +277,9 @@ class DLL_PUBLIC Pipeline { /** * @brief Set execution characteristics for this Pipeline * + * TODO(michalz): Remove this function and rework Pipeline construction + * to use a configuration structure. + * * @param pipelined_execution Use pipelined execution * @param separated_execution Use separated queues * @param async_execution Use worker threads for RunX() functions @@ -585,14 +585,17 @@ class DLL_PUBLIC Pipeline { * @brief Initializes the Pipeline internal state */ void Init(int batch_size, int num_threads, int device_id, int64_t seed, bool pipelined_execution, - bool separated_execution, bool async_execution, size_t bytes_per_sample_hint, + bool separated_execution, bool async_execution, bool dynamic_execution, + size_t bytes_per_sample_hint, bool set_affinity, int max_num_stream, int default_cuda_stream_priority, QueueSizes prefetch_queue_depth = QueueSizes{2}); - using EdgeMeta = struct { + struct EdgeMeta { bool has_cpu; bool has_gpu; - bool has_contiguous; + // Whether the given backend is guaranteed to have contiguous storage + bool has_contiguous_cpu; + bool has_contiguous_gpu; // MakeContiguous was added after this node to be used as output on specified device: bool has_make_contiguous_cpu; bool has_make_contiguous_gpu; @@ -614,22 +617,17 @@ class DLL_PUBLIC Pipeline { */ int AddOperatorImpl(const OpSpec &spec, const std::string& inst_name, int logical_id); - void SetupGPUInput(std::map::iterator it); + void ToCPU(std::map::iterator it); + void ToGPU(std::map::iterator it); inline EdgeMeta NewEdge(const std::string &device) { - EdgeMeta edge; - edge.has_cpu = false; - edge.has_gpu = false; - edge.has_contiguous = false; - edge.has_make_contiguous_cpu = false; - edge.has_make_contiguous_gpu = false; + EdgeMeta edge{}; if (device == "cpu") { edge.has_cpu = true; } else if (device == "gpu") { edge.has_gpu = true; } else if (device == "mixed") { edge.has_gpu = true; - edge.has_contiguous = true; } else { DALI_FAIL("Invalid device argument \"" + device + "\". " "Valid options are \"cpu\", \"gpu\" or \"mixed\"."); @@ -714,6 +712,7 @@ class DLL_PUBLIC Pipeline { bool pipelined_execution_ = false; bool separated_execution_ = false; bool async_execution_ = false; + bool dynamic_execution_ = false; size_t bytes_per_sample_hint_ = 0; int set_affinity_ = 0; int max_num_stream_ = 0; @@ -725,7 +724,7 @@ class DLL_PUBLIC Pipeline { bool checkpointing_ = false; std::vector seed_; - int original_seed_ = 0; + int64_t original_seed_ = 0; size_t current_seed_ = 0; std::unique_ptr executor_; diff --git a/dali/pipeline/pipeline_test.cc b/dali/pipeline/pipeline_test.cc index 1ea9da7e80..f339eefc9c 100644 --- a/dali/pipeline/pipeline_test.cc +++ b/dali/pipeline/pipeline_test.cc @@ -54,50 +54,11 @@ class PipelineTest : public DALITest { void RunTestEnforce(const string &dev1, const string &dev2) { Pipeline pipe(1, 1, 0); - // TODO(michalz): This is a totally artificial limitation. Remove the constraint and the tests. - - // Inputs must be know to the pipeline, i.e. ops - // must be added in a topological ordering. - ASSERT_THROW( - pipe.AddOperator( - OpSpec("Copy") - .AddArg("device", dev1) - .AddInput("data", dev1) - .AddOutput("copy_out", dev1)), - std::runtime_error); - pipe.AddOperator( OpSpec("ExternalSource") .AddArg("device", "gpu") .AddOutput("data", "gpu")); - // TODO(michalz): Remove this constraint and the tests. This should be a build-time error, - // with old executor, not a construction-time error. - - // For dev1 = "cpu": Inputs to CPU ops must be on CPU, - // we do not auto-copy them from gpu to cpu. - // For dev1 = "gpu": CPU inputs to GPU ops must be on CPU, - // we will not copy them back to the host. - ASSERT_THROW( - pipe.AddOperator( - OpSpec("Copy") - .AddArg("device", dev1) - .AddInput("data", dev2) - .AddOutput("copy_out", dev1)), - std::runtime_error); - - if (dev1 == "cpu") { - // Inputs to CPU ops must already exist on CPU, - // we do not auto-copy them from gpu to cpu. - ASSERT_THROW( - pipe.AddOperator( - OpSpec("Copy") - .AddArg("device", dev1) - .AddInput("data", dev1) - .AddOutput("copy_out", dev1)), - std::runtime_error); - } - pipe.AddOperator( OpSpec("ExternalSource") .AddArg("device", dev1) @@ -405,6 +366,7 @@ TEST_F(PipelineTestOnce, TestPresize) { const int num_thread = 1; const bool pipelined = false; const bool async = false; + const bool dynamic = false; DALIImageType img_type = DALI_RGB; const int presize_val_CPU = 11; @@ -418,6 +380,7 @@ TEST_F(PipelineTestOnce, TestPresize) { num_thread, 0, -1, pipelined, 3, async, + dynamic, presize_val_default); TensorList data; diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index 532c327257..c4eabb2d06 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -1882,13 +1882,15 @@ PYBIND11_MODULE(backend_impl, m) { .def(py::init( [](int batch_size, int num_threads, int device_id, int64_t seed = -1, bool pipelined_execution = true, int prefetch_queue_depth = 2, - bool async_execution = true, size_t bytes_per_sample_hint = 0, + bool async_execution = true, bool dynamic_execution = false, + size_t bytes_per_sample_hint = 0, bool set_affinity = false, int max_num_stream = -1, int default_cuda_stream_priority = 0) { return std::make_unique( - batch_size, num_threads, device_id, seed, pipelined_execution, - prefetch_queue_depth, async_execution, bytes_per_sample_hint, set_affinity, - max_num_stream, default_cuda_stream_priority); + batch_size, num_threads, device_id, seed, + pipelined_execution, prefetch_queue_depth, async_execution, dynamic_execution, + bytes_per_sample_hint, set_affinity, + max_num_stream, default_cuda_stream_priority); }), "batch_size"_a, "num_threads"_a, @@ -1897,6 +1899,7 @@ PYBIND11_MODULE(backend_impl, m) { "exec_pipelined"_a = true, "prefetch_queue_depth"_a = 2, "exec_async"_a = true, + "exec_dynamic"_a = false, "bytes_per_sample_hint"_a = 0, "set_affinity"_a = false, "max_num_stream"_a = -1, @@ -1907,14 +1910,15 @@ PYBIND11_MODULE(backend_impl, m) { [](string serialized_pipe, int batch_size = -1, int num_threads = -1, int device_id = -1, bool pipelined_execution = true, int prefetch_queue_depth = 2, - bool async_execution = true, size_t bytes_per_sample_hint = 0, - bool set_affinity = false, int max_num_stream = -1, + bool async_execution = true, bool dynamic_execution = false, + size_t bytes_per_sample_hint = 0, bool set_affinity = false, int max_num_stream = -1, int default_cuda_stream_priority = 0) { return std::make_unique( serialized_pipe, batch_size, num_threads, device_id, pipelined_execution, - prefetch_queue_depth, async_execution, bytes_per_sample_hint, - set_affinity, max_num_stream, default_cuda_stream_priority); + prefetch_queue_depth, async_execution, dynamic_execution, + bytes_per_sample_hint, set_affinity, + max_num_stream, default_cuda_stream_priority); }), "serialized_pipe"_a, "batch_size"_a = -1, @@ -1923,6 +1927,7 @@ PYBIND11_MODULE(backend_impl, m) { "exec_pipelined"_a = true, "prefetch_queue_depth"_a = 2, "exec_async"_a = true, + "exec_dynamic"_a = true, "bytes_per_sample_hint"_a = 0, "set_affinity"_a = false, "max_num_stream"_a = -1, diff --git a/dali/python/nvidia/dali/_utils/eager_utils.py b/dali/python/nvidia/dali/_utils/eager_utils.py index 466937c305..da04d3c327 100644 --- a/dali/python/nvidia/dali/_utils/eager_utils.py +++ b/dali/python/nvidia/dali/_utils/eager_utils.py @@ -604,8 +604,11 @@ def _choose_device(op_name, wrapper_name, inputs, device_param): device = device_param device_id = 0 - if device == "cpu" and input_device == "gpu": - raise ValueError("An operator with device='cpu' cannot accept GPU inputs.") + # TODO(michalz): Verify against InputDevice from the schema. + # TODO(michalz): Add InputDevice::Any for operators which can take any input backend + # Temporarily the check is disabled + # if device == "cpu" and input_device == "gpu": + # raise ValueError("An operator with device='cpu' cannot accept GPU inputs.") if device != "cpu" and device != "gpu": raise ValueError(f"Incorrect device type '{device}'.") diff --git a/dali/python/nvidia/dali/data_node.py b/dali/python/nvidia/dali/data_node.py index 48d636c0a2..ebd8312dbb 100644 --- a/dali/python/nvidia/dali/data_node.py +++ b/dali/python/nvidia/dali/data_node.py @@ -77,19 +77,28 @@ def __str__(self): __repr__ = __str__ + def gpu(self) -> DataNode: + return self._to_backend("gpu") + + def cpu(self) -> DataNode: + return self._to_backend("cpu") + # Note: Regardless of whether we want the cpu or gpu version # of a tensor, we keep the source argument the same so that # the pipeline can backtrack through the user-defined graph - def gpu(self) -> DataNode: + def _to_backend(self, backend) -> DataNode: + if self.device == backend: + return self + from nvidia.dali import _conditionals if _conditionals.conditionals_enabled(): # Treat it the same way as regular operator would behave [self_split], _ = _conditionals.apply_conditional_split_to_args([self], {}) - transferred_node = DataNode(self_split.name, "gpu", self_split.source) + transferred_node = DataNode(self_split.name, backend, self_split.source) _conditionals.register_data_nodes(transferred_node, [self]) return transferred_node - return DataNode(self.name, "gpu", self.source) + return DataNode(self.name, backend, self.source) def __add__(self, other) -> DataNode: return _arithm_op("add", self, other) diff --git a/dali/python/nvidia/dali/fn/__init__.py b/dali/python/nvidia/dali/fn/__init__.py index dcbc6e8539..a7f215e972 100644 --- a/dali/python/nvidia/dali/fn/__init__.py +++ b/dali/python/nvidia/dali/fn/__init__.py @@ -70,8 +70,11 @@ def op_wrapper(*inputs, **kwargs): init_args, call_args = nvidia.dali.ops._separate_kwargs(kwargs) default_dev = nvidia.dali.ops._choose_device(inputs) - if default_dev == "gpu" and init_args.get("device") == "cpu": - raise ValueError("An operator with device='cpu' cannot accept GPU inputs.") + # TODO(michalz): Verify against InputDevice from the schema. + # TODO(michalz): Add InputDevice::Any for operators which can take any input backend + # Temporarily the check is disabled + # if default_dev == "gpu" and init_args.get("device") == "cpu": + # raise ValueError("An operator with device='cpu' cannot accept GPU inputs.") if "device" not in init_args: init_args["device"] = default_dev diff --git a/dali/python/nvidia/dali/pipeline.py b/dali/python/nvidia/dali/pipeline.py index e29930f925..69d6aa7dcb 100644 --- a/dali/python/nvidia/dali/pipeline.py +++ b/dali/python/nvidia/dali/pipeline.py @@ -227,6 +227,7 @@ def __init__( py_callback_pickler=None, output_dtype=None, output_ndim=None, + experimental_exec_dynamic=False, ): self._pipe = None self._sinks = [] @@ -255,6 +256,7 @@ def __init__( self._batches_to_consume = 0 self._names_and_devices = None self._exec_async = exec_async + self._exec_dynamic = experimental_exec_dynamic self._bytes_per_sample = bytes_per_sample self._set_affinity = set_affinity self._max_streams = max_streams @@ -865,6 +867,7 @@ def _init_pipeline_backend(self): self._exec_pipelined, self._cpu_queue_size, self._exec_async, + self._exec_dynamic, self._bytes_per_sample, self._set_affinity, self._max_streams, @@ -1519,6 +1522,7 @@ def deserialize(cls, serialized_pipeline=None, filename=None, **kwargs): kw.get("exec_pipelined", True), kw.get("prefetch_queue_depth", 2), kw.get("exec_async", True), + kw.get("exec_dynamic", False), kw.get("bytes_per_sample", 0), kw.get("set_affinity", False), kw.get("max_streams", -1), @@ -1566,6 +1570,7 @@ def deserialize_and_build(self, serialized_pipeline): self._exec_pipelined, self._prefetch_queue_depth, self._exec_async, + self._exec_dynamic, self._bytes_per_sample, self._set_affinity, self._max_streams, diff --git a/dali/test/python/operator_1/test_slice.py b/dali/test/python/operator_1/test_slice.py index c486f6c8fc..69bd4cb760 100644 --- a/dali/test/python/operator_1/test_slice.py +++ b/dali/test/python/operator_1/test_slice.py @@ -1175,7 +1175,9 @@ def make_pipe(): sliced = fn.slice(fake_data, rel_start, rel_shape, device="cpu") return sliced - with assert_raises(ValueError, glob="An operator with device='cpu' cannot accept GPU inputs"): + # TODO(michalz): Restore the old check when we have proper verification against schema + # with assert_raises(ValueError, glob="An operator with device='cpu' cannot accept GPU inputs"): + with assert_raises(RuntimeError, glob="doesn't support transition from GPU to CPU"): p = make_pipe() p.build() p.run() diff --git a/dali/test/python/test_pipeline.py b/dali/test/python/test_pipeline.py index dd6016248a..0a6456666e 100644 --- a/dali/test/python/test_pipeline.py +++ b/dali/test/python/test_pipeline.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -2229,3 +2229,113 @@ def test_subgraph_stealing(): glob="The pipeline is invalid because it contains operators with non-unique names", ): p2.build() + + +def test_gpu2cpu(): + bs = 8 + + @pipeline_def(batch_size=bs, num_threads=4, device_id=0, experimental_exec_dynamic=True) + def pdef(): + enc, _ = fn.readers.file(file_root=jpeg_folder) + img = fn.decoders.image(enc, device="mixed") + return img, img.cpu() + + pipe = pdef() + pipe.build() + for i in range(10): + gpu, cpu = pipe.run() + assert isinstance(gpu, dali.backend_impl.TensorListGPU) + assert isinstance(cpu, dali.backend_impl.TensorListCPU) + check_batch(cpu, gpu, bs, 0, 0, "HWC") + + +def test_shapes_gpu(): + bs = 8 + + @pipeline_def(batch_size=bs, num_threads=4, device_id=0, experimental_exec_dynamic=True) + def pdef(): + enc, _ = fn.readers.file(file_root=jpeg_folder) + img = fn.decoders.image(enc, device="mixed") + peek = fn.peek_image_shape(enc) + return peek, fn.shapes(img, device="cpu"), fn.shapes(img.cpu()) + + pipe = pdef() + pipe.build() + for i in range(10): + peek, shape_of_gpu, shape_of_cpu = pipe.run() + # all results must be CPU tensor lists + assert isinstance(peek, dali.backend_impl.TensorListCPU) + assert isinstance(shape_of_gpu, dali.backend_impl.TensorListCPU) + assert isinstance(shape_of_cpu, dali.backend_impl.TensorListCPU) + check_batch(shape_of_gpu, peek, bs, 0, 0) + check_batch(shape_of_cpu, peek, bs, 0, 0) + + +def test_gpu2cpu_old_exec_error(): + bs = 8 + + @pipeline_def( + batch_size=bs, + num_threads=4, + device_id=0, + exec_async=False, + exec_pipelined=False, + experimental_exec_dynamic=False, + ) + def pdef(): + gpu = fn.external_source("input", device="gpu") + return gpu.cpu() + + pipe = pdef() + with assert_raises(RuntimeError, glob="doesn't support transition from GPU to CPU"): + pipe.build() + + +def test_gpu2cpu_conditionals(): + bs = 4 + + @pipeline_def( + batch_size=bs, + num_threads=4, + device_id=0, + experimental_exec_dynamic=True, # use new executor + enable_conditionals=True, + ) + def def_test(): + enc, label = fn.readers.file(file_root=jpeg_folder) + img = fn.decoders.image(enc, device="mixed") + # return inverted image for even samples + if (label[0] & 1) == 0: + out = img ^ np.uint8(255) + out_cpu = out.cpu() + else: + out = img + out_cpu = out.cpu() + return out, out_cpu + + @pipeline_def( + batch_size=bs, + num_threads=4, + device_id=0, + exec_async=False, # use old executor, even in presence of DALI_USE_EXEC2 + exec_pipelined=False, + ) + def def_ref(): + enc, label = fn.readers.file(file_root=jpeg_folder) + img = fn.decoders.image(enc, device="mixed") + # return inverted image for even samples + even = (label[0] & 1) == 0 + mask = fn.cast(even * 255, dtype=types.UINT8) + return img ^ mask + + test_pipe = def_test() + test_pipe.build() + ref_pipe = def_ref() + ref_pipe.build() + for i in range(3): + gpu, cpu = test_pipe.run() + assert isinstance(gpu, dali.backend_impl.TensorListGPU) + assert isinstance(cpu, dali.backend_impl.TensorListCPU) + (ref,) = ref_pipe.run() + check_batch(cpu, ref, bs, 0, 0, "HWC") + check_batch(gpu, ref, bs, 0, 0, "HWC") diff --git a/include/dali/c_api.h b/include/dali/c_api.h index 49d07a5ca1..56d62d8c46 100644 --- a/include/dali/c_api.h +++ b/include/dali/c_api.h @@ -59,6 +59,27 @@ typedef enum { DALI_BOOL = 11, } dali_data_type_t; +typedef enum { + DALI_EXEC_IS_PIPELINED = 1, + DALI_EXEC_IS_ASYNC = 2, + DALI_EXEC_IS_SEPARATED = 4, + DALI_EXEC_IS_DYNAMIC = 8, + + DALI_EXEC_SIMPLE = 0, + DALI_EXEC_ASYNC_PIPELINED = DALI_EXEC_IS_PIPELINED | DALI_EXEC_IS_ASYNC, + DALI_EXEC_DYNAMIC = DALI_EXEC_ASYNC_PIPELINED | DALI_EXEC_IS_DYNAMIC, +} dali_exec_flags_t; + +#ifdef __cplusplus +constexpr dali_exec_flags_t operator|(dali_exec_flags_t x, dali_exec_flags_t y) { + return dali_exec_flags_t(static_cast(x) | static_cast(y)); +} + +constexpr dali_exec_flags_t operator&(dali_exec_flags_t x, dali_exec_flags_t y) { + return dali_exec_flags_t(static_cast(x) & static_cast(y)); +} + +#endif /* * Need to keep that in sync with ReaderMeta from operator.h @@ -158,6 +179,31 @@ daliCreatePipeline2(daliPipelineHandle *pipe_handle, const char *serialized_pipe int cpu_prefetch_queue_depth, int gpu_prefetch_queue_depth, int enable_memory_stats); +/** + * Create a DALI Pipeline, using a pipeline that has been serialized beforehand. + * + * @param pipe_handle Pipeline handle. + * @param serialized_pipeline Serialized pipeline. + * @param length Length of the serialized pipeline string. + * @param max_batch_size Maximum batch size. + * @param num_threads Number of CPU threads which this pipeline uses. + * @param device_id ID of the GPU device which this pipeline uses. + * @param pipelined_execution If != 0, this pipeline will execute in Pipeline mode. + * @param exec_flags Executor congiguration flags + * @param cpu_prefetch_queue_depth Depth of the prefetching queue in the CPU stage. + * If `separated_execution == 0`, this value is ignored + * @param gpu_prefetch_queue_depth Depth of the prefetching queue in the GPU stage. + * If `separated_execution == 0`, this value is ignored + * @param enable_memory_stats Enable memory stats. + */ +DLL_PUBLIC void +daliCreatePipeline3(daliPipelineHandle *pipe_handle, const char *serialized_pipeline, int length, + int max_batch_size, int num_threads, int device_id, + dali_exec_flags_t exec_flags, + int prefetch_queue_depth, + int cpu_prefetch_queue_depth, int gpu_prefetch_queue_depth, + int enable_memory_stats); + /** * Convenient overload. Use it, if the Pipeline should inherit its parameters * from serialized pipeline. @@ -663,7 +709,7 @@ DLL_PUBLIC void daliLoadPluginDirectory(const char* plugin_dir); /** * @brief Load default plugin library * @remarks DALI_PRELOAD_PLUGINS are environment variables that can be used to control what - * plugins are loaded. If the variable is set, it is interpreted as a list of paths separated + * plugins are loaded. If the variable is set, it is interpreted as a list of paths separated * by colon (:), where each element can be a directory or library path. * If not set, the "default" path is scanned, which is a subdirectory called plugin under the * directory where the DALI library is installed.