Skip to content

Commit

Permalink
Simplify copy.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Sep 2, 2024
1 parent 99b44cd commit 016ef1e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 40 deletions.
63 changes: 35 additions & 28 deletions dali/pipeline/operator/builtin/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,47 @@
namespace dali {

template <>
void Copy<CPUBackend, CPUBackend>::RunImpl(Workspace &ws) {
auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);

int batch_size = input.num_samples();
output.SetLayout(input.GetLayout());
auto shapes = input.shape();

auto &thread_pool = ws.GetThreadPool();
for (int sample_id = 0; sample_id < batch_size; ++sample_id) {
thread_pool.AddWork(
[sample_id, &input, &output](int tid) {
output.CopySample(sample_id, input, sample_id, AccessOrder::host());
},
shapes.tensor_size(sample_id));
void Copy<CPUBackend>::RunImpl(Workspace &ws) {
if (ws.InputIsType<CPUBackend>(0)) {
auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);

int batch_size = input.num_samples();
output.SetLayout(input.GetLayout());
auto shapes = input.shape();

auto &thread_pool = ws.GetThreadPool();
for (int sample_id = 0; sample_id < batch_size; ++sample_id) {
thread_pool.AddWork(
[sample_id, &input, &output](int tid) {
output.CopySample(sample_id, input, sample_id, AccessOrder::host());
},
shapes.tensor_size(sample_id));
}
thread_pool.RunAll();
} else {
auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
output.Copy(input, ws.output_order());
}
}

template <>
void Copy<GPUBackend>::RunImpl(Workspace &ws) {
if (ws.InputIsType<CPUBackend>(0)) {
auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
output.Copy(input, ws.output_order());
} else {
auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
output.Copy(input, ws.output_order());
}
thread_pool.RunAll();
}

DALI_REGISTER_OPERATOR(Copy, Copy<CPUBackend>, CPU);
DALI_REGISTER_OPERATOR(Copy, Copy<GPUBackend>, GPU);

using CopyD2H = Copy<CPUBackend, GPUBackend>;
DALI_REGISTER_OPERATOR(CopyD2H, CopyD2H, CPU);

DALI_SCHEMA(Copy)
.DocStr("Creates a copy of the input tensor.")
Expand All @@ -49,14 +66,4 @@ DALI_SCHEMA(Copy)
.AllowSequences()
.SupportVolumetric();


DALI_SCHEMA(CopyD2H)
.DocStr("Creates a copy of the input tensor.")
.NumInput(1)
.InputDevice(0, InputDevice::GPU)
.NumOutput(1)
.AllowSequences()
.SupportVolumetric();


} // namespace dali
17 changes: 6 additions & 11 deletions dali/pipeline/operator/builtin/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

namespace dali {

template <typename DstBackend, typename SrcBackend = DstBackend>
class Copy : public StatelessOperator<DstBackend> {
template <typename Backend>
class Copy : public StatelessOperator<Backend> {
public:
explicit Copy(const OpSpec &spec) :
StatelessOperator<DstBackend>(spec) {}
StatelessOperator<Backend>(spec) {}

DISABLE_COPY_MOVE_ASSIGN(Copy);

Expand All @@ -40,17 +40,12 @@ class Copy : public StatelessOperator<DstBackend> {

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
output_desc.resize(1);
const auto &input = ws.Input<SrcBackend>(0);
output_desc[0].type = input.type();
output_desc[0].shape = input.shape();
output_desc[0].type = ws.GetInputDataType(0);
output_desc[0].shape = ws.GetInputShape(0);
return true;
}

void RunImpl(Workspace &ws) override {
auto &input = ws.Input<SrcBackend>(0);
auto &output = ws.Output<DstBackend>(0);
output.Copy(input, ws.output_order());
}
void RunImpl(Workspace &ws) override;
};


Expand Down
2 changes: 1 addition & 1 deletion dali/pipeline/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ void Pipeline::ToCPU(std::map<string, EdgeMeta>::iterator it) {
if (it->second.has_cpu)
return;
OpSpec copy_to_host_spec =
OpSpec("CopyD2H")
OpSpec("Copy")
.AddArg("device", "cpu")
.AddInput(it->first, "gpu")
.AddOutput(it->first, "cpu");
Expand Down

0 comments on commit 016ef1e

Please sign in to comment.