From 287fde491350bf30f091e3f1e83135ec9fa594f0 Mon Sep 17 00:00:00 2001 From: Levi Armstrong Date: Wed, 14 Aug 2024 13:32:59 -0500 Subject: [PATCH] Update TaskComposerNodeInfo to allow searching graph --- .../core/task_composer_graph.h | 7 +++ .../core/task_composer_node.h | 2 +- .../core/task_composer_node_info.h | 25 ++++++++++- .../src/nodes/has_data_storage_entry_task.cpp | 2 +- .../core/src/task_composer_executor.cpp | 4 +- .../core/src/task_composer_graph.cpp | 44 ++++++++++++++++--- .../core/src/task_composer_node_info.cpp | 38 +++++++++++++++- .../core/src/task_composer_pipeline.cpp | 2 - .../planning/nodes/motion_planner_task.hpp | 4 +- .../nodes/continuous_contact_check_task.cpp | 4 +- .../src/nodes/discrete_contact_check_task.cpp | 4 +- .../src/nodes/fix_state_collision_task.cpp | 4 +- 12 files changed, 113 insertions(+), 27 deletions(-) diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_graph.h b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_graph.h index 27af51f666..2132a16e4b 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_graph.h +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_graph.h @@ -111,6 +111,12 @@ class TaskComposerGraph : public TaskComposerNode */ void setTerminalTriggerAbortByIndex(int terminal_index); + /** Get the abort terminal uuid if set */ + boost::uuids::uuid getAbortTerminal() const; + + /** Get the abort terminal index if set */ + int getAbortTerminalIndex() const; + /** * @brief Check if the current state of the graph is valid * @todo Replace return type with std::expected when upgraded to use c++23 @@ -149,6 +155,7 @@ class TaskComposerGraph : public TaskComposerNode std::map nodes_; std::vector terminals_; + int abort_terminal_{ -1 }; }; } // namespace tesseract_planning diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node.h b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node.h index 677f06fc06..2d23a2de30 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node.h +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node.h @@ -41,7 +41,6 @@ TESSERACT_COMMON_IGNORE_WARNINGS_POP #include #include -#include namespace YAML { @@ -53,6 +52,7 @@ namespace tesseract_planning class TaskComposerDataStorage; class TaskComposerContext; class TaskComposerExecutor; +class TaskComposerNodeInfo; enum class TaskComposerNodeType { diff --git a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node_info.h b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node_info.h index 815363e333..cd2bcc634b 100644 --- a/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node_info.h +++ b/tesseract_task_composer/core/include/tesseract_task_composer/core/task_composer_node_info.h @@ -40,13 +40,12 @@ TESSERACT_COMMON_IGNORE_WARNINGS_POP #include +#include #include #include namespace tesseract_planning { -class TaskComposerNode; - /** Stores information about a node */ class TaskComposerNodeInfo { @@ -79,6 +78,15 @@ class TaskComposerNodeInfo */ boost::uuids::uuid parent_uuid{}; + /** @brief The node type */ + TaskComposerNodeType type; + + /** @brief The task type hash code from std::type_index */ + std::size_t type_hash_code{ 0 }; + + /** @brief The task is conditional or not */ + bool conditional{ false }; + /** @brief The nodes inbound edges */ std::vector inbound_edges; @@ -91,6 +99,12 @@ class TaskComposerNodeInfo /** @brief The output keys */ TaskComposerKeys output_keys; + /** @brief The graph of pipeline terminals */ + std::vector terminals; + + /** @brief Indicate if abort terminal was assigned. Only valid for graph and pipelines */ + int abort_terminal{ -1 }; + /** @brief Value returned from the Task on completion */ int return_value{ -1 }; @@ -189,6 +203,12 @@ class TaskComposerNodeInfoContainer /** @brief Merge the contents of another container's info map */ void mergeInfoMap(TaskComposerNodeInfoContainer&& container); + /** @brief Set the root node */ + void setRootNode(const boost::uuids::uuid& node_uuid); + + /** @brief Get the root node */ + boost::uuids::uuid getRootNode() const; + /** * @brief Called if aborted * @details This is set if abort is called in input @@ -217,6 +237,7 @@ class TaskComposerNodeInfoContainer void serialize(Archive& ar, const unsigned int version); // NOLINT mutable std::shared_mutex mutex_; + boost::uuids::uuid root_node_{}; boost::uuids::uuid aborting_node_{}; std::map info_map_; diff --git a/tesseract_task_composer/core/src/nodes/has_data_storage_entry_task.cpp b/tesseract_task_composer/core/src/nodes/has_data_storage_entry_task.cpp index 0ea3490f90..e3b59440a7 100644 --- a/tesseract_task_composer/core/src/nodes/has_data_storage_entry_task.cpp +++ b/tesseract_task_composer/core/src/nodes/has_data_storage_entry_task.cpp @@ -67,7 +67,7 @@ std::unique_ptr HasDataStorageEntryTask::runImpl(TaskCompo bool HasDataStorageEntryTask::operator==(const HasDataStorageEntryTask& rhs) const { - return (TaskComposerNode::operator==(rhs)); + return (TaskComposerTask::operator==(rhs)); } bool HasDataStorageEntryTask::operator!=(const HasDataStorageEntryTask& rhs) const { return !operator==(rhs); } diff --git a/tesseract_task_composer/core/src/task_composer_executor.cpp b/tesseract_task_composer/core/src/task_composer_executor.cpp index 332803a313..4fb51fb5e9 100644 --- a/tesseract_task_composer/core/src/task_composer_executor.cpp +++ b/tesseract_task_composer/core/src/task_composer_executor.cpp @@ -46,7 +46,9 @@ std::unique_ptr TaskComposerExecutor::run(const TaskComposer std::shared_ptr data_storage, bool dotgraph) { - return run(node, std::make_shared(node.getName(), std::move(data_storage), dotgraph)); + auto context = std::make_shared(node.getName(), std::move(data_storage), dotgraph); + context->task_infos.setRootNode(node.getUUID()); + return run(node, context); } bool TaskComposerExecutor::operator==(const TaskComposerExecutor& rhs) const { return (name_ == rhs.name_); } diff --git a/tesseract_task_composer/core/src/task_composer_graph.cpp b/tesseract_task_composer/core/src/task_composer_graph.cpp index a4d77c7327..76fb10b5e0 100644 --- a/tesseract_task_composer/core/src/task_composer_graph.cpp +++ b/tesseract_task_composer/core/src/task_composer_graph.cpp @@ -321,17 +321,31 @@ void TaskComposerGraph::setTerminalTriggerAbort(boost::uuids::uuid terminal) { if (!terminal.is_nil()) { - auto& n = nodes_.at(terminal); - if (n->getType() == TaskComposerNodeType::TASK) - static_cast(*n).setTriggerAbort(true); - else - throw std::runtime_error("Tasks can only trigger abort!"); + abort_terminal_ = -1; + for (std::size_t i = 0; i < terminals_.size(); ++i) + { + const boost::uuids::uuid& uuid = terminals_[i]; + if (uuid == terminal) + { + abort_terminal_ = static_cast(i); + auto& n = nodes_.at(terminal); + if (n->getType() == TaskComposerNodeType::TASK) + static_cast(*n).setTriggerAbort(true); + else + throw std::runtime_error("Tasks can only trigger abort!"); + + break; + } + } + if (abort_terminal_ < 0) + throw std::runtime_error("Task with uuid: " + boost::uuids::to_string(terminal) + " is not a terminal node"); } else { - for (const auto& terminal : terminals_) + abort_terminal_ = -1; + for (const auto& t : terminals_) { - auto& n = nodes_.at(terminal); + auto& n = nodes_.at(t); if (n->getType() == TaskComposerNodeType::TASK) static_cast(*n).setTriggerAbort(false); } @@ -342,6 +356,7 @@ void TaskComposerGraph::setTerminalTriggerAbortByIndex(int terminal_index) { if (terminal_index >= 0) { + abort_terminal_ = terminal_index; auto& n = nodes_.at(terminals_.at(static_cast(terminal_index))); if (n->getType() == TaskComposerNodeType::TASK) static_cast(*n).setTriggerAbort(true); @@ -350,6 +365,7 @@ void TaskComposerGraph::setTerminalTriggerAbortByIndex(int terminal_index) } else { + abort_terminal_ = -1; for (const auto& terminal : terminals_) { auto& n = nodes_.at(terminal); @@ -359,6 +375,16 @@ void TaskComposerGraph::setTerminalTriggerAbortByIndex(int terminal_index) } } +boost::uuids::uuid TaskComposerGraph::getAbortTerminal() const +{ + if (abort_terminal_ >= 0) + return terminals_.at(static_cast(abort_terminal_)); + + return {}; +} + +int TaskComposerGraph::getAbortTerminalIndex() const { return abort_terminal_; } + std::pair TaskComposerGraph::isValid() const { int root_node_cnt{ 0 }; @@ -409,6 +435,7 @@ TaskComposerGraph::dump(std::ostream& os, << "\\nUUID: " << uuid_str_ << "\\l"; os << "Inputs:\\l" << input_keys_; os << "Outputs:\\l" << output_keys_; + os << "Abort Terminal: " << abort_terminal_ << "\\l"; os << "Conditional: " << ((conditional_) ? "True" : "False") << "\\l"; if (getType() == TaskComposerNodeType::PIPELINE || getType() == TaskComposerNodeType::GRAPH) { @@ -436,6 +463,7 @@ TaskComposerGraph::dump(std::ostream& os, << "\\l"; os << "Inputs:\\l" << input_keys; os << "Outputs:\\l" << output_keys; + os << "Abort Terminal: " << static_cast(*node).abort_terminal_ << "\\l"; os << "Conditional: " << ((node->isConditional()) ? "True" : "False") << "\\l"; if (it != results_map.end()) os << "Time: " << std::fixed << std::setprecision(3) << it->second->elapsed_time << "s\\l"; @@ -506,6 +534,7 @@ bool TaskComposerGraph::operator==(const TaskComposerGraph& rhs) const } } equal &= (terminals_ == rhs.terminals_); + equal &= (abort_terminal_ == rhs.abort_terminal_); equal &= TaskComposerNode::operator==(rhs); return equal; } @@ -519,6 +548,7 @@ void TaskComposerGraph::serialize(Archive& ar, const unsigned int /*version*/) { ar& boost::serialization::make_nvp("nodes", nodes_); ar& boost::serialization::make_nvp("terminals", terminals_); + ar& boost::serialization::make_nvp("abort_terminal", abort_terminal_); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(TaskComposerNode); } diff --git a/tesseract_task_composer/core/src/task_composer_node_info.cpp b/tesseract_task_composer/core/src/task_composer_node_info.cpp index c66c5a03ce..9860ab722e 100644 --- a/tesseract_task_composer/core/src/task_composer_node_info.cpp +++ b/tesseract_task_composer/core/src/task_composer_node_info.cpp @@ -39,7 +39,7 @@ TESSERACT_COMMON_IGNORE_WARNINGS_PUSH TESSERACT_COMMON_IGNORE_WARNINGS_POP #include -#include +#include namespace tesseract_planning { @@ -48,9 +48,20 @@ TaskComposerNodeInfo::TaskComposerNodeInfo(const TaskComposerNode& node) , ns(node.ns_) , uuid(node.uuid_) , parent_uuid(node.parent_uuid_) + , type(node.type_) + , type_hash_code(std::type_index(typeid(node)).hash_code()) + , conditional(node.conditional_) , inbound_edges(node.inbound_edges_) , outbound_edges(node.outbound_edges_) + , input_keys(node.input_keys_) + , output_keys(node.output_keys_) { + if (type == TaskComposerNodeType::GRAPH || type == TaskComposerNodeType::PIPELINE) + { + const auto& graph = static_cast(node); + terminals = graph.getTerminals(); + abort_terminal = graph.getAbortTerminalIndex(); + } } bool TaskComposerNodeInfo::operator==(const TaskComposerNodeInfo& rhs) const @@ -62,6 +73,9 @@ bool TaskComposerNodeInfo::operator==(const TaskComposerNodeInfo& rhs) const equal &= ns == rhs.ns; equal &= uuid == rhs.uuid; equal &= parent_uuid == rhs.parent_uuid; + equal &= type == rhs.type; + equal &= type_hash_code == rhs.type_hash_code; + equal &= conditional == rhs.conditional; equal &= return_value == rhs.return_value; equal &= status_code == rhs.status_code; equal &= status_message == rhs.status_message; @@ -71,6 +85,8 @@ bool TaskComposerNodeInfo::operator==(const TaskComposerNodeInfo& rhs) const equal &= tesseract_common::isIdentical(outbound_edges, rhs.outbound_edges, true); equal &= input_keys == rhs.input_keys; equal &= output_keys == rhs.output_keys; + equal &= terminals == rhs.terminals; + equal &= abort_terminal == rhs.abort_terminal; equal &= color == rhs.color; equal &= dotgraph == rhs.dotgraph; equal &= data_storage == rhs.data_storage; @@ -89,6 +105,9 @@ void TaskComposerNodeInfo::serialize(Archive& ar, const unsigned int /*version*/ ar& boost::serialization::make_nvp("ns", ns); ar& boost::serialization::make_nvp("uuid", uuid); ar& boost::serialization::make_nvp("parent_uuid", parent_uuid); + ar& boost::serialization::make_nvp("type", type); + ar& boost::serialization::make_nvp("type_hash_code", type_hash_code); + ar& boost::serialization::make_nvp("conditional", conditional); ar& boost::serialization::make_nvp("return_value", return_value); ar& boost::serialization::make_nvp("status_code", status_code); ar& boost::serialization::make_nvp("status_message", status_message); @@ -99,6 +118,8 @@ void TaskComposerNodeInfo::serialize(Archive& ar, const unsigned int /*version*/ ar& boost::serialization::make_nvp("outbound_edges", outbound_edges); ar& boost::serialization::make_nvp("input_keys", input_keys); ar& boost::serialization::make_nvp("output_keys", output_keys); + ar& boost::serialization::make_nvp("terminals", terminals); + ar& boost::serialization::make_nvp("abort_terminal", abort_terminal); ar& boost::serialization::make_nvp("color", color); ar& boost::serialization::make_nvp("dotgraph", dotgraph); ar& boost::serialization::make_nvp("data_storage", data_storage); @@ -178,6 +199,18 @@ TaskComposerNodeInfoContainer::find(const std::function lock(mutex_); + root_node_ = node_uuid; +} + +boost::uuids::uuid TaskComposerNodeInfoContainer::getRootNode() const +{ + std::shared_lock lock(mutex_); + return root_node_; +} + void TaskComposerNodeInfoContainer::setAborted(const boost::uuids::uuid& node_uuid) { assert(!node_uuid.is_nil()); @@ -264,6 +297,8 @@ bool TaskComposerNodeInfoContainer::operator==(const TaskComposerNodeInfoContain std::scoped_lock lock{ lhs_lock, rhs_lock }; bool equal = true; + equal &= root_node_ == rhs.root_node_; + equal &= aborting_node_ == rhs.aborting_node_; auto equality = [](const TaskComposerNodeInfo::UPtr& p1, const TaskComposerNodeInfo::UPtr& p2) { return (p1 && p2 && *p1 == *p2) || (!p1 && !p2); }; @@ -285,6 +320,7 @@ template void TaskComposerNodeInfoContainer::serialize(Archive& ar, const unsigned int /*version*/) { std::unique_lock lock(mutex_); + ar& BOOST_SERIALIZATION_NVP(root_node_); ar& BOOST_SERIALIZATION_NVP(aborting_node_); ar& BOOST_SERIALIZATION_NVP(info_map_); } diff --git a/tesseract_task_composer/core/src/task_composer_pipeline.cpp b/tesseract_task_composer/core/src/task_composer_pipeline.cpp index 2efc8b10d0..4c8fb4543e 100644 --- a/tesseract_task_composer/core/src/task_composer_pipeline.cpp +++ b/tesseract_task_composer/core/src/task_composer_pipeline.cpp @@ -80,8 +80,6 @@ std::unique_ptr TaskComposerPipeline::runImpl(TaskComposer { timer.stop(); auto info = std::make_unique(*this); - info->input_keys = input_keys_; - info->output_keys = output_keys_; info->return_value = static_cast(i); info->color = node_info->color; info->status_code = node_info->status_code; diff --git a/tesseract_task_composer/planning/include/tesseract_task_composer/planning/nodes/motion_planner_task.hpp b/tesseract_task_composer/planning/include/tesseract_task_composer/planning/nodes/motion_planner_task.hpp index 8719614659..ab0d2b9e5b 100644 --- a/tesseract_task_composer/planning/include/tesseract_task_composer/planning/nodes/motion_planner_task.hpp +++ b/tesseract_task_composer/planning/include/tesseract_task_composer/planning/nodes/motion_planner_task.hpp @@ -157,9 +157,7 @@ class MotionPlannerTask : public TaskComposerTask return info; } - std::shared_ptr env = - env_poly.template as>()->clone(); - info->data_storage.setData("environment", env); + auto env = env_poly.template as>(); auto input_data_poly = getData(*context.data_storage, INOUT_PROGRAM_PORT); if (input_data_poly.getType() != std::type_index(typeid(CompositeInstruction))) diff --git a/tesseract_task_composer/planning/src/nodes/continuous_contact_check_task.cpp b/tesseract_task_composer/planning/src/nodes/continuous_contact_check_task.cpp index a0c902bdfa..7ee6a1aa58 100644 --- a/tesseract_task_composer/planning/src/nodes/continuous_contact_check_task.cpp +++ b/tesseract_task_composer/planning/src/nodes/continuous_contact_check_task.cpp @@ -122,9 +122,7 @@ ContinuousContactCheckTask::runImpl(TaskComposerContext& context, OptionalTaskCo return info; } - std::shared_ptr env = - env_poly.as>()->clone(); - info->data_storage.setData("environment", env); + auto env = env_poly.as>(); auto input_data_poly = getData(*context.data_storage, INPUT_PROGRAM_PORT); if (input_data_poly.getType() != std::type_index(typeid(CompositeInstruction))) diff --git a/tesseract_task_composer/planning/src/nodes/discrete_contact_check_task.cpp b/tesseract_task_composer/planning/src/nodes/discrete_contact_check_task.cpp index e4be57f259..23b6bd854d 100644 --- a/tesseract_task_composer/planning/src/nodes/discrete_contact_check_task.cpp +++ b/tesseract_task_composer/planning/src/nodes/discrete_contact_check_task.cpp @@ -122,9 +122,7 @@ std::unique_ptr DiscreteContactCheckTask::runImpl(TaskComp return info; } - std::shared_ptr env = - env_poly.as>()->clone(); - info->data_storage.setData("environment", env); + auto env = env_poly.as>(); auto input_data_poly = getData(*context.data_storage, INPUT_PROGRAM_PORT); if (input_data_poly.getType() != std::type_index(typeid(CompositeInstruction))) diff --git a/tesseract_task_composer/planning/src/nodes/fix_state_collision_task.cpp b/tesseract_task_composer/planning/src/nodes/fix_state_collision_task.cpp index 92eea7bc58..4b515fabbe 100644 --- a/tesseract_task_composer/planning/src/nodes/fix_state_collision_task.cpp +++ b/tesseract_task_composer/planning/src/nodes/fix_state_collision_task.cpp @@ -408,9 +408,7 @@ std::unique_ptr FixStateCollisionTask::runImpl(TaskCompose return info; } - std::shared_ptr env = - env_poly.as>()->clone(); - info->data_storage.setData("environment", env); + auto env = env_poly.as>(); auto input_data_poly = getData(*context.data_storage, INOUT_PROGRAM_PORT); if (input_data_poly.getType() != std::type_index(typeid(CompositeInstruction)))