Skip to content

Commit

Permalink
Minimal stream assignment.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Jul 4, 2024
1 parent 48c39e2 commit 10a1584
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 146 deletions.
20 changes: 16 additions & 4 deletions dali/pipeline/executor/executor2/exec2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ class Executor2::Impl {
}

private:
auto InitIterationData() {
std::shared_ptr<IterationData> InitIterationData() {
auto iter_data = std::make_shared<IterationData>();
iter_data->iteration_index = iter_index++;
iter_data->iteration_index = iter_index_++;
return iter_data;
}


Expand All @@ -79,6 +80,11 @@ class Executor2::Impl {
}

void AnalyzeGraph() {
CountNodes();
//FindInputNodes();
}

void CountNodes() {
for (auto &n : graph_.nodes) {
switch (NodeType(&n)) {
case OpType::CPU:
Expand All @@ -96,6 +102,8 @@ class Executor2::Impl {
if (n.inputs.empty())
graph_info_.num_mixed_roots++;
break;
default:
break;
}
}
}
Expand All @@ -106,7 +114,7 @@ class Executor2::Impl {
"doesn't specify a device id.");
}

void CalculatePrefechDepth() {
void CalculatePrefetchDepth() {
int depth = 1;
if (graph_info_.num_cpu_roots > 0)
depth = std::max(depth, config_.cpu_queue_depth);
Expand Down Expand Up @@ -162,7 +170,7 @@ class Executor2::Impl {
for (int i = 0; i < num_streams; i++)
streams_.push_back(CUDAStreamPool::instance().Get());
for (auto &node : graph_.nodes) {
auto stream_idx = assignment.GetStreamIdx(&node);
auto stream_idx = assignment[&node];

node.env.order = stream_idx.has_value()
? AccessOrder(streams_[*stream_idx])
Expand All @@ -180,6 +188,8 @@ class Executor2::Impl {
int num_cpu_roots = 0;
int num_mixed_roots = 0;
int num_gpu_roots = 0;

std::vector<ExecNode *> input_nodes;
} graph_info_;

std::unique_ptr<ThreadPool> tp_;
Expand All @@ -188,6 +198,8 @@ class Executor2::Impl {

ExecGraph graph_;
std::unique_ptr<tasking::Executor> exec_;

int iter_index_ = 0;
};


Expand Down
7 changes: 7 additions & 0 deletions dali/pipeline/executor/executor2/exec_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "dali/pipeline/executor/executor2/exec_graph.h"
#include "dali/pipeline/executor/executor2/op_task.h"
#include "dali/pipeline/operator/op_spec.h"
#include "dali/pipeline/graph/op_graph2.h"

namespace dali {
namespace exec2 {
Expand Down Expand Up @@ -46,6 +47,12 @@ void ClearWorkspacePayload(Workspace &ws) {
ws.InjectIterationData(nullptr);
}

ExecNode::ExecNode(std::unique_ptr<OperatorBase> op, const graph::OpNode *def)
: op(std::move(op)), def(def) {
if (def)
device = def->op_type;
}

void ExecNode::PutWorkspace(CachedWorkspace ws) {
ClearWorkspacePayload(*ws);
workspace_cache_.Put(std::move(ws));
Expand Down
6 changes: 4 additions & 2 deletions dali/pipeline/executor/executor2/exec_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ struct PipelineOutputTag {};
class DLL_PUBLIC ExecNode {
public:
ExecNode() = default;
explicit ExecNode(std::unique_ptr<OperatorBase> op, const graph::OpNode *def = nullptr)
: op(std::move(op)), def(def) {}
ExecNode(std::unique_ptr<OperatorBase> op, const graph::OpNode *def = nullptr);
explicit ExecNode(PipelineOutputTag) : is_pipeline_output(true) {}

std::vector<ExecEdge *> inputs, outputs;
Expand Down Expand Up @@ -102,8 +101,11 @@ class DLL_PUBLIC ExecNode {
}

const graph::OpNode *def = nullptr;
OpType device = OpType::CPU;
bool is_pipeline_output = false;

mutable bool visited = false;

private:
CachedWorkspace CreateOutputWorkspace();
CachedWorkspace CreateOpWorkspace();
Expand Down
Loading

0 comments on commit 10a1584

Please sign in to comment.