From 016ef1e8d101acca2cf25b3ec3dcbe5cac86e400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Sat, 3 Aug 2024 22:55:26 +0200 Subject: [PATCH] Simplify copy. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: MichaƂ Zientkiewicz --- dali/pipeline/operator/builtin/copy.cc | 63 ++++++++++++++------------ dali/pipeline/operator/builtin/copy.h | 17 +++---- dali/pipeline/pipeline.cc | 2 +- 3 files changed, 42 insertions(+), 40 deletions(-) diff --git a/dali/pipeline/operator/builtin/copy.cc b/dali/pipeline/operator/builtin/copy.cc index 30aeb5682b..8588bc92f0 100644 --- a/dali/pipeline/operator/builtin/copy.cc +++ b/dali/pipeline/operator/builtin/copy.cc @@ -17,30 +17,47 @@ namespace dali { template <> -void Copy::RunImpl(Workspace &ws) { - auto &input = ws.Input(0); - auto &output = ws.Output(0); - - int batch_size = input.num_samples(); - output.SetLayout(input.GetLayout()); - auto shapes = input.shape(); - - auto &thread_pool = ws.GetThreadPool(); - for (int sample_id = 0; sample_id < batch_size; ++sample_id) { - thread_pool.AddWork( - [sample_id, &input, &output](int tid) { - output.CopySample(sample_id, input, sample_id, AccessOrder::host()); - }, - shapes.tensor_size(sample_id)); +void Copy::RunImpl(Workspace &ws) { + if (ws.InputIsType(0)) { + auto &input = ws.Input(0); + auto &output = ws.Output(0); + + int batch_size = input.num_samples(); + output.SetLayout(input.GetLayout()); + auto shapes = input.shape(); + + auto &thread_pool = ws.GetThreadPool(); + for (int sample_id = 0; sample_id < batch_size; ++sample_id) { + thread_pool.AddWork( + [sample_id, &input, &output](int tid) { + output.CopySample(sample_id, input, sample_id, AccessOrder::host()); + }, + shapes.tensor_size(sample_id)); + } + thread_pool.RunAll(); + } else { + auto &input = ws.Input(0); + auto &output = ws.Output(0); + output.Copy(input, ws.output_order()); + } +} + +template <> +void Copy::RunImpl(Workspace &ws) { + if (ws.InputIsType(0)) { + auto &input = ws.Input(0); + auto &output = ws.Output(0); + output.Copy(input, ws.output_order()); + } else { + auto &input = ws.Input(0); + auto &output = ws.Output(0); + output.Copy(input, ws.output_order()); } - thread_pool.RunAll(); } DALI_REGISTER_OPERATOR(Copy, Copy, CPU); DALI_REGISTER_OPERATOR(Copy, Copy, GPU); -using CopyD2H = Copy; -DALI_REGISTER_OPERATOR(CopyD2H, CopyD2H, CPU); DALI_SCHEMA(Copy) .DocStr("Creates a copy of the input tensor.") @@ -49,14 +66,4 @@ DALI_SCHEMA(Copy) .AllowSequences() .SupportVolumetric(); - -DALI_SCHEMA(CopyD2H) - .DocStr("Creates a copy of the input tensor.") - .NumInput(1) - .InputDevice(0, InputDevice::GPU) - .NumOutput(1) - .AllowSequences() - .SupportVolumetric(); - - } // namespace dali diff --git a/dali/pipeline/operator/builtin/copy.h b/dali/pipeline/operator/builtin/copy.h index 139ca2b288..91c2995ea6 100644 --- a/dali/pipeline/operator/builtin/copy.h +++ b/dali/pipeline/operator/builtin/copy.h @@ -25,11 +25,11 @@ namespace dali { -template -class Copy : public StatelessOperator { +template +class Copy : public StatelessOperator { public: explicit Copy(const OpSpec &spec) : - StatelessOperator(spec) {} + StatelessOperator(spec) {} DISABLE_COPY_MOVE_ASSIGN(Copy); @@ -40,17 +40,12 @@ class Copy : public StatelessOperator { bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { output_desc.resize(1); - const auto &input = ws.Input(0); - output_desc[0].type = input.type(); - output_desc[0].shape = input.shape(); + output_desc[0].type = ws.GetInputDataType(0); + output_desc[0].shape = ws.GetInputShape(0); return true; } - void RunImpl(Workspace &ws) override { - auto &input = ws.Input(0); - auto &output = ws.Output(0); - output.Copy(input, ws.output_order()); - } + void RunImpl(Workspace &ws) override; }; diff --git a/dali/pipeline/pipeline.cc b/dali/pipeline/pipeline.cc index 0b50d12616..d8f7999077 100644 --- a/dali/pipeline/pipeline.cc +++ b/dali/pipeline/pipeline.cc @@ -633,7 +633,7 @@ void Pipeline::ToCPU(std::map::iterator it) { if (it->second.has_cpu) return; OpSpec copy_to_host_spec = - OpSpec("CopyD2H") + OpSpec("Copy") .AddArg("device", "cpu") .AddInput(it->first, "gpu") .AddOutput(it->first, "cpu");