diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_graph.h b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_graph.h index 6b59311a65..27af51f666 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_graph.h +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_graph.h @@ -144,6 +144,9 @@ class TaskComposerGraph : public TaskComposerNode template void serialize(Archive& ar, const unsigned int version); // NOLINT + std::unique_ptr runImpl(TaskComposerContext& context, + OptionalTaskComposerExecutor executor = std::nullopt) const override; + std::map nodes_; std::vector terminals_; }; diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node.h b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node.h index a15b62278f..677f06fc06 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node.h +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node.h @@ -31,6 +31,7 @@ TESSERACT_COMMON_IGNORE_WARNINGS_PUSH #include #include #include +#include #include #include #include @@ -50,6 +51,8 @@ class Node; namespace tesseract_planning { class TaskComposerDataStorage; +class TaskComposerContext; +class TaskComposerExecutor; enum class TaskComposerNodeType { @@ -68,6 +71,9 @@ class TaskComposerNode using UPtr = std::unique_ptr; using ConstUPtr = std::unique_ptr; + /** @brief Most task will not require a executor so making it optional */ + using OptionalTaskComposerExecutor = std::optional>; + TaskComposerNode(std::string name = "TaskComposerNode", TaskComposerNodeType type = TaskComposerNodeType::NODE, TaskComposerNodePorts ports = TaskComposerNodePorts(), @@ -82,6 +88,8 @@ class TaskComposerNode TaskComposerNode(TaskComposerNode&&) = delete; TaskComposerNode& operator=(TaskComposerNode&&) = delete; + int run(TaskComposerContext& context, OptionalTaskComposerExecutor executor = std::nullopt) const; + /** @brief Set the name of the node */ void setName(const std::string& name); @@ -169,6 +177,9 @@ class TaskComposerNode template void serialize(Archive& ar, const unsigned int version); // NOLINT + virtual std::unique_ptr runImpl(TaskComposerContext& context, + OptionalTaskComposerExecutor executor = std::nullopt) const = 0; + /** @brief The name of the task */ std::string name_; @@ -208,6 +219,9 @@ class TaskComposerNode /** @brief The nodes ports definition */ TaskComposerNodePorts ports_; + /** @brief Indicate if task triggers abort */ + bool trigger_abort_{ false }; + /** @brief This will create a UUID string with no hyphens used when creating dot graph */ static std::string toString(const boost::uuids::uuid& u, const std::string& prefix = ""); diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node_info.h b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node_info.h index a7d9c2f0f0..db7f04dfaf 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node_info.h +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node_info.h @@ -135,8 +135,7 @@ class TaskComposerNodeInfo private: friend struct tesseract_common::Serialization; friend class boost::serialization::access; - friend class TaskComposerTask; - friend class TaskComposerPipeline; + friend class TaskComposerNode; /** @brief Indicate if task was not ran because abort flag was enabled */ bool aborted_{ false }; diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_pipeline.h b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_pipeline.h index 396c7407a2..d0287b4ef2 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_pipeline.h +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_pipeline.h @@ -30,7 +30,6 @@ TESSERACT_COMMON_IGNORE_WARNINGS_PUSH #include #include -#include TESSERACT_COMMON_IGNORE_WARNINGS_POP #include @@ -53,8 +52,6 @@ class TaskComposerPipeline : public TaskComposerGraph using ConstPtr = std::shared_ptr; using UPtr = std::unique_ptr; using ConstUPtr = std::unique_ptr; - /** @brief Most task will not require a executor so making it optional */ - using OptionalTaskComposerExecutor = std::optional>; TaskComposerPipeline(std::string name = "TaskComposerPipeline"); TaskComposerPipeline(std::string name, bool conditional); @@ -65,8 +62,6 @@ class TaskComposerPipeline : public TaskComposerGraph TaskComposerPipeline(TaskComposerPipeline&&) = delete; TaskComposerPipeline& operator=(TaskComposerPipeline&&) = delete; - int run(TaskComposerContext& context, OptionalTaskComposerExecutor executor = std::nullopt) const; - bool operator==(const TaskComposerPipeline& rhs) const; bool operator!=(const TaskComposerPipeline& rhs) const; @@ -77,8 +72,8 @@ class TaskComposerPipeline : public TaskComposerGraph template void serialize(Archive& ar, const unsigned int version); // NOLINT - std::unique_ptr runImpl(TaskComposerContext& context, - OptionalTaskComposerExecutor executor = std::nullopt) const; + std::unique_ptr + runImpl(TaskComposerContext& context, OptionalTaskComposerExecutor executor = std::nullopt) const override final; void runRecursive(const TaskComposerNode& node, TaskComposerContext& context, diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_task.h b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_task.h index 0ea1c7316d..797e7c727e 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_task.h +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_task.h @@ -30,7 +30,6 @@ TESSERACT_COMMON_IGNORE_WARNINGS_PUSH #include #include -#include TESSERACT_COMMON_IGNORE_WARNINGS_POP #include @@ -47,9 +46,6 @@ class TaskComposerTask : public TaskComposerNode using UPtr = std::unique_ptr; using ConstUPtr = std::unique_ptr; - /** @brief Most task will not require a executor so making it optional */ - using OptionalTaskComposerExecutor = std::optional>; - TaskComposerTask(); explicit TaskComposerTask(std::string name, TaskComposerNodePorts ports, bool conditional); explicit TaskComposerTask(std::string name, TaskComposerNodePorts ports, const YAML::Node& config); @@ -68,20 +64,12 @@ class TaskComposerTask : public TaskComposerNode */ void setTriggerAbort(bool enable); - int run(TaskComposerContext& context, OptionalTaskComposerExecutor executor = std::nullopt) const; - protected: - /** @brief Indicate if task triggers abort */ - bool trigger_abort_{ false }; - friend struct tesseract_common::Serialization; friend class boost::serialization::access; template void serialize(Archive& ar, const unsigned int version); // NOLINT - - virtual std::unique_ptr runImpl(TaskComposerContext& context, - OptionalTaskComposerExecutor executor = std::nullopt) const = 0; }; } // namespace tesseract_planning diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/test_suite/task_composer_node_info_unit.hpp b/tesseract_task_composer/core/include/tesseract_task_composer/core/test_suite/task_composer_node_info_unit.hpp index c79514130c..f1ee632c68 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/test_suite/task_composer_node_info_unit.hpp +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/test_suite/task_composer_node_info_unit.hpp @@ -34,6 +34,7 @@ TESSERACT_COMMON_IGNORE_WARNINGS_POP #include #include #include +#include #include namespace tesseract_planning::test_suite @@ -57,7 +58,7 @@ void runTaskComposerNodeInfoTest() } { // Constructor - TaskComposerNode node; + test_suite::DummyTaskComposerNode node; T node_info(node); EXPECT_EQ(node_info.return_value, -1); EXPECT_EQ(node_info.status_code, 0); diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/test_suite/test_task.h b/tesseract_task_composer/core/include/tesseract_task_composer/core/test_suite/test_task.h index a6bfaea693..824fb1a61e 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/test_suite/test_task.h +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/test_suite/test_task.h @@ -41,6 +41,14 @@ class TaskComposerPluginFactory; namespace tesseract_planning::test_suite { +class DummyTaskComposerNode : public TaskComposerNode +{ + using TaskComposerNode::TaskComposerNode; + + std::unique_ptr + runImpl(TaskComposerContext& context, OptionalTaskComposerExecutor /*executor*/ = std::nullopt) const override final; +}; + class TestTask : public TaskComposerTask { public: diff --git a/tesseract_task_composer/core/src/task_composer_graph.cpp b/tesseract_task_composer/core/src/task_composer_graph.cpp index 184648708f..a4d77c7327 100644 --- a/tesseract_task_composer/core/src/task_composer_graph.cpp +++ b/tesseract_task_composer/core/src/task_composer_graph.cpp @@ -36,8 +36,12 @@ TESSERACT_COMMON_IGNORE_WARNINGS_PUSH #include #include #include +#include TESSERACT_COMMON_IGNORE_WARNINGS_POP +#include +#include +#include #include #include #include @@ -214,6 +218,55 @@ TaskComposerGraph::TaskComposerGraph(std::string name, throw std::runtime_error(is_valid.second); } +std::unique_ptr TaskComposerGraph::runImpl(TaskComposerContext& context, + OptionalTaskComposerExecutor executor) const +{ + if (terminals_.empty()) + throw std::runtime_error("TaskComposerGraph, with name '" + name_ + "' does not have terminals!"); + + tesseract_common::Timer timer; + timer.start(); + + TaskComposerFuture::UPtr future = executor.value().get().run(*this, context.data_storage, context.dotgraph); + future->wait(); + + // Merge child context data into parent context + context.task_infos.mergeInfoMap(std::move(future->context->task_infos)); + if (future->context->isAborted()) + context.abort(future->context->task_infos.getAbortingNode()); + + auto info = std::make_unique(*this); + auto info_map = context.task_infos.getInfoMap(); + if (context.dotgraph) + { + std::stringstream dot_graph; + dot_graph << "subgraph cluster_" << toString(uuid_) << " {\n color=black;\n label = \"" << name_ << "\\n(" + << uuid_str_ << ")\";\n"; + dump(dot_graph, this, info_map); // dump the graph including dynamic tasks + dot_graph << "}\n"; + info->dotgraph = dot_graph.str(); + } + + for (std::size_t i = 0; i < terminals_.size(); ++i) + { + auto node_info = context.task_infos.getInfo(terminals_[i]); + if (node_info != nullptr) + { + timer.stop(); + info->input_keys = input_keys_; + info->output_keys = output_keys_; + info->return_value = static_cast(i); + info->color = node_info->color; + info->status_code = node_info->status_code; + info->status_message = node_info->status_message; + info->elapsed_time = timer.elapsedSeconds(); + return info; + } + } + + throw std::runtime_error("TaskComposerGraph, with name '" + name_ + "' has no node info for any of the leaf nodes!"); +} + boost::uuids::uuid TaskComposerGraph::addNode(std::unique_ptr task_node) { boost::uuids::uuid uuid = task_node->getUUID(); @@ -392,10 +445,9 @@ TaskComposerGraph::dump(std::ostream& os, } } - if (type_ == TaskComposerNodeType::PIPELINE) + if (type_ == TaskComposerNodeType::GRAPH || type_ == TaskComposerNodeType::PIPELINE) { - const auto& pipeline = static_cast(*this); - if (pipeline.isConditional()) + if (conditional_) { int return_value = -1; diff --git a/tesseract_task_composer/core/src/task_composer_node.cpp b/tesseract_task_composer/core/src/task_composer_node.cpp index 2673d80807..d8c04a648c 100644 --- a/tesseract_task_composer/core/src/task_composer_node.cpp +++ b/tesseract_task_composer/core/src/task_composer_node.cpp @@ -33,8 +33,10 @@ TESSERACT_COMMON_IGNORE_WARNINGS_PUSH #include #include #include +#include TESSERACT_COMMON_IGNORE_WARNINGS_POP +#include #include #include #include @@ -132,6 +134,59 @@ TaskComposerNode::TaskComposerNode(std::string name, validatePorts(); } +int TaskComposerNode::run(TaskComposerContext& context, OptionalTaskComposerExecutor executor) const +{ + auto start_time = std::chrono::system_clock::now(); + if (context.isAborted()) + { + auto info = std::make_unique(*this); + info->start_time = start_time; + info->input_keys = input_keys_; + info->output_keys = output_keys_; + info->return_value = 0; + info->color = "white"; + info->status_code = 0; + info->status_message = "Aborted"; + info->aborted_ = true; + context.task_infos.addInfo(std::move(info)); + return 0; + } + + tesseract_common::Timer timer; + TaskComposerNodeInfo::UPtr results; + timer.start(); + try + { + results = runImpl(context, executor); + } + catch (const std::exception& e) + { + results = std::make_unique(*this); + results->color = "red"; + results->status_code = -1; + results->status_message = "Exception thrown: " + std::string(e.what()); + results->return_value = 0; + } + timer.stop(); + results->input_keys = input_keys_; + results->output_keys = output_keys_; + results->start_time = start_time; + results->elapsed_time = timer.elapsedSeconds(); + + int value = results->return_value; + assert(value >= 0); + + // Call abort if required and is a task + if (type_ == TaskComposerNodeType::TASK && trigger_abort_ && !context.isAborted()) + { + results->status_message += " (Abort Triggered)"; + context.abort(uuid_); + } + + context.task_infos.addInfo(std::move(results)); + return value; +} + void TaskComposerNode::setName(const std::string& name) { name_ = name; } const std::string& TaskComposerNode::getName() const { return name_; } @@ -424,6 +479,7 @@ bool TaskComposerNode::operator==(const TaskComposerNode& rhs) const equal &= output_keys_ == rhs.output_keys_; equal &= conditional_ == rhs.conditional_; equal &= ports_ == rhs.ports_; + equal &= trigger_abort_ == rhs.trigger_abort_; return equal; } bool TaskComposerNode::operator!=(const TaskComposerNode& rhs) const { return !operator==(rhs); } @@ -442,6 +498,7 @@ void TaskComposerNode::serialize(Archive& ar, const unsigned int /*version*/) ar& boost::serialization::make_nvp("output_keys", output_keys_); ar& boost::serialization::make_nvp("conditional", conditional_); ar& boost::serialization::make_nvp("ports", ports_); + ar& boost::serialization::make_nvp("trigger_abort", trigger_abort_); } std::string TaskComposerNode::toString(const boost::uuids::uuid& u, const std::string& prefix) diff --git a/tesseract_task_composer/core/src/task_composer_pipeline.cpp b/tesseract_task_composer/core/src/task_composer_pipeline.cpp index b60f5cd248..2efc8b10d0 100644 --- a/tesseract_task_composer/core/src/task_composer_pipeline.cpp +++ b/tesseract_task_composer/core/src/task_composer_pipeline.cpp @@ -50,51 +50,6 @@ TaskComposerPipeline::TaskComposerPipeline(std::string name, { } -int TaskComposerPipeline::run(TaskComposerContext& context, OptionalTaskComposerExecutor executor) const -{ - auto start_time = std::chrono::system_clock::now(); - if (context.isAborted()) - { - auto info = std::make_unique(*this); - info->start_time = start_time; - info->input_keys = input_keys_; - info->output_keys = output_keys_; - info->return_value = 0; - info->status_code = 0; - info->status_message = "Aborted"; - info->color = "white"; - info->aborted_ = true; - context.task_infos.addInfo(std::move(info)); - return 0; - } - - tesseract_common::Timer timer; - std::unique_ptr results; - timer.start(); - try - { - results = runImpl(context, executor); - } - catch (const std::exception& e) - { - results = std::make_unique(*this); - results->color = "red"; - results->status_code = -1; - results->status_message = "Exception thrown: " + std::string(e.what()); - results->return_value = 0; - } - timer.stop(); - results->input_keys = input_keys_; - results->output_keys = output_keys_; - results->start_time = start_time; - results->elapsed_time = timer.elapsedSeconds(); - - int value = results->return_value; - assert(value >= 0); - context.task_infos.addInfo(std::move(results)); - return value; -} - std::unique_ptr TaskComposerPipeline::runImpl(TaskComposerContext& context, OptionalTaskComposerExecutor executor) const { @@ -144,42 +99,23 @@ void TaskComposerPipeline::runRecursive(const TaskComposerNode& node, TaskComposerContext& context, OptionalTaskComposerExecutor executor) const { - if (node.getType() == TaskComposerNodeType::GRAPH) - throw std::runtime_error("TaskComposerPipeline, does not support GRAPH node types. Name: '" + name_ + "'"); + if (node.getType() == TaskComposerNodeType::NODE) + throw std::runtime_error("TaskComposerPipeline, unsupported node type TaskComposerNodeType::NODE"); - if (node.getType() == TaskComposerNodeType::TASK) + int rv = node.run(context, executor); + if (node.isConditional()) { - const auto& task = static_cast(node); - int rv = task.run(context, executor); - if (task.isConditional()) - { - const auto& edge = node.getOutboundEdges().at(static_cast(rv)); - runRecursive(*nodes_.at(edge), context, executor); - } - else - { - if (node.getOutboundEdges().size() > 1) - throw std::runtime_error("TaskComposerPipeline, non conditional task can only have one out bound edge. Name: " - "'" + - name_ + "'"); - for (const auto& edge : node.getOutboundEdges()) - runRecursive(*(nodes_.at(edge)), context, executor); - } + const auto& edge = node.getOutboundEdges().at(static_cast(rv)); + runRecursive(*nodes_.at(edge), context, executor); } else { - const auto& pipeline = static_cast(node); - int rv = pipeline.run(context, executor); - if (pipeline.isConditional()) - { - const auto& edge = node.getOutboundEdges().at(static_cast(rv)); - runRecursive(*nodes_.at(edge), context, executor); - } - else - { - for (const auto& edge : node.getOutboundEdges()) - runRecursive(*nodes_.at(edge), context, executor); - } + if (node.getOutboundEdges().size() > 1) + throw std::runtime_error("TaskComposerPipeline, non conditional task can only have one out bound edge. Name: " + "'" + + name_ + "'"); + for (const auto& edge : node.getOutboundEdges()) + runRecursive(*(nodes_.at(edge)), context, executor); } } diff --git a/tesseract_task_composer/core/src/task_composer_task.cpp b/tesseract_task_composer/core/src/task_composer_task.cpp index 63156e4448..9d54317d84 100644 --- a/tesseract_task_composer/core/src/task_composer_task.cpp +++ b/tesseract_task_composer/core/src/task_composer_task.cpp @@ -61,66 +61,7 @@ TaskComposerTask::TaskComposerTask(std::string name, TaskComposerNodePorts ports void TaskComposerTask::setTriggerAbort(bool enable) { trigger_abort_ = enable; } -int TaskComposerTask::run(TaskComposerContext& context, OptionalTaskComposerExecutor executor) const -{ - auto start_time = std::chrono::system_clock::now(); - if (context.isAborted()) - { - auto info = std::make_unique(*this); - info->start_time = start_time; - info->input_keys = input_keys_; - info->output_keys = output_keys_; - info->return_value = 0; - info->color = "white"; - info->status_code = 0; - info->status_message = "Aborted"; - info->aborted_ = true; - context.task_infos.addInfo(std::move(info)); - return 0; - } - - tesseract_common::Timer timer; - TaskComposerNodeInfo::UPtr results; - timer.start(); - try - { - results = runImpl(context, executor); - } - catch (const std::exception& e) - { - results = std::make_unique(*this); - results->color = "red"; - results->status_code = -1; - results->status_message = "Exception thrown: " + std::string(e.what()); - results->return_value = 0; - } - timer.stop(); - results->input_keys = input_keys_; - results->output_keys = output_keys_; - results->start_time = start_time; - results->elapsed_time = timer.elapsedSeconds(); - - int value = results->return_value; - assert(value >= 0); - - // Call abort if required - if (trigger_abort_ && !context.isAborted()) - { - results->status_message += " (Abort Triggered)"; - context.abort(uuid_); - } - - context.task_infos.addInfo(std::move(results)); - return value; -} - -bool TaskComposerTask::operator==(const TaskComposerTask& rhs) const -{ - bool equal{ true }; - equal &= trigger_abort_ == rhs.trigger_abort_; - equal &= (TaskComposerNode::operator==(rhs)); - return equal; -} +bool TaskComposerTask::operator==(const TaskComposerTask& rhs) const { return (TaskComposerNode::operator==(rhs)); } // LCOV_EXCL_START bool TaskComposerTask::operator!=(const TaskComposerTask& rhs) const { return !operator==(rhs); } @@ -129,7 +70,6 @@ bool TaskComposerTask::operator!=(const TaskComposerTask& rhs) const { return !o template void TaskComposerTask::serialize(Archive& ar, const unsigned int /*version*/) { - ar& boost::serialization::make_nvp("trigger_abort", trigger_abort_); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(TaskComposerNode); } diff --git a/tesseract_task_composer/core/src/test_suite/test_task.cpp b/tesseract_task_composer/core/src/test_suite/test_task.cpp index 8f7add23a5..d98f12b852 100644 --- a/tesseract_task_composer/core/src/test_suite/test_task.cpp +++ b/tesseract_task_composer/core/src/test_suite/test_task.cpp @@ -37,6 +37,12 @@ TESSERACT_COMMON_IGNORE_WARNINGS_POP namespace tesseract_planning::test_suite { +std::unique_ptr DummyTaskComposerNode::runImpl(TaskComposerContext& /*context*/, + OptionalTaskComposerExecutor /*executor*/) const +{ + return std::make_unique(*this); +} + const std::string TestTask::INOUT_PORT1_PORT = "port1"; const std::string TestTask::INOUT_PORT2_PORT = "port2"; diff --git a/tesseract_task_composer/test/tesseract_task_composer_core_unit.cpp b/tesseract_task_composer/test/tesseract_task_composer_core_unit.cpp index d6b31a3937..e63ce5814e 100644 --- a/tesseract_task_composer/test/tesseract_task_composer_core_unit.cpp +++ b/tesseract_task_composer/test/tesseract_task_composer_core_unit.cpp @@ -121,7 +121,7 @@ TEST(TesseractTaskComposerCoreUnit, TaskComposerDataStorageTests) // NOLINT TEST(TesseractTaskComposerCoreUnit, TaskComposerContextTests) // NOLINT { - TaskComposerNode node; + test_suite::DummyTaskComposerNode node; auto context = std::make_unique("TaskComposerContextTests", std::make_unique()); EXPECT_EQ(context->name, "TaskComposerContextTests"); @@ -142,7 +142,7 @@ TEST(TesseractTaskComposerCoreUnit, TaskComposerContextTests) // NOLINT TEST(TesseractTaskComposerCoreUnit, TaskComposerNodeInfoContainerTests) // NOLINT { - TaskComposerNode node; + test_suite::DummyTaskComposerNode node; auto node_info = std::make_unique(node); auto node_info_container = std::make_unique(); @@ -179,7 +179,7 @@ TEST(TesseractTaskComposerCoreUnit, TaskComposerNodeInfoContainerTests) // NOLI TEST(TesseractTaskComposerCoreUnit, TaskComposerNodeTests) // NOLINT { std::stringstream os; - auto node = std::make_unique(); + auto node = std::make_unique(); // Default EXPECT_EQ(node->getName(), "TaskComposerNode"); EXPECT_EQ(node->getType(), TaskComposerNodeType::NODE); @@ -235,8 +235,8 @@ TEST(TesseractTaskComposerCoreUnit, TaskComposerNodeTests) // NOLINT { std::string str = R"(config:)"; YAML::Node config = YAML::Load(str); - auto task = - std::make_unique(name, TaskComposerNodeType::TASK, TaskComposerNodePorts{}, config["config"]); + auto task = std::make_unique( + name, TaskComposerNodeType::TASK, TaskComposerNodePorts{}, config["config"]); EXPECT_EQ(task->getName(), name); EXPECT_EQ(task->getType(), TaskComposerNodeType::TASK); EXPECT_TRUE(task->getInputKeys().empty()); @@ -251,8 +251,8 @@ TEST(TesseractTaskComposerCoreUnit, TaskComposerNodeTests) // NOLINT std::string str = R"(config: conditional: true)"; YAML::Node config = YAML::Load(str); - auto task = - std::make_unique(name, TaskComposerNodeType::TASK, TaskComposerNodePorts{}, config["config"]); + auto task = std::make_unique( + name, TaskComposerNodeType::TASK, TaskComposerNodePorts{}, config["config"]); EXPECT_EQ(task->getName(), name); EXPECT_EQ(task->getType(), TaskComposerNodeType::TASK); EXPECT_TRUE(task->getInputKeys().empty()); @@ -2363,6 +2363,105 @@ TEST(TesseractTaskComposerCoreUnit, TaskComposerServerTests) // NOLINT } } +TEST(TesseractTaskComposerCoreUnit, TaskComposerPipelineWithGraphChild) // NOLINT +{ + std::string str = R"(task_composer_plugins: + search_paths: + - /usr/local/lib + search_libraries: + - tesseract_task_composer_factories + - tesseract_task_composer_taskflow_factories + executors: + default: TaskflowExecutor + plugins: + TaskflowExecutor: + class: TaskflowTaskComposerExecutorFactory + config: + threads: 5 + tasks: + plugins: + TestPipeline: + class: PipelineTaskFactory + config: + conditional: true + nodes: + StartTask: + class: StartTaskFactory + config: + conditional: false + TestConditionalGraphTask: + task: TestGraph + config: + conditional: true + DoneTask: + class: DoneTaskFactory + config: + conditional: false + AbortTask: + class: DoneTaskFactory + config: + conditional: false + edges: + - source: StartTask + destinations: [TestConditionalGraphTask] + - source: TestConditionalGraphTask + destinations: [AbortTask, DoneTask] + terminals: [AbortTask, DoneTask] + TestGraph: + class: GraphTaskFactory + config: + conditional: false + nodes: + StartTask: + class: StartTaskFactory + config: + conditional: false + TestTask: + class: TestTaskFactory + config: + conditional: true + return_value: 1 + inputs: + port1: input_data + port2: [input_data2] + outputs: + port1: output_data + port2: [output_data2] + DoneTask: + class: DoneTaskFactory + config: + conditional: false + AbortTask: + class: DoneTaskFactory + config: + conditional: false + edges: + - source: StartTask + destinations: [TestTask] + - source: TestTask + destinations: [AbortTask, DoneTask] + terminals: [AbortTask, DoneTask])"; + + TaskComposerServer server; + server.loadConfig(str); + + // Run method using TaskComposerContext + const auto& pipeline = server.getTask("TestPipeline"); + auto data_storage = std::make_unique(); + auto future = server.run(pipeline, std::move(data_storage), false, "TaskflowExecutor"); + future->wait(); + + EXPECT_EQ(future->context->isAborted(), false); + EXPECT_EQ(future->context->isSuccessful(), true); + EXPECT_EQ(future->context->task_infos.getInfoMap().size(), 7); + EXPECT_TRUE(future->context->task_infos.getAbortingNode().is_nil()); + + std::ofstream os1; + os1.open(tesseract_common::getTempPath() + "task_composer_pipeline_with_conditional_child_graph_task.dot"); + EXPECT_NO_THROW(pipeline.dump(os1, nullptr, future->context->task_infos.getInfoMap())); // NOLINT + os1.close(); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv);