From 10a1584a5e3ef31f480c86b1293fb30617ba796e Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Wed, 3 Jul 2024 19:17:51 +0200 Subject: [PATCH] Minimal stream assignment. Signed-off-by: Michal Zientkiewicz --- dali/pipeline/executor/executor2/exec2.cc | 20 +- .../pipeline/executor/executor2/exec_graph.cc | 7 + dali/pipeline/executor/executor2/exec_graph.h | 6 +- .../executor/executor2/stream_assignment.h | 266 +++++++++++------- .../executor2/stream_assignment_test.cc | 160 +++++++++++ dali/pipeline/graph/graph_util.h | 71 +++++ dali/pipeline/graph/op_graph2.h | 36 +-- 7 files changed, 420 insertions(+), 146 deletions(-) create mode 100644 dali/pipeline/executor/executor2/stream_assignment_test.cc create mode 100644 dali/pipeline/graph/graph_util.h diff --git a/dali/pipeline/executor/executor2/exec2.cc b/dali/pipeline/executor/executor2/exec2.cc index afd73addc80..5d37e86b3a3 100644 --- a/dali/pipeline/executor/executor2/exec2.cc +++ b/dali/pipeline/executor/executor2/exec2.cc @@ -68,9 +68,10 @@ class Executor2::Impl { } private: - auto InitIterationData() { + std::shared_ptr InitIterationData() { auto iter_data = std::make_shared(); - iter_data->iteration_index = iter_index++; + iter_data->iteration_index = iter_index_++; + return iter_data; } @@ -79,6 +80,11 @@ class Executor2::Impl { } void AnalyzeGraph() { + CountNodes(); + //FindInputNodes(); + } + + void CountNodes() { for (auto &n : graph_.nodes) { switch (NodeType(&n)) { case OpType::CPU: @@ -96,6 +102,8 @@ class Executor2::Impl { if (n.inputs.empty()) graph_info_.num_mixed_roots++; break; + default: + break; } } } @@ -106,7 +114,7 @@ class Executor2::Impl { "doesn't specify a device id."); } - void CalculatePrefechDepth() { + void CalculatePrefetchDepth() { int depth = 1; if (graph_info_.num_cpu_roots > 0) depth = std::max(depth, config_.cpu_queue_depth); @@ -162,7 +170,7 @@ class Executor2::Impl { for (int i = 0; i < num_streams; i++) streams_.push_back(CUDAStreamPool::instance().Get()); for (auto &node : graph_.nodes) { - auto stream_idx = assignment.GetStreamIdx(&node); + auto stream_idx = assignment[&node]; node.env.order = stream_idx.has_value() ? AccessOrder(streams_[*stream_idx]) @@ -180,6 +188,8 @@ class Executor2::Impl { int num_cpu_roots = 0; int num_mixed_roots = 0; int num_gpu_roots = 0; + + std::vector input_nodes; } graph_info_; std::unique_ptr tp_; @@ -188,6 +198,8 @@ class Executor2::Impl { ExecGraph graph_; std::unique_ptr exec_; + + int iter_index_ = 0; }; diff --git a/dali/pipeline/executor/executor2/exec_graph.cc b/dali/pipeline/executor/executor2/exec_graph.cc index 8e8aa8af252..967325d97b2 100644 --- a/dali/pipeline/executor/executor2/exec_graph.cc +++ b/dali/pipeline/executor/executor2/exec_graph.cc @@ -18,6 +18,7 @@ #include "dali/pipeline/executor/executor2/exec_graph.h" #include "dali/pipeline/executor/executor2/op_task.h" #include "dali/pipeline/operator/op_spec.h" +#include "dali/pipeline/graph/op_graph2.h" namespace dali { namespace exec2 { @@ -46,6 +47,12 @@ void ClearWorkspacePayload(Workspace &ws) { ws.InjectIterationData(nullptr); } +ExecNode::ExecNode(std::unique_ptr op, const graph::OpNode *def) +: op(std::move(op)), def(def) { + if (def) + device = def->op_type; +} + void ExecNode::PutWorkspace(CachedWorkspace ws) { ClearWorkspacePayload(*ws); workspace_cache_.Put(std::move(ws)); diff --git a/dali/pipeline/executor/executor2/exec_graph.h b/dali/pipeline/executor/executor2/exec_graph.h index 94b90d1e1d3..679bbedd391 100644 --- a/dali/pipeline/executor/executor2/exec_graph.h +++ b/dali/pipeline/executor/executor2/exec_graph.h @@ -67,8 +67,7 @@ struct PipelineOutputTag {}; class DLL_PUBLIC ExecNode { public: ExecNode() = default; - explicit ExecNode(std::unique_ptr op, const graph::OpNode *def = nullptr) - : op(std::move(op)), def(def) {} + ExecNode(std::unique_ptr op, const graph::OpNode *def = nullptr); explicit ExecNode(PipelineOutputTag) : is_pipeline_output(true) {} std::vector inputs, outputs; @@ -102,8 +101,11 @@ class DLL_PUBLIC ExecNode { } const graph::OpNode *def = nullptr; + OpType device = OpType::CPU; bool is_pipeline_output = false; + mutable bool visited = false; + private: CachedWorkspace CreateOutputWorkspace(); CachedWorkspace CreateOpWorkspace(); diff --git a/dali/pipeline/executor/executor2/stream_assignment.h b/dali/pipeline/executor/executor2/stream_assignment.h index ef5f422e188..27e8e3c05f8 100644 --- a/dali/pipeline/executor/executor2/stream_assignment.h +++ b/dali/pipeline/executor/executor2/stream_assignment.h @@ -15,13 +15,13 @@ #ifndef DALI_PIPELINE_EXECUTOR_EXECUTOR2_STREAM_ASSIGNMENT_H_ #define DALI_PIPELINE_EXECUTOR_EXECUTOR2_STREAM_ASSIGNMENT_H_ +#include "dali/pipeline/graph/graph_util.h" #include #include #include #include "dali/pipeline/executor/executor2/exec_graph.h" #include "dali/pipeline/executor/executor2/exec2.h" - namespace dali { namespace exec2 { @@ -29,14 +29,13 @@ template class StreamAssignment; inline bool NeedsStream(const ExecNode *node) { - if (node->def) { - if (node->def->op_type != OpType::CPU) - return true; - } else if (node->is_pipeline_output) { + if (node->is_pipeline_output) { for (auto &pipe_out : node->inputs) { if (pipe_out->device == StorageDevice::GPU) return true; } + } else { + return node->device != OpType::CPU; } return false; } @@ -55,9 +54,9 @@ inline OpType NodeType(const ExecNode *node) { } } return type; + } else { + return node->device; } - assert(node->def); - return node->def->op_type; } template <> @@ -71,7 +70,7 @@ class StreamAssignment { } } - std::optional GetStreamIdx(const ExecNode *node) { + std::optional operator[](const ExecNode *node) const { if (NeedsStream(node)) return 0; else @@ -115,7 +114,7 @@ class StreamAssignment { * If the node is a GPU node it gets stream index 1 if there are any mixed nodes, otherwise * the only stream is the GPU stream and the returned index is 0. */ - std::optional GetStreamIdx(const ExecNode *node) { + std::optional operator[](const ExecNode *node) const { switch (NodeType(node)) { case OpType::CPU: return std::nullopt; @@ -162,10 +161,10 @@ class StreamAssignment { Assign(graph); } - std::optional GetStreamIdx(const ExecNode *node) const { - auto it = stream_assignment_.find(node); - assert(it != stream_assignment_.end()); - return it->second; + std::optional operator[](const ExecNode *node) const { + auto it = node_ids_.find(node); + assert(it != node_ids_.end()); + return stream_assignment_[it->second]; } /** Gets the total number of streams required to run independent operators on separate streams. */ @@ -175,112 +174,169 @@ class StreamAssignment { private: void Assign(ExecGraph &graph) { - int next_stream_id; + // pre-fill the id pool with sequential numbers + for (int i = 0, n = graph.nodes.size(); i < n; i++) { + free_stream_ids_.insert(i); + } + + // Sort the graph topologically with DFS for (auto &node : graph.nodes) { - next_stream_id += Assign(&node, next_stream_id); + Sort(&node); } - for (auto &kv : stream_assignment_) - if (kv.second == -1) - kv.second = std::nullopt; - total_streams_ = next_stream_id; + for (auto &node : graph.nodes) { + if (node.inputs.empty()) + queue_.push({ node_ids_[&node], NextStreamId(&node).value_or(kInvalidStreamIdx) }); + } + + assert(graph.nodes.size() == sorted_nodes_.size()); + stream_assignment_.resize(sorted_nodes_.size()); + + FindGPUContributors(graph); + + graph::ClearVisitMarkers(graph.nodes); + Traverse(); + ClearCPUStreams(); + total_streams_ = CalcNumStreams(); } - int Assign(ExecNode *node, int next_stream_id) { - /* The assignment algorithm. - - The function assigns stream indices in a depth-first fashion. This allows for easy reuse - of a stream for direct successors (they're serialized anyway, so we can just use one - stream for them, skipping synchronization later). - The function returns the number of streams needed for parallel execution of independent - GPU/Mixed operators. - - 1. Assign the current stream id to the node, if it needs a stream. - 2. Go over all outputs, depth first. - a. recursively call Assign on a child node - - if a child node already has assignment, it'll be skipped and return 0 streams needed. - b. if processing an output needs some streams, bump the stream index. - c. report the number of streams needed - if the node is a GPU/Mixed node, it'll need - at least 1 stream (see the final return statement). - - Example (C denotes a CPU node, G - a GPU node) - - Graph: - ----C - / - --- C ---- G ---- G --- G - \ \ \ / - \ \ ----- G - - \ \ / - \ ------ C ------ G --- G - \ \ - ----- G --- C ---- G - - --- G - - Visiting order (excludes edges leading to nodes already visited) - ----2 - / - --- 0 ---- 1 ---- 3 --- 4 - \ \ \ - \ \ ----- 5 - \ \ - \ ------ 6 ------ 7 --- 8 - \ \ - ---- 10 --- 11 ---- 9 - - --- 12 - - Return values marked on edges - - --0-C - / - -5- C --2- G --1- G -1- G - \ \ \ - \ \ ---1- G - \ \ - \ ----2- C ---2-- G -1- G - \ \ - ---1- G -0- C --1- G - - -1- G - - next_stream_id (includes CPU operators) - - ---- 0 - / - --- 0 ---- 0 ---- 0 --- 0 - \ \ \ - \ \ ----- 1 - \ \ - \ ------ 2 ------ 2 --- 2 - \ \ - ----- 4 --- 4 ---- 3 - - --- 5 - - The final stream assignment is equal to next_stream_id shown above, but CPU operators get -1. - */ - auto &assignment = stream_assignment_[node]; - if (assignment.has_value()) // this doubles as a visit marker - return 0; - bool needs_stream = NeedsStream(node); + void Traverse() { + while (!queue_.empty()) { + // PrintQueue(); /* uncomment for debugging */ + auto [idx, stream_idx] = queue_.top(); + std::optional stream_id; + if (stream_idx != kInvalidStreamIdx) + stream_id = stream_idx; + + queue_.pop(); + auto *node = sorted_nodes_[idx]; + // This will be true for nodes which has no outputs or which doesn't contribute to any + // GPU nodes. + bool keep_stream_id = stream_id.has_value(); + + graph::Visit v(node); + if (!v) { + assert(stream_assignment_[idx].value_or(kInvalidStreamIdx) <= stream_idx); + continue; // we've been here already - skip + } + + stream_assignment_[idx] = stream_id; + + if (stream_id.has_value()) + free_stream_ids_.insert(*stream_id); + for (auto *out : node->outputs) { + auto out_stream_id = NextStreamId(out->consumer, stream_id); + if (out_stream_id.has_value()) + keep_stream_id = false; + queue_.push({node_ids_[out->consumer], out_stream_id.value_or(kInvalidStreamIdx)}); + } + if (keep_stream_id) + free_stream_ids_.erase(*stream_id); + } + } + + void ClearCPUStreams() { + for (int i = 0, n = sorted_nodes_.size(); i < n; i++) { + if (!NeedsStream(sorted_nodes_[i])) + stream_assignment_[i] = std::nullopt; + } + } + + int CalcNumStreams() { + int max = -1; + for (auto a : stream_assignment_) { + if (a.has_value()) + max = std::max(max, *a); + } + return max + 1; + } + + void PrintQueue(std::ostream &os = std::cout) { + auto q2 = queue_; + while (!q2.empty()) { + auto [idx, stream_idx] = q2.top(); + q2.pop(); + auto *node = sorted_nodes_[idx]; + if (node->def) + os << node->def->instance_name; + else if (node->is_pipeline_output) + os << ""; + else + os << "[" << idx << "]"; + os << "("; + if (stream_idx != kInvalidStreamIdx) + os << stream_idx; + else + os << "none"; + os << ") "; + } + os << "\n"; + } + + void Sort(const ExecNode *node) { + assert(node); + graph::Visit visit(node); + if (!visit) + return; + int idx = sorted_nodes_.size(); + node_ids_.emplace(node, idx); + for (auto *edge : node->inputs) { + assert(edge); + Sort(edge->producer); + } + sorted_nodes_.push_back(node); + } + + std::optional NextStreamId(const ExecNode *node, + std::optional prev_stream_id = std::nullopt) { + // If the preceding node had a stream, then we have to pass it on through CPU nodes + // if there are any GPU nodes down the graph. + // If the preciding node didn't have a stream, then we only need a stream if current + // node needs a stram. + bool needs_stream = prev_stream_id.has_value() + ? gpu_contributors_.count(node) != 0 + : NeedsStream(node); if (needs_stream) { - assignment = next_stream_id; + assert(!free_stream_ids_.empty()); + auto b = free_stream_ids_.begin(); + int ret = *b; + free_stream_ids_.erase(b); + return ret; } else { - assignment = -1; + return std::nullopt; } + } - int subgraph_streams = 0; - for (auto *edge : node->outputs) { - subgraph_streams += Assign(edge->consumer, next_stream_id + subgraph_streams); + void FindGPUContributors(ExecGraph &graph) { + // Run DFS, output to input, and find nodes which contribute to any node that requires a stream + graph::ClearVisitMarkers(graph.nodes); + for (auto &node : graph.nodes) { + if (node.outputs.empty()) + FindGPUContributors(&node, false); } + } - return std::max(subgraph_streams, needs_stream ? 1 : 0); + void FindGPUContributors(const ExecNode *node, bool is_gpu_contributor) { + graph::Visit v(node); + if (!v) + return; + if (!is_gpu_contributor) + is_gpu_contributor = NeedsStream(node); + if (is_gpu_contributor) + gpu_contributors_.insert(node); + for (auto *inp : node->inputs) + FindGPUContributors(inp->producer, is_gpu_contributor); } - std::unordered_map> stream_assignment_; + + static constexpr int kInvalidStreamIdx = 0x7fffffff; + std::vector> stream_assignment_; int total_streams_ = 0; + std::unordered_map node_ids_; // topologically sorted nodes + std::set gpu_contributors_; + std::vector sorted_nodes_; + std::set free_stream_ids_; + std::priority_queue, std::vector>, std::greater<>> queue_; }; } // namespace exec2 diff --git a/dali/pipeline/executor/executor2/stream_assignment_test.cc b/dali/pipeline/executor/executor2/stream_assignment_test.cc new file mode 100644 index 00000000000..8c51396ab7c --- /dev/null +++ b/dali/pipeline/executor/executor2/stream_assignment_test.cc @@ -0,0 +1,160 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "dali/pipeline/executor/executor2/stream_assignment.h" +#include "dali/pipeline/operator/operator.h" +#include "dali/pipeline/operator/operator_factory.h" + +namespace dali { + +template +class StreamAssignmentDummyOp : public Operator { + public: + using Operator::Operator; + USE_OPERATOR_MEMBERS(); + + void RunImpl(Workspace &ws) override {} + bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { + return false; + } + +}; + +DALI_SCHEMA(StreamAssignmentDummyOp) + .NumInput(0, 999) + .NumOutput(0) + .AdditionalOutputsFn([](const OpSpec &spec) { + return spec.NumOutput(); + }); + +DALI_REGISTER_OPERATOR(StreamAssignmentDummyOp, StreamAssignmentDummyOp, CPU); + +DALI_REGISTER_OPERATOR(StreamAssignmentDummyOp, StreamAssignmentDummyOp, GPU); + +namespace exec2 { + +namespace { + +OpSpec SpecDev(const std::string &device) { + return OpSpec("StreamAssignmentDummyOp") + .AddArg("device", device) + .AddArg("num_threads", 1) + .AddArg("max_batch_size", 1); +} + +OpSpec SpecGPU() { + return SpecDev("gpu"); +} + +OpSpec SpecCPU() { + return SpecDev("cpu"); +} + +auto MakeNodeMap(const ExecGraph &graph) { + std::map> map; + for (auto &n : graph.nodes) + if (n.def) { + map[n.def->instance_name] = &n; + } + return map; +} + +} // namespace + +TEST(StreamAssignmentTest, PerOperator) { + ExecGraph eg; + /* + --c-- g + / \ / \ + a -- b ----- d ----- f ---- h ---> out + \ (cpu) / / + --------------e / + / + i ----------------- j(cpu) + + k ----------------------------> out + + */ + graph::OpGraph::Builder b; + b.Add("a", + SpecGPU() + .AddOutput("a->b", "gpu") + .AddOutput("a->e", "gpu")); + b.Add("i", + SpecGPU() + .AddOutput("i->j", "gpu")); + b.Add("j", + SpecCPU() + .AddInput("i->j", "gpu") + .AddOutput("j->h", "cpu")); + b.Add("b", + SpecCPU() + .AddInput("a->b", "gpu") + .AddOutput("b->c", "cpu") + .AddOutput("b->d", "cpu")); + b.Add("c", + SpecGPU() + .AddInput("b->c", "cpu") + .AddOutput("c->d", "gpu")); + b.Add("d", + SpecGPU() + .AddInput("b->d", "cpu") + .AddInput("c->d", "gpu") + .AddOutput("d->f", "gpu")); + b.Add("e", + SpecGPU() + .AddInput("a->e", "gpu") + .AddOutput("e->f", "gpu")); + b.Add("f", + SpecGPU() + .AddInput("d->f", "gpu") + .AddInput("e->f", "gpu") + .AddOutput("f->g", "gpu") + .AddOutput("f->h", "gpu")); + b.Add("g", + SpecGPU() + .AddInput("f->g", "gpu") + .AddOutput("g->h", "gpu")); + b.Add("h", + SpecGPU() + .AddInput("f->h", "gpu") + .AddInput("g->h", "gpu") + .AddInput("j->h", "cpu") + .AddOutput("h->o", "gpu")); + b.Add("k", + SpecGPU() + .AddOutput("k->o", "gpu")); // directly to output + b.AddOutput("h->o_gpu"); + b.AddOutput("k->o_gpu"); + auto g = std::move(b).GetGraph(true); + eg.Lower(g); + + StreamAssignment assignment(eg); + auto map = MakeNodeMap(eg); + EXPECT_EQ(assignment[map["a"]], 0); + EXPECT_EQ(assignment[map["b"]], std::nullopt); + EXPECT_EQ(assignment[map["c"]], 0); + EXPECT_EQ(assignment[map["d"]], 0); + EXPECT_EQ(assignment[map["e"]], 3); + EXPECT_EQ(assignment[map["f"]], 0); + EXPECT_EQ(assignment[map["g"]], 0); + EXPECT_EQ(assignment[map["h"]], 0); + EXPECT_EQ(assignment[map["i"]], 1); + EXPECT_EQ(assignment[map["j"]], std::nullopt); + EXPECT_EQ(assignment[map["k"]], 2); +} + +} // namespace exec2 +} // namespace dali diff --git a/dali/pipeline/graph/graph_util.h b/dali/pipeline/graph/graph_util.h new file mode 100644 index 00000000000..61436c8985f --- /dev/null +++ b/dali/pipeline/graph/graph_util.h @@ -0,0 +1,71 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_PIPELINE_GRAPH_GRAPH_UTIL_H_ +#define DALI_PIPELINE_GRAPH_GRAPH_UTIL_H_ + +#include +#include "dali/core/util.h" + +namespace dali { +namespace graph { + +IMPL_HAS_MEMBER(visit_pending); + +/** A helper for visiting DAG nodes. + * + * When the node has a "visit_pending" member, the helper detects cycles. + */ +template +class Visit { + public: + explicit Visit(Node *n) : node_(n) { + if constexpr (has_member_visit_pending_v) { + if (node_->visit_pending) + throw std::logic_error("Cycle detected."); + node_->visit_pending = true; + } + new_visit_ = !n->visited; + node_->visited = true; + } + ~Visit() { + if constexpr (has_member_visit_pending_v) { + node_->visit_pending = false; + } + } + + explicit operator bool() const { + return new_visit_; + } + + private: + Node *node_; + bool new_visit_; +}; + +/** Clears visit markers. + * + * Sets the `visited` field to `false`. + * Typically used at the beginning of a graph-processing algorithm. + */ +template +static void ClearVisitMarkers(NodeList &nodes) { + for (auto &node : nodes) + node.visited = false; +} + +} // namespace graph +} // namespace dali + +#endif // DALI_PIPELINE_GRAPH_GRAPH_UTIL_H_ diff --git a/dali/pipeline/graph/op_graph2.h b/dali/pipeline/graph/op_graph2.h index 612829f52d8..8e90d59fe16 100644 --- a/dali/pipeline/graph/op_graph2.h +++ b/dali/pipeline/graph/op_graph2.h @@ -24,6 +24,7 @@ #include #include "dali/core/common.h" #include "dali/pipeline/operator/op_spec.h" +#include "dali/pipeline/graph/graph_util.h" namespace dali { namespace graph { @@ -239,41 +240,6 @@ class DLL_PUBLIC OpGraph { friend class SortHelper; }; -/** A helper for visiting DAG nodes - it features previous visit detection and cycle detection. */ -template -class Visit { - public: - explicit Visit(Node *n) : node_(n) { - if (node_->visit_pending) - throw std::logic_error("Cycle detected."); - node_->visit_pending = true; - new_visit_ = !n->visited; - node_->visited = true; - } - ~Visit() { - node_->visit_pending = false; - } - - explicit operator bool() const { - return new_visit_; - } - - private: - Node *node_; - bool new_visit_; -}; - -/** Clears visit markers. - * - * Sets the `visited` field to `false`. - * Typically used at the beginning of a graph-processing algorithm. - */ -template -static void ClearVisitMarkers(NodeList &nodes) { - for (auto &node : nodes) - node.visited = false; -} - /** A single-use class for constructing graphs. */ class DLL_PUBLIC OpGraph::Builder { public: