From 06d318426ba059481b465885a17c8f851a57b732 Mon Sep 17 00:00:00 2001 From: Parsa Ghadimi Date: Thu, 9 May 2024 08:37:04 -0400 Subject: [PATCH] fix(node): allow dropping a nested runtime --- core/node/src/lib.rs | 115 ++++++++++++++++++++++++++++--------------- 1 file changed, 76 insertions(+), 39 deletions(-) diff --git a/core/node/src/lib.rs b/core/node/src/lib.rs index 6ddf269bf..ac1884748 100644 --- a/core/node/src/lib.rs +++ b/core/node/src/lib.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::marker::PhantomData; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; @@ -11,6 +12,9 @@ use tokio::time::timeout; /// A single [Node] instance that has ownership over its tokio runtime. pub struct ContainedNode { + /// The name of this contained node. + name: String, + /// The dependency injection data provider which can contain most of the items that make up /// a node. provider: fdi::MultiThreadedProvider, @@ -26,6 +30,8 @@ pub struct ContainedNode { impl ContainedNode { pub fn new(provider: fdi::MultiThreadedProvider, name: Option) -> Self { + let name = name.unwrap_or_else(|| "LIGHTNING".into()); + // Create and insert the shutdown controller to the provider. let trace_shutdown = std::env::var("TRACE_SHUTDOWN").is_ok(); let shutdown = ShutdownController::new(trace_shutdown); @@ -40,10 +46,11 @@ impl ContainedNode { // Create the tokio runtime. let worker_id = AtomicUsize::new(0); + let node_name = name.clone(); let runtime = tokio::runtime::Builder::new_multi_thread() .thread_name_fn(move || { let id = worker_id.fetch_add(1, Ordering::SeqCst); - format!("{}#{id}", name.as_deref().unwrap_or("LIGHTNING")) + format!("{node_name}#{id}") }) .on_thread_start(move || { let permit = permit.clone(); @@ -59,6 +66,7 @@ impl ContainedNode { .expect("Failed to build tokio runtime for node container."); Self { + name, provider, runtime, shutdown, @@ -90,50 +98,79 @@ impl ContainedNode { }) } - pub async fn shutdown(mut self) { + /// Shut down the node and return a future that will be resolved when the node is fully down. + /// + /// Unlike other async method this function can trigger the shutdown without it being polled. + /// In other words you can still trigger the shutdown event by calling this method and never + /// awaiting the returned future. + pub fn shutdown(mut self) -> impl Future { // Tell the controller it's time to go down. self.shutdown.trigger_shutdown(); - // Give the runtime 30 seconds to stop. - self.runtime.shutdown_timeout(Duration::from_secs(30)); - - for i in 0.. { - if timeout(Duration::from_secs(3), self.shutdown.wait_for_completion()) - .await - .is_ok() - { - // shutdown completed. - return; - } + let task_name = format!("{}::RuntimeDrop", self.name); - match i { - 0 | 1 => { - // 3s, 6s - tracing::trace!("Still shutting down..."); - continue; - }, - 2 => { - // 9s - tracing::warn!("Still shutting down..."); + // Give the runtime 30 seconds to stop. + match tokio::runtime::Handle::try_current() { + Ok(handle) => { + tokio::task::Builder::new() + .name(&task_name) + .spawn_blocking_on( + || { + self.runtime.shutdown_timeout(Duration::from_secs(30)); + }, + &handle, + ) + .unwrap(); + }, + Err(_) => { + std::thread::Builder::new() + .name(task_name) + .spawn(|| { + self.runtime.shutdown_timeout(Duration::from_secs(30)); + }) + .unwrap(); + }, + }; + + async move { + for i in 0.. { + if timeout(Duration::from_secs(3), self.shutdown.wait_for_completion()) + .await + .is_ok() + { + // shutdown completed. + return; + } + + match i { + 0 | 1 => { + // 3s, 6s + tracing::trace!("Still shutting down..."); + continue; + }, + 2 => { + // 9s + tracing::warn!("Still shutting down..."); + continue; + }, + _ => { + // 12s + tracing::error!("Shutdown taking too long..") + }, + } + + if i == 10 { + // 33s + break; + } + + let Some(iter) = self.shutdown.pending_backtraces() else { continue; - }, - _ => { - // 12s - tracing::error!("Shutdown taking too long..") - }, - } - - if i == 10 { - // 33s - break; - } - - let Some(iter) = self.shutdown.pending_backtraces() else { - continue; - }; + }; - for (i, trace) in iter.enumerate() { - eprintln!("Pending task backtrace #{i}:\n{trace:#?}"); + for (i, trace) in iter.enumerate() { + eprintln!("Pending task backtrace #{i}:\n{trace:#?}"); + } } } }