From b4e2ab8a74b4ffd65a1baf1c1e15d3f5a99be630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Thu, 8 Aug 2024 15:30:03 +0200 Subject: [PATCH] Fix after silently broken rebase. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: MichaƂ Zientkiewicz --- dali/pipeline/executor/executor2/exec2_test.h | 96 +------------------ .../executor/executor2/exec_graph_test.cc | 4 +- 2 files changed, 3 insertions(+), 97 deletions(-) diff --git a/dali/pipeline/executor/executor2/exec2_test.h b/dali/pipeline/executor/executor2/exec2_test.h index e2574c01b2..f60e288893 100644 --- a/dali/pipeline/executor/executor2/exec2_test.h +++ b/dali/pipeline/executor/executor2/exec2_test.h @@ -22,106 +22,12 @@ #include "dali/pipeline/operator/operator.h" #include "dali/pipeline/operator/arg_helper.h" #include "dali/pipeline/graph/op_graph2.h" +#include "dali/pipeline/executor/executor2/exec2_ops_for_test.h" namespace dali { namespace exec2 { namespace test { -constexpr char kTestOpName[] = "Exec2TestOp"; - -class DummyOpCPU : public Operator { - public: - explicit DummyOpCPU(const OpSpec &spec) : Operator(spec) { - instance_name_ = spec_.GetArgument("name"); - } - - bool SetupImpl(std::vector &outs, const Workspace &ws) override { - int N = ws.GetRequestedBatchSize(0); - outs.resize(ws.NumOutput()); - outs[0].shape = uniform_list_shape(N, TensorShape<>{}); - outs[0].type = DALI_INT32; - return true; - } - - void RunImpl(Workspace &ws) override { - int N = ws.GetRequestedBatchSize(0); - addend_.Acquire(spec_, ws, N); - sample_sums_.resize(N); - auto &tp = ws.GetThreadPool(); - for (int s = 0; s < N; s++) { - auto sample_sum = [&, s](int) { - int sum = *addend_[s].data + s; - for (int i = 0; i < ws.NumInput(); i++) { - sum += *ws.Input(i)[s].data(); - } - sample_sums_[s] = sum; - }; - tp.AddWork(sample_sum); - } - tp.RunAll(true); - for (int s = 0; s < N; s++) - *ws.Output(0)[s].mutable_data() = sample_sums_[s]; - } - - bool CanInferOutputs() const override { return true; } - ArgValue addend_{"addend", spec_}; - - std::vector sample_sums_; - std::string instance_name_; -}; - -class DummyOpGPU : public Operator { - public: - explicit DummyOpGPU(const OpSpec &spec) : Operator(spec) { - instance_name_ = spec_.GetArgument("name"); - } - - bool SetupImpl(std::vector &outs, const Workspace &ws) override { - int N = ws.GetRequestedBatchSize(0); - outs.resize(ws.NumOutput()); - outs[0].shape = uniform_list_shape(N, TensorShape<>{}); - outs[0].type = DALI_INT32; - return true; - } - - void RunImpl(Workspace &ws) override; - - bool CanInferOutputs() const override { return true; } - - private: - ArgValue addend_{"addend", spec_}; - - std::string instance_name_; -}; - - -constexpr char kCounterOpName[] = "Exec2Counter"; - -class CounterOp : public Operator { - public: - explicit CounterOp(const OpSpec &spec) : Operator(spec) { - } - - bool SetupImpl(std::vector &outs, const Workspace &ws) override { - int N = ws.GetRequestedBatchSize(0); - outs.resize(ws.NumOutput()); - outs[0].shape = uniform_list_shape(N, TensorShape<>{}); - outs[0].type = DALI_INT32; - return true; - } - - void RunImpl(Workspace &ws) override { - int N = ws.GetRequestedBatchSize(0); - for (int s = 0; s < N; s++) { - *ws.Output(0)[s].mutable_data() = counter++; - } - } - - bool CanInferOutputs() const override { return true; } - - int counter = 0; -}; - inline auto &AddCommonArgs( OpSpec &spec, int max_batch_size, const std::string &device = "cpu", int num_threads = 1) { spec.AddArg("max_batch_size", max_batch_size); diff --git a/dali/pipeline/executor/executor2/exec_graph_test.cc b/dali/pipeline/executor/executor2/exec_graph_test.cc index bdc86561b0..4a56f64aa5 100644 --- a/dali/pipeline/executor/executor2/exec_graph_test.cc +++ b/dali/pipeline/executor/executor2/exec_graph_test.cc @@ -465,7 +465,7 @@ TEST(ExecGraphTest, LoweredExec) { ExecEnv env; env.thread_pool = &tp; params.env = &env; - params.batch_size = 32; + params.max_batch_size = 32; params.iter_data = std::make_shared(); { tasking::Executor ex(4); @@ -473,7 +473,7 @@ TEST(ExecGraphTest, LoweredExec) { g.PrepareIteration(params); auto fut = g.Launch(ex); auto &out = fut.Value(); - CheckTestGraph1Results(out.workspace, params.batch_size); + CheckTestGraph1Results(out.workspace, params.max_batch_size); } }