Skip to content

Commit

Permalink
Add a python backend node with orchestrator (#54)
Browse files Browse the repository at this point in the history
* Refs #21703: Add missing dependency

Signed-off-by: eProsima <[email protected]>

* Refs #21703: Refactor Orchestrator to be python node compatible

Signed-off-by: eProsima <[email protected]>

* Refs #21703: Add Orchestrator as a backend Node python module

Signed-off-by: eProsima <[email protected]>

* Refs #21703: Fix NITs

Signed-off-by: eProsima <[email protected]>

* Refs #21703: Adap test to the changes performed

Signed-off-by: eProsima <[email protected]>

* Refs #21703: Add Orchestrator python node test in modules setup

Signed-off-by: eProsima <[email protected]>

* Refs #21703: Apply rev suggestions

Signed-off-by: eProsima <[email protected]>

* Refs #21703: Address missing comments from rev suggestion

Signed-off-by: eProsima <[email protected]>

---------

Signed-off-by: eProsima <[email protected]>
  • Loading branch information
JesusPoderoso authored Oct 14, 2024
1 parent d52ba11 commit 441377f
Show file tree
Hide file tree
Showing 18 changed files with 722 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
* @file OrchestratorNode.hpp
*/


#ifndef SUSTAINMLCPP_ORCHESTRATOR_ORCHESTRATORNODE_HPP
#define SUSTAINMLCPP_ORCHESTRATOR_ORCHESTRATORNODE_HPP

Expand Down Expand Up @@ -93,10 +92,23 @@ class OrchestratorNode
types::MLModel,
types::UserInput>;

/**
* @brief Construct a new Orchestrator Node object
*
* @param handler OrchestratorNodeHandle object to handle the callbacks
* @param domain Domain ID to use for the DDS entities
*
* @note The deletion of the handler is responsibility of the user.
*/
OrchestratorNode(
std::shared_ptr<OrchestratorNodeHandle> handler,
OrchestratorNodeHandle& handler,
uint32_t domain = 0);

/**
* @brief Destroy the Orchestrator Node object
*
* @note The deletion of the handler is responsibility of the user.
*/
~OrchestratorNode();

/**
Expand Down Expand Up @@ -174,7 +186,7 @@ class OrchestratorNode
/**
* @brief Used to retrieve the associated OrchestratorNodeHandle.
*/
inline std::weak_ptr<OrchestratorNodeHandle> get_handler()
inline OrchestratorNodeHandle* get_handler() const
{
return handler_;
}
Expand All @@ -184,6 +196,22 @@ class OrchestratorNode
*/
void print_db();

/**
* @brief Called by the user to run the run.
*/
void spin();

/**
* @brief Remove all Fast DDS entities and clean up the OrchestratorNode and OrchestratorNodeHandle.
*
*/
void destroy();

/**
* @brief Stops the execution of the node.
*/
void terminate();

protected:

/**
Expand All @@ -199,7 +227,11 @@ class OrchestratorNode

uint32_t domain_;

std::shared_ptr<OrchestratorNodeHandle> handler_;
/**
* @brief Handle to manage the node status and node output callbacks
* @note The deletion of the handler is responsibility of the user.
*/
OrchestratorNodeHandle* handler_;

eprosima::fastdds::dds::DomainParticipant* participant_;

Expand All @@ -223,6 +255,7 @@ class OrchestratorNode
std::mutex mtx_;

std::atomic_bool initialized_{false};
std::atomic_bool terminated_{false};
std::condition_variable initialization_cv_;

/**
Expand All @@ -246,6 +279,10 @@ class OrchestratorNode
OrchestratorNode* orchestrator_{nullptr};

};

std::condition_variable spin_cv_;
std::atomic_bool terminate_;

std::unique_ptr<OrchestratorParticipantListener> participant_listener_;

};
Expand Down
4 changes: 2 additions & 2 deletions sustainml_cpp/src/cpp/orchestrator/ModuleNodeProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ ModuleNodeProxy::~ModuleNodeProxy()
void ModuleNodeProxy::notify_status_change()
{
std::lock_guard<std::mutex> lock(orchestrator_->get_mutex());
std::shared_ptr<OrchestratorNodeHandle> handler_ptr = orchestrator_->get_handler().lock();
OrchestratorNodeHandle* handler_ptr = orchestrator_->get_handler();
if (handler_ptr != nullptr)
{
handler_ptr->on_node_status_change(node_id_, status_);
Expand All @@ -208,7 +208,7 @@ void ModuleNodeProxy::notify_new_node_ouput()
store_data_in_db();
void* untyped_data = get_tmp_untyped_data();
std::lock_guard<std::mutex> lock(orchestrator_->get_mutex());
std::shared_ptr<OrchestratorNodeHandle> handler_ptr = orchestrator_->get_handler().lock();
OrchestratorNodeHandle* handler_ptr = orchestrator_->get_handler();
if (handler_ptr != nullptr)
{
handler_ptr->on_new_node_output(node_id_, untyped_data);
Expand Down
116 changes: 75 additions & 41 deletions sustainml_cpp/src/cpp/orchestrator/OrchestratorNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,33 +69,38 @@ void OrchestratorNode::OrchestratorParticipantListener::on_participant_discovery
NodeID node_id = common::get_node_id_from_name(participant_name);

std::lock_guard<std::mutex> lock(orchestrator_->proxies_mtx_);
if (reason == eprosima::fastdds::rtps::ParticipantDiscoveryStatus::DISCOVERED_PARTICIPANT &&
orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)] == nullptr)
{
EPROSIMA_LOG_INFO(ORCHESTRATOR, "Creating node proxy for " << participant_name << " node");
ModuleNodeProxyFactory::make_node_proxy(
node_id,
orchestrator_,
orchestrator_->task_db_,
orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)]);
}
else if ((reason == eprosima::fastdds::rtps::ParticipantDiscoveryStatus::DROPPED_PARTICIPANT ||
reason == eprosima::fastdds::rtps::ParticipantDiscoveryStatus::REMOVED_PARTICIPANT) &&
orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)] != nullptr)

// check if the node has been terminated
if (!orchestrator_->terminated_.load())
{
EPROSIMA_LOG_INFO(ORCHESTRATOR, "Setting inactive " << participant_name << " node");
types::NodeStatus status = orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)]->get_status();
status.node_status(Status::NODE_INACTIVE);
orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)]->set_status(status);
orchestrator_->handler_->on_node_status_change(node_id, status);
if (reason == eprosima::fastdds::rtps::ParticipantDiscoveryStatus::DISCOVERED_PARTICIPANT &&
orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)] == nullptr)
{
EPROSIMA_LOG_INFO(ORCHESTRATOR, "Creating node proxy for " << participant_name << " node");
ModuleNodeProxyFactory::make_node_proxy(
node_id,
orchestrator_,
orchestrator_->task_db_,
orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)]);
}
else if ((reason == eprosima::fastdds::rtps::ParticipantDiscoveryStatus::DROPPED_PARTICIPANT ||
reason == eprosima::fastdds::rtps::ParticipantDiscoveryStatus::REMOVED_PARTICIPANT) &&
orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)] != nullptr)
{
EPROSIMA_LOG_INFO(ORCHESTRATOR, "Setting inactive " << participant_name << " node");
types::NodeStatus status = orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)]->get_status();
status.node_status(Status::NODE_INACTIVE);
orchestrator_->node_proxies_[static_cast<uint32_t>(node_id)]->set_status(status);
orchestrator_->handler_->on_node_status_change(node_id, status);
}
}
}

OrchestratorNode::OrchestratorNode(
std::shared_ptr<OrchestratorNodeHandle> handle,
OrchestratorNodeHandle& handle,
uint32_t domain)
: domain_(domain)
, handler_(handle)
, handler_(&handle)
, node_proxies_({
nullptr,
nullptr,
Expand All @@ -116,34 +121,46 @@ OrchestratorNode::OrchestratorNode(

OrchestratorNode::~OrchestratorNode()
{
std::lock_guard<std::mutex> lock(mtx_);
for (size_t i = 0; i < (size_t)NodeID::MAX; i++)
destroy();
}

void OrchestratorNode::destroy()
{
if (!terminated_.load())
{
std::lock_guard<std::mutex> lock(proxies_mtx_);
if (node_proxies_[i] != nullptr)
std::lock_guard<std::mutex> lock_proxies(proxies_mtx_);
std::lock_guard<std::mutex> lock(mtx_);
for (size_t i = 0; i < (size_t)NodeID::MAX; i++)
{
delete node_proxies_[i];
if (node_proxies_[i] != nullptr)
{
delete node_proxies_[i];
node_proxies_[i] = nullptr;
}
}
}

if (sub_ != nullptr)
{
sub_->delete_contained_entities();
}
if (sub_ != nullptr)
{
sub_->delete_contained_entities();
}

if (pub_ != nullptr)
{
pub_->delete_contained_entities();
}
if (pub_ != nullptr)
{
pub_->delete_contained_entities();
}

if (participant_ != nullptr)
{
participant_->delete_contained_entities();
}
if (participant_ != nullptr)
{
participant_->delete_contained_entities();
}

DomainParticipantFactory::get_instance()->delete_participant(participant_);

DomainParticipantFactory::get_instance()->delete_participant(participant_);
delete task_man_;

delete task_man_;
handler_ = nullptr;
terminated_.store(true);
}
}

void OrchestratorNode::print_db()
Expand Down Expand Up @@ -464,6 +481,23 @@ void OrchestratorNode::send_control_command(
control_writer_->write(cmd.get_impl());
}

void OrchestratorNode::spin()
{
EPROSIMA_LOG_INFO(ORCHESTRATOR, "Spinning Orchestrator... ");
std::unique_lock<std::mutex> lock(mtx_);
spin_cv_.wait(lock, [&]
{
return terminate_.load();
});
}

void OrchestratorNode::terminate()
{
terminate_.store(true);
destroy();
spin_cv_.notify_all();
}

} // namespace orchestrator
} // namespace orchestrator
} // namespace sustainml

Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TEST(OrchestratorNode, OrchestratorInitializesProperlyWhenNodesAreALive)

tonh->prepare_expected_data(nodes_ready_expected_data);

orchestrator::OrchestratorNode orchestrator(tonh);
orchestrator::OrchestratorNode orchestrator(*(tonh.get()));

MLModelMetadataManagedNode ml_met_node;
MLModelManagedNode ml_node;
Expand All @@ -111,6 +111,7 @@ TEST(OrchestratorNode, OrchestratorInitializesProperlyWhenNodesAreALive)
hw_cons_node.start();

ASSERT_TRUE(tonh->wait_for_data(std::chrono::seconds(5)));
orchestrator.destroy();
}

TEST(OrchestratorNode, AlateJoinerOrchestratorInitializesProperly)
Expand Down Expand Up @@ -145,16 +146,17 @@ TEST(OrchestratorNode, AlateJoinerOrchestratorInitializesProperly)

std::this_thread::sleep_for(std::chrono::seconds(2));

orchestrator::OrchestratorNode orchestrator(tonh);
orchestrator::OrchestratorNode orchestrator(*(tonh.get()));

ASSERT_TRUE(tonh->wait_for_data(std::chrono::seconds(5)));
orchestrator.destroy();
}

TEST(OrchestratorNode, OrchestratorReceivesNodeOutputs)
{
std::shared_ptr<TestOrchestratorNodeHandle> tonh = std::make_shared<TestOrchestratorNodeHandle>();

orchestrator::OrchestratorNode orchestrator(tonh);
orchestrator::OrchestratorNode orchestrator(*(tonh.get()));

MLModelMetadataManagedNode te_node;
MLModelManagedNode ml_node;
Expand Down Expand Up @@ -196,13 +198,14 @@ TEST(OrchestratorNode, OrchestratorReceivesNodeOutputs)
orchestrator.start_task(task.first, task.second);

ASSERT_TRUE(tonh->wait_for_data(std::chrono::seconds(5)));
orchestrator.destroy();
}

TEST(OrchestratorNode, OrchestratorGetTaskData)
{
std::shared_ptr<TestOrchestratorNodeHandle> tonh = std::make_shared<TestOrchestratorNodeHandle>();

orchestrator::OrchestratorNode orchestrator(tonh);
orchestrator::OrchestratorNode orchestrator(*(tonh.get()));

MLModelMetadataManagedNode te_node;
MLModelManagedNode ml_node;
Expand Down Expand Up @@ -261,13 +264,14 @@ TEST(OrchestratorNode, OrchestratorGetTaskData)
ASSERT_EQ(((types::MLModelMetadata*)enc_task)->task_id().problem_id(), 2);
ASSERT_EQ(orchestrator.get_task_data({2, 1}, NodeID::ID_HW_RESOURCES, hw), RetCode_t::RETCODE_OK);
ASSERT_EQ(((types::HWResource*)hw)->task_id().problem_id(), 2);
orchestrator.destroy();
}

TEST(OrchestratorNode, OrchestratorGetNodeStatus)
{
std::shared_ptr<TestOrchestratorNodeHandle> tonh = std::make_shared<TestOrchestratorNodeHandle>();

orchestrator::OrchestratorNode orchestrator(tonh);
orchestrator::OrchestratorNode orchestrator(*(tonh.get()));

MLModelMetadataManagedNode te_node;
MLModelManagedNode ml_node;
Expand Down Expand Up @@ -322,13 +326,14 @@ TEST(OrchestratorNode, OrchestratorGetNodeStatus)
ASSERT_EQ(status->node_status(), Status::NODE_IDLE);
orchestrator.get_node_status(NodeID::ID_APP_REQUIREMENTS, status);
ASSERT_EQ(status->node_status(), Status::NODE_IDLE);
orchestrator.destroy();
}

TEST(OrchestratorNode, OrchestratorTaskIteration)
{
std::shared_ptr<TestOrchestratorNodeHandle> tonh = std::make_shared<TestOrchestratorNodeHandle>();

orchestrator::OrchestratorNode orchestrator(tonh);
orchestrator::OrchestratorNode orchestrator(*(tonh.get()));

MLModelMetadataManagedNode te_node;
MLModelManagedNode ml_node;
Expand Down Expand Up @@ -420,13 +425,14 @@ TEST(OrchestratorNode, OrchestratorTaskIteration)
ASSERT_EQ(1, carbon_iterated_data->task_id().problem_id());
ASSERT_EQ(2, carbon_iterated_data->task_id().iteration_id());
ASSERT_GT(carbon_iterated_data->energy_consumption(), 300);
orchestrator.destroy();
}

TEST(OrchestratorNode, OrchestratorGetTaskDataDoesNotAccumulate)
{
std::shared_ptr<TestOrchestratorNodeHandle> tonh = std::make_shared<TestOrchestratorNodeHandle>();

orchestrator::OrchestratorNode orchestrator(tonh);
orchestrator::OrchestratorNode orchestrator(*(tonh.get()));

MLModelMetadataManagedNode te_node;
MLModelManagedNode ml_node;
Expand Down Expand Up @@ -497,4 +503,5 @@ TEST(OrchestratorNode, OrchestratorGetTaskDataDoesNotAccumulate)
orchestrator.start_task(task.first, task.second);

ASSERT_TRUE(tonh->wait_for_data(std::chrono::seconds(10)));
orchestrator.destroy();
}
4 changes: 3 additions & 1 deletion sustainml_modules/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ opencv-python>=4.9.0.80
transformers>=4.37.0
networkx>=3.0
ollama
ultralytics
optimum
onnx
onnxruntime
rdflib
timm
ultralytics
Loading

0 comments on commit 441377f

Please sign in to comment.