Skip to content

Commit

Permalink
Update TaskComposerNodeInfo to allow searching graph
Browse files Browse the repository at this point in the history
  • Loading branch information
Levi-Armstrong committed Aug 14, 2024
1 parent 4895a8e commit 287fde4
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ class TaskComposerGraph : public TaskComposerNode
*/
void setTerminalTriggerAbortByIndex(int terminal_index);

/** Get the abort terminal uuid if set */
boost::uuids::uuid getAbortTerminal() const;

/** Get the abort terminal index if set */
int getAbortTerminalIndex() const;

/**
* @brief Check if the current state of the graph is valid
* @todo Replace return type with std::expected when upgraded to use c++23
Expand Down Expand Up @@ -149,6 +155,7 @@ class TaskComposerGraph : public TaskComposerNode

std::map<boost::uuids::uuid, TaskComposerNode::Ptr> nodes_;
std::vector<boost::uuids::uuid> terminals_;
int abort_terminal_{ -1 };
};

} // namespace tesseract_planning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ TESSERACT_COMMON_IGNORE_WARNINGS_POP

#include <tesseract_task_composer/core/task_composer_keys.h>
#include <tesseract_task_composer/core/task_composer_node_ports.h>
#include <tesseract_task_composer/core/task_composer_node_info.h>

namespace YAML
{
Expand All @@ -53,6 +52,7 @@ namespace tesseract_planning
class TaskComposerDataStorage;
class TaskComposerContext;
class TaskComposerExecutor;
class TaskComposerNodeInfo;

enum class TaskComposerNodeType
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@ TESSERACT_COMMON_IGNORE_WARNINGS_POP

#include <tesseract_common/any_poly.h>

#include <tesseract_task_composer/core/task_composer_node.h>
#include <tesseract_task_composer/core/task_composer_keys.h>
#include <tesseract_task_composer/core/task_composer_data_storage.h>

namespace tesseract_planning
{
class TaskComposerNode;

/** Stores information about a node */
class TaskComposerNodeInfo
{
Expand Down Expand Up @@ -79,6 +78,15 @@ class TaskComposerNodeInfo
*/
boost::uuids::uuid parent_uuid{};

/** @brief The node type */
TaskComposerNodeType type;

/** @brief The task type hash code from std::type_index */
std::size_t type_hash_code{ 0 };

/** @brief The task is conditional or not */
bool conditional{ false };

/** @brief The nodes inbound edges */
std::vector<boost::uuids::uuid> inbound_edges;

Expand All @@ -91,6 +99,12 @@ class TaskComposerNodeInfo
/** @brief The output keys */
TaskComposerKeys output_keys;

/** @brief The graph of pipeline terminals */
std::vector<boost::uuids::uuid> terminals;

/** @brief Indicate if abort terminal was assigned. Only valid for graph and pipelines */
int abort_terminal{ -1 };

/** @brief Value returned from the Task on completion */
int return_value{ -1 };

Expand Down Expand Up @@ -189,6 +203,12 @@ class TaskComposerNodeInfoContainer
/** @brief Merge the contents of another container's info map */
void mergeInfoMap(TaskComposerNodeInfoContainer&& container);

/** @brief Set the root node */
void setRootNode(const boost::uuids::uuid& node_uuid);

/** @brief Get the root node */
boost::uuids::uuid getRootNode() const;

/**
* @brief Called if aborted
* @details This is set if abort is called in input
Expand Down Expand Up @@ -217,6 +237,7 @@ class TaskComposerNodeInfoContainer
void serialize(Archive& ar, const unsigned int version); // NOLINT

mutable std::shared_mutex mutex_;
boost::uuids::uuid root_node_{};
boost::uuids::uuid aborting_node_{};
std::map<boost::uuids::uuid, TaskComposerNodeInfo::UPtr> info_map_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::unique_ptr<TaskComposerNodeInfo> HasDataStorageEntryTask::runImpl(TaskCompo

bool HasDataStorageEntryTask::operator==(const HasDataStorageEntryTask& rhs) const
{
return (TaskComposerNode::operator==(rhs));
return (TaskComposerTask::operator==(rhs));
}
bool HasDataStorageEntryTask::operator!=(const HasDataStorageEntryTask& rhs) const { return !operator==(rhs); }

Expand Down
4 changes: 3 additions & 1 deletion tesseract_task_composer/core/src/task_composer_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ std::unique_ptr<TaskComposerFuture> TaskComposerExecutor::run(const TaskComposer
std::shared_ptr<TaskComposerDataStorage> data_storage,
bool dotgraph)
{
return run(node, std::make_shared<TaskComposerContext>(node.getName(), std::move(data_storage), dotgraph));
auto context = std::make_shared<TaskComposerContext>(node.getName(), std::move(data_storage), dotgraph);
context->task_infos.setRootNode(node.getUUID());
return run(node, context);
}

bool TaskComposerExecutor::operator==(const TaskComposerExecutor& rhs) const { return (name_ == rhs.name_); }
Expand Down
44 changes: 37 additions & 7 deletions tesseract_task_composer/core/src/task_composer_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,31 @@ void TaskComposerGraph::setTerminalTriggerAbort(boost::uuids::uuid terminal)
{
if (!terminal.is_nil())
{
auto& n = nodes_.at(terminal);
if (n->getType() == TaskComposerNodeType::TASK)
static_cast<TaskComposerTask&>(*n).setTriggerAbort(true);
else
throw std::runtime_error("Tasks can only trigger abort!");
abort_terminal_ = -1;
for (std::size_t i = 0; i < terminals_.size(); ++i)
{
const boost::uuids::uuid& uuid = terminals_[i];
if (uuid == terminal)
{
abort_terminal_ = static_cast<int>(i);
auto& n = nodes_.at(terminal);
if (n->getType() == TaskComposerNodeType::TASK)
static_cast<TaskComposerTask&>(*n).setTriggerAbort(true);
else
throw std::runtime_error("Tasks can only trigger abort!");

break;
}
}
if (abort_terminal_ < 0)
throw std::runtime_error("Task with uuid: " + boost::uuids::to_string(terminal) + " is not a terminal node");
}
else
{
for (const auto& terminal : terminals_)
abort_terminal_ = -1;
for (const auto& t : terminals_)
{
auto& n = nodes_.at(terminal);
auto& n = nodes_.at(t);
if (n->getType() == TaskComposerNodeType::TASK)
static_cast<TaskComposerTask&>(*n).setTriggerAbort(false);
}
Expand All @@ -342,6 +356,7 @@ void TaskComposerGraph::setTerminalTriggerAbortByIndex(int terminal_index)
{
if (terminal_index >= 0)
{
abort_terminal_ = terminal_index;
auto& n = nodes_.at(terminals_.at(static_cast<std::size_t>(terminal_index)));
if (n->getType() == TaskComposerNodeType::TASK)
static_cast<TaskComposerTask&>(*n).setTriggerAbort(true);
Expand All @@ -350,6 +365,7 @@ void TaskComposerGraph::setTerminalTriggerAbortByIndex(int terminal_index)
}
else
{
abort_terminal_ = -1;
for (const auto& terminal : terminals_)
{
auto& n = nodes_.at(terminal);
Expand All @@ -359,6 +375,16 @@ void TaskComposerGraph::setTerminalTriggerAbortByIndex(int terminal_index)
}
}

boost::uuids::uuid TaskComposerGraph::getAbortTerminal() const
{
if (abort_terminal_ >= 0)
return terminals_.at(static_cast<std::size_t>(abort_terminal_));

return {};
}

int TaskComposerGraph::getAbortTerminalIndex() const { return abort_terminal_; }

std::pair<bool, std::string> TaskComposerGraph::isValid() const
{
int root_node_cnt{ 0 };
Expand Down Expand Up @@ -409,6 +435,7 @@ TaskComposerGraph::dump(std::ostream& os,
<< "\\nUUID: " << uuid_str_ << "\\l";
os << "Inputs:\\l" << input_keys_;
os << "Outputs:\\l" << output_keys_;
os << "Abort Terminal: " << abort_terminal_ << "\\l";
os << "Conditional: " << ((conditional_) ? "True" : "False") << "\\l";
if (getType() == TaskComposerNodeType::PIPELINE || getType() == TaskComposerNodeType::GRAPH)
{
Expand Down Expand Up @@ -436,6 +463,7 @@ TaskComposerGraph::dump(std::ostream& os,
<< "\\l";
os << "Inputs:\\l" << input_keys;
os << "Outputs:\\l" << output_keys;
os << "Abort Terminal: " << static_cast<const TaskComposerGraph&>(*node).abort_terminal_ << "\\l";
os << "Conditional: " << ((node->isConditional()) ? "True" : "False") << "\\l";
if (it != results_map.end())
os << "Time: " << std::fixed << std::setprecision(3) << it->second->elapsed_time << "s\\l";
Expand Down Expand Up @@ -506,6 +534,7 @@ bool TaskComposerGraph::operator==(const TaskComposerGraph& rhs) const
}
}
equal &= (terminals_ == rhs.terminals_);
equal &= (abort_terminal_ == rhs.abort_terminal_);
equal &= TaskComposerNode::operator==(rhs);
return equal;
}
Expand All @@ -519,6 +548,7 @@ void TaskComposerGraph::serialize(Archive& ar, const unsigned int /*version*/)
{
ar& boost::serialization::make_nvp("nodes", nodes_);
ar& boost::serialization::make_nvp("terminals", terminals_);
ar& boost::serialization::make_nvp("abort_terminal", abort_terminal_);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(TaskComposerNode);
}

Expand Down
38 changes: 37 additions & 1 deletion tesseract_task_composer/core/src/task_composer_node_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TESSERACT_COMMON_IGNORE_WARNINGS_PUSH
TESSERACT_COMMON_IGNORE_WARNINGS_POP

#include <tesseract_task_composer/core/task_composer_node_info.h>
#include <tesseract_task_composer/core/task_composer_node.h>
#include <tesseract_task_composer/core/task_composer_graph.h>

namespace tesseract_planning
{
Expand All @@ -48,9 +48,20 @@ TaskComposerNodeInfo::TaskComposerNodeInfo(const TaskComposerNode& node)
, ns(node.ns_)
, uuid(node.uuid_)
, parent_uuid(node.parent_uuid_)
, type(node.type_)
, type_hash_code(std::type_index(typeid(node)).hash_code())
, conditional(node.conditional_)
, inbound_edges(node.inbound_edges_)
, outbound_edges(node.outbound_edges_)
, input_keys(node.input_keys_)
, output_keys(node.output_keys_)
{
if (type == TaskComposerNodeType::GRAPH || type == TaskComposerNodeType::PIPELINE)
{
const auto& graph = static_cast<const TaskComposerGraph&>(node);
terminals = graph.getTerminals();
abort_terminal = graph.getAbortTerminalIndex();
}
}

bool TaskComposerNodeInfo::operator==(const TaskComposerNodeInfo& rhs) const
Expand All @@ -62,6 +73,9 @@ bool TaskComposerNodeInfo::operator==(const TaskComposerNodeInfo& rhs) const
equal &= ns == rhs.ns;
equal &= uuid == rhs.uuid;
equal &= parent_uuid == rhs.parent_uuid;
equal &= type == rhs.type;
equal &= type_hash_code == rhs.type_hash_code;
equal &= conditional == rhs.conditional;
equal &= return_value == rhs.return_value;
equal &= status_code == rhs.status_code;
equal &= status_message == rhs.status_message;
Expand All @@ -71,6 +85,8 @@ bool TaskComposerNodeInfo::operator==(const TaskComposerNodeInfo& rhs) const
equal &= tesseract_common::isIdentical(outbound_edges, rhs.outbound_edges, true);
equal &= input_keys == rhs.input_keys;
equal &= output_keys == rhs.output_keys;
equal &= terminals == rhs.terminals;
equal &= abort_terminal == rhs.abort_terminal;
equal &= color == rhs.color;
equal &= dotgraph == rhs.dotgraph;
equal &= data_storage == rhs.data_storage;
Expand All @@ -89,6 +105,9 @@ void TaskComposerNodeInfo::serialize(Archive& ar, const unsigned int /*version*/
ar& boost::serialization::make_nvp("ns", ns);
ar& boost::serialization::make_nvp("uuid", uuid);
ar& boost::serialization::make_nvp("parent_uuid", parent_uuid);
ar& boost::serialization::make_nvp("type", type);
ar& boost::serialization::make_nvp("type_hash_code", type_hash_code);
ar& boost::serialization::make_nvp("conditional", conditional);
ar& boost::serialization::make_nvp("return_value", return_value);
ar& boost::serialization::make_nvp("status_code", status_code);
ar& boost::serialization::make_nvp("status_message", status_message);
Expand All @@ -99,6 +118,8 @@ void TaskComposerNodeInfo::serialize(Archive& ar, const unsigned int /*version*/
ar& boost::serialization::make_nvp("outbound_edges", outbound_edges);
ar& boost::serialization::make_nvp("input_keys", input_keys);
ar& boost::serialization::make_nvp("output_keys", output_keys);
ar& boost::serialization::make_nvp("terminals", terminals);
ar& boost::serialization::make_nvp("abort_terminal", abort_terminal);
ar& boost::serialization::make_nvp("color", color);
ar& boost::serialization::make_nvp("dotgraph", dotgraph);
ar& boost::serialization::make_nvp("data_storage", data_storage);
Expand Down Expand Up @@ -178,6 +199,18 @@ TaskComposerNodeInfoContainer::find(const std::function<bool(const TaskComposerN
return results;
}

void TaskComposerNodeInfoContainer::setRootNode(const boost::uuids::uuid& node_uuid)
{
std::unique_lock<std::shared_mutex> lock(mutex_);
root_node_ = node_uuid;
}

boost::uuids::uuid TaskComposerNodeInfoContainer::getRootNode() const
{
std::shared_lock<std::shared_mutex> lock(mutex_);
return root_node_;
}

void TaskComposerNodeInfoContainer::setAborted(const boost::uuids::uuid& node_uuid)
{
assert(!node_uuid.is_nil());
Expand Down Expand Up @@ -264,6 +297,8 @@ bool TaskComposerNodeInfoContainer::operator==(const TaskComposerNodeInfoContain
std::scoped_lock lock{ lhs_lock, rhs_lock };

bool equal = true;
equal &= root_node_ == rhs.root_node_;
equal &= aborting_node_ == rhs.aborting_node_;
auto equality = [](const TaskComposerNodeInfo::UPtr& p1, const TaskComposerNodeInfo::UPtr& p2) {
return (p1 && p2 && *p1 == *p2) || (!p1 && !p2);
};
Expand All @@ -285,6 +320,7 @@ template <class Archive>
void TaskComposerNodeInfoContainer::serialize(Archive& ar, const unsigned int /*version*/)
{
std::unique_lock<std::shared_mutex> lock(mutex_);
ar& BOOST_SERIALIZATION_NVP(root_node_);
ar& BOOST_SERIALIZATION_NVP(aborting_node_);
ar& BOOST_SERIALIZATION_NVP(info_map_);
}
Expand Down
2 changes: 0 additions & 2 deletions tesseract_task_composer/core/src/task_composer_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ std::unique_ptr<TaskComposerNodeInfo> TaskComposerPipeline::runImpl(TaskComposer
{
timer.stop();
auto info = std::make_unique<TaskComposerNodeInfo>(*this);
info->input_keys = input_keys_;
info->output_keys = output_keys_;
info->return_value = static_cast<int>(i);
info->color = node_info->color;
info->status_code = node_info->status_code;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ class MotionPlannerTask : public TaskComposerTask
return info;
}

std::shared_ptr<const tesseract_environment::Environment> env =
env_poly.template as<std::shared_ptr<const tesseract_environment::Environment>>()->clone();
info->data_storage.setData("environment", env);
auto env = env_poly.template as<std::shared_ptr<const tesseract_environment::Environment>>();

auto input_data_poly = getData(*context.data_storage, INOUT_PROGRAM_PORT);
if (input_data_poly.getType() != std::type_index(typeid(CompositeInstruction)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ ContinuousContactCheckTask::runImpl(TaskComposerContext& context, OptionalTaskCo
return info;
}

std::shared_ptr<const tesseract_environment::Environment> env =
env_poly.as<std::shared_ptr<const tesseract_environment::Environment>>()->clone();
info->data_storage.setData("environment", env);
auto env = env_poly.as<std::shared_ptr<const tesseract_environment::Environment>>();

auto input_data_poly = getData(*context.data_storage, INPUT_PROGRAM_PORT);
if (input_data_poly.getType() != std::type_index(typeid(CompositeInstruction)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ std::unique_ptr<TaskComposerNodeInfo> DiscreteContactCheckTask::runImpl(TaskComp
return info;
}

std::shared_ptr<const tesseract_environment::Environment> env =
env_poly.as<std::shared_ptr<const tesseract_environment::Environment>>()->clone();
info->data_storage.setData("environment", env);
auto env = env_poly.as<std::shared_ptr<const tesseract_environment::Environment>>();

auto input_data_poly = getData(*context.data_storage, INPUT_PROGRAM_PORT);
if (input_data_poly.getType() != std::type_index(typeid(CompositeInstruction)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,7 @@ std::unique_ptr<TaskComposerNodeInfo> FixStateCollisionTask::runImpl(TaskCompose
return info;
}

std::shared_ptr<const tesseract_environment::Environment> env =
env_poly.as<std::shared_ptr<const tesseract_environment::Environment>>()->clone();
info->data_storage.setData("environment", env);
auto env = env_poly.as<std::shared_ptr<const tesseract_environment::Environment>>();

auto input_data_poly = getData(*context.data_storage, INOUT_PROGRAM_PORT);
if (input_data_poly.getType() != std::type_index(typeid(CompositeInstruction)))
Expand Down

0 comments on commit 287fde4

Please sign in to comment.