Skip to content

Commit

Permalink
Fix after silently broken rebase.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Aug 9, 2024
1 parent 258ff3b commit b4e2ab8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 97 deletions.
96 changes: 1 addition & 95 deletions dali/pipeline/executor/executor2/exec2_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,106 +22,12 @@
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/operator/arg_helper.h"
#include "dali/pipeline/graph/op_graph2.h"
#include "dali/pipeline/executor/executor2/exec2_ops_for_test.h"

namespace dali {
namespace exec2 {
namespace test {

constexpr char kTestOpName[] = "Exec2TestOp";

class DummyOpCPU : public Operator<CPUBackend> {
public:
explicit DummyOpCPU(const OpSpec &spec) : Operator<CPUBackend>(spec) {
instance_name_ = spec_.GetArgument<string>("name");
}

bool SetupImpl(std::vector<OutputDesc> &outs, const Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
outs.resize(ws.NumOutput());
outs[0].shape = uniform_list_shape(N, TensorShape<>{});
outs[0].type = DALI_INT32;
return true;
}

void RunImpl(Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
addend_.Acquire(spec_, ws, N);
sample_sums_.resize(N);
auto &tp = ws.GetThreadPool();
for (int s = 0; s < N; s++) {
auto sample_sum = [&, s](int) {
int sum = *addend_[s].data + s;
for (int i = 0; i < ws.NumInput(); i++) {
sum += *ws.Input<CPUBackend>(i)[s].data<int>();
}
sample_sums_[s] = sum;
};
tp.AddWork(sample_sum);
}
tp.RunAll(true);
for (int s = 0; s < N; s++)
*ws.Output<CPUBackend>(0)[s].mutable_data<int>() = sample_sums_[s];
}

bool CanInferOutputs() const override { return true; }
ArgValue<int> addend_{"addend", spec_};

std::vector<int> sample_sums_;
std::string instance_name_;
};

class DummyOpGPU : public Operator<GPUBackend> {
public:
explicit DummyOpGPU(const OpSpec &spec) : Operator<GPUBackend>(spec) {
instance_name_ = spec_.GetArgument<string>("name");
}

bool SetupImpl(std::vector<OutputDesc> &outs, const Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
outs.resize(ws.NumOutput());
outs[0].shape = uniform_list_shape(N, TensorShape<>{});
outs[0].type = DALI_INT32;
return true;
}

void RunImpl(Workspace &ws) override;

bool CanInferOutputs() const override { return true; }

private:
ArgValue<int> addend_{"addend", spec_};

std::string instance_name_;
};


constexpr char kCounterOpName[] = "Exec2Counter";

class CounterOp : public Operator<CPUBackend> {
public:
explicit CounterOp(const OpSpec &spec) : Operator<CPUBackend>(spec) {
}

bool SetupImpl(std::vector<OutputDesc> &outs, const Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
outs.resize(ws.NumOutput());
outs[0].shape = uniform_list_shape(N, TensorShape<>{});
outs[0].type = DALI_INT32;
return true;
}

void RunImpl(Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
for (int s = 0; s < N; s++) {
*ws.Output<CPUBackend>(0)[s].mutable_data<int>() = counter++;
}
}

bool CanInferOutputs() const override { return true; }

int counter = 0;
};

inline auto &AddCommonArgs(
OpSpec &spec, int max_batch_size, const std::string &device = "cpu", int num_threads = 1) {
spec.AddArg("max_batch_size", max_batch_size);
Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/executor/executor2/exec_graph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,15 @@ TEST(ExecGraphTest, LoweredExec) {
ExecEnv env;
env.thread_pool = &tp;
params.env = &env;
params.batch_size = 32;
params.max_batch_size = 32;
params.iter_data = std::make_shared<IterationData>();
{
tasking::Executor ex(4);
ex.Start();
g.PrepareIteration(params);
auto fut = g.Launch(ex);
auto &out = fut.Value<const PipelineOutput &>();
CheckTestGraph1Results(out.workspace, params.batch_size);
CheckTestGraph1Results(out.workspace, params.max_batch_size);
}
}

Expand Down

0 comments on commit b4e2ab8

Please sign in to comment.