Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgraded concurrency framework to support panic=unwind mode #16

Merged
merged 9 commits into from
Oct 31, 2023
Merged
86 changes: 45 additions & 41 deletions node/actors/network/src/mux/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,53 +293,57 @@ async fn test_transport_closed() {
write_frame_size: 100,
});
scope::run!(ctx, |ctx, s| async {
let (s1, s2) = noise::testonly::pipe(ctx).await;
// Spawns a peer that establishes the a transient stream and then tries to read from it.
// Reading should terminate as soon as we close the transport.
s.spawn(async {
let mut stream = scope::run!(ctx, |ctx, s| async {
let mut mux = mux::Mux {
cfg: cfg.clone(),
accept: BTreeMap::default(),
connect: BTreeMap::default(),
};
let q = mux::StreamQueue::new(1);
mux.connect.insert(cap, q.clone());
s.spawn_bg(async { expected(mux.run(ctx, s2).await).context("mux2.run()") });
anyhow::Ok(q.open(ctx).await.unwrap())
let streams = s
.spawn(async {
let (s1, s2) = noise::testonly::pipe(ctx).await;
// Establish a transient stream, then close the transport (s1 and s2).
scope::run!(ctx, |ctx, s| async {
let outbound = s.spawn(async {
let mut mux = mux::Mux {
cfg: cfg.clone(),
accept: BTreeMap::default(),
connect: BTreeMap::default(),
};
let q = mux::StreamQueue::new(1);
mux.connect.insert(cap, q.clone());
s.spawn_bg(async {
expected(mux.run(ctx, s2).await).context("[connect] mux.run()")
});
q.open(ctx).await.context("[connect] q.open()")
});
let inbound = s.spawn(async {
let mut mux = mux::Mux {
cfg: cfg.clone(),
accept: BTreeMap::default(),
connect: BTreeMap::default(),
};
let q = mux::StreamQueue::new(1);
mux.accept.insert(cap, q.clone());
s.spawn_bg(async {
expected(mux.run(ctx, s1).await).context("[accept] mux.run()")
});
q.open(ctx).await.context("[accept] q.open()")
});
Ok([
inbound.join(ctx).await.context("inbound")?,
outbound.join(ctx).await.context("outbound")?,
])
})
.await
})
.await
.unwrap();

.join(ctx)
.await?;
// Check how the streams without transport behave.
for mut s in streams {
let mut buf = bytes::Buffer::new(100);
// Read is expected to succeed, but no data should be read.
stream.read.read_exact(ctx, &mut buf).await.unwrap();
s.read.read_exact(ctx, &mut buf).await.unwrap();
assert_eq!(buf.len(), 0);
// Writing will succeed (thanks to buffering), but flushing should fail
// because the transport is closed.
stream.write.write_all(ctx, &[1, 2, 3]).await.unwrap();
assert!(stream.write.flush(ctx).await.is_err());
anyhow::Ok(())
});

let mut mux = mux::Mux {
cfg: cfg.clone(),
accept: BTreeMap::default(),
connect: BTreeMap::default(),
};
let q = mux::StreamQueue::new(1);
mux.accept.insert(cap, q.clone());
// Accept the stream and drop the connection completely.
s.spawn_bg(async { expected(mux.run(ctx, s1).await).context("mux1.run()") });
let mut stream = q.open(ctx).await.unwrap();

let mut buf = bytes::Buffer::new(100);
stream.read.read_exact(ctx, &mut buf).await.unwrap();
// The peer multiplexer is dropped to complete `read_exact()`, so we don't have a deadlock here.
assert_eq!(buf.len(), 0);
stream.write.write_all(ctx, &[1, 2, 3]).await.unwrap();
assert!(stream.write.flush(ctx).await.is_err());

s.write.write_all(ctx, &[1, 2, 3]).await.unwrap();
assert!(s.write.flush(ctx).await.is_err());
}
Ok(())
})
.await
Expand Down
2 changes: 1 addition & 1 deletion node/actors/sync_blocks/src/tests/end_to_end.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl GossipNetwork<InitialNodeHandle> {
}

#[async_trait]
trait GossipNetworkTest: fmt::Debug {
trait GossipNetworkTest: fmt::Debug + Send {
/// Returns the number of nodes in the gossip network and number of peers for each node.
fn network_params(&self) -> (usize, usize);

Expand Down
40 changes: 21 additions & 19 deletions node/actors/sync_blocks/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,28 +139,30 @@ async fn subscribing_to_state_updates() {
anyhow::Ok(())
});

let initial_state = state_subscriber.borrow_and_update();
assert_eq!(
initial_state.first_stored_block,
genesis_block.justification
);
assert_eq!(
initial_state.last_contiguous_stored_block,
genesis_block.justification
);
assert_eq!(initial_state.last_stored_block, genesis_block.justification);
drop(initial_state);
{
let initial_state = state_subscriber.borrow_and_update();
assert_eq!(
initial_state.first_stored_block,
genesis_block.justification
);
assert_eq!(
initial_state.last_contiguous_stored_block,
genesis_block.justification
);
assert_eq!(initial_state.last_stored_block, genesis_block.justification);
}

storage.put_block(ctx, &block_1).await.unwrap();

let new_state = sync::changed(ctx, &mut state_subscriber).await?;
assert_eq!(new_state.first_stored_block, genesis_block.justification);
assert_eq!(
new_state.last_contiguous_stored_block,
block_1.justification
);
assert_eq!(new_state.last_stored_block, block_1.justification);
drop(new_state);
{
let new_state = sync::changed(ctx, &mut state_subscriber).await?;
assert_eq!(new_state.first_stored_block, genesis_block.justification);
assert_eq!(
new_state.last_contiguous_stored_block,
block_1.justification
);
assert_eq!(new_state.last_stored_block, block_1.justification);
}

storage.put_block(ctx, &block_3).await.unwrap();

Expand Down
112 changes: 74 additions & 38 deletions node/libs/concurrency/src/scope/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ mod state;
mod task;

pub use macros::*;
use must_complete::must_complete;
use state::{CancelGuard, State, TerminateGuard};
use state::{CancelGuard, OrPanic, State, TerminateGuard};
use task::{Task, Terminated};
use tracing::Instrument as _;

Expand Down Expand Up @@ -100,6 +99,7 @@ impl<'env, T> JoinHandle<'env, T> {
/// Awaits completion of the task.
/// Returns the result of the task.
/// Returns `Canceled` if the context has been canceled before the task.
/// Panics if the awaited task panicked.
///
/// Caller is expected to provide their local context as an argument, which
/// is neccessarily a descendant of the scope's context (because JoinHandle cannot
Expand All @@ -126,6 +126,14 @@ impl<'env, T> JoinHandle<'env, T> {
Err(ctx::Canceled)
})
}

/// Unconditional join used in run/run_blocking to await the root task.
async fn join_raw(self) -> ctx::OrCanceled<T> {
match self.0.await {
Ok(Ok(v)) => Ok(v),
_ => Err(ctx::Canceled),
}
}
}

/// Scope represents a concurrent computation bounded by lifetime `'env`.
Expand Down Expand Up @@ -244,17 +252,13 @@ impl<'env, E: 'static + Send> Scope<'env, E> {
/// the lifetime of the call.
///
/// `root_task` is executed as a root task of this scope.
/// Although both blocking and async tasks can be executed
/// in this scope, the root task is executed inline in the caller's
/// thread (i.e. we don't call `tokio::spawn(root_task)`) so it has to
/// be async.
///
// Safety:
// - we are assuming that `run` is only called via `run!` macro
// - <'env> is exactly equal to the lifetime of the `Scope::new(...).run(...).await` call.
// - in particular `run(...)` cannot be assigned to a local variable, because
// it would reference the temporal `Scope::new(...)` object.
// - the returned future is wrapped in `must_complete` so it will abort
// - the returned future uses `must_complete::Guard` so it will abort
// the whole process if dropped before completion.
// - before the first `poll()` call of the `run(...)` future may be forgotten (via `std::mem::forget`)
// directly or indirectly and that's safe, because no unsafe code has been executed yet.
Expand All @@ -266,22 +270,43 @@ impl<'env, E: 'static + Send> Scope<'env, E> {
// transitively use references stored in the root task, so they stay valid as well,
// until `run()` future is dropped.
#[doc(hidden)]
pub fn run<T, F, Fut>(&'env mut self, root_task: F) -> impl 'env + Future<Output = Result<T, E>>
pub async fn run<T: 'static + Send, F, Fut>(&'env mut self, root_task: F) -> Result<T, E>
where
F: 'env + FnOnce(&'env ctx::Ctx, &'env Self) -> Fut,
Fut: 'env + Future<Output = Result<T, E>>,
Fut: 'env + Send + Future<Output = Result<T, E>>,
{
must_complete(async move {
let guard = Arc::new(State::make(self.ctx.clone()));
self.cancel_guard = Arc::downgrade(&guard);
self.terminate_guard = Arc::downgrade(guard.terminate_guard());
let state = guard.terminate_guard().state().clone();
let root_res = Task::Main(guard).run(root_task(&self.ctx, self)).await;
// Wait for the scope termination.
state.terminated().await;
// Return the error, or the result of the root_task.
state.take_err().map_or_else(|| Ok(root_res.unwrap()), Err)
})
// Abort if run() future is dropped before completion.
let must_complete = must_complete::Guard;

let guard = Arc::new(State::make(self.ctx.clone()));
self.cancel_guard = Arc::downgrade(&guard);
self.terminate_guard = Arc::downgrade(guard.terminate_guard());
let state = guard.terminate_guard().state().clone();
// Spawn the root task. We cannot run it directly in this task,
// because if the root task panicked, we wouldn't be able to
// wait for other tasks to finish.
let root_task = self.spawn(root_task(&self.ctx, self));
// Once we spawned the root task we can drop the guard.
drop(guard);
// Await for the completion of the root_task.
let root_task_result = root_task.join_raw().await;
// Wait for the scope termination.
state.terminated().await;

// All tasks completed.
must_complete.defuse();

// Return the result of the root_task, the error, or propagate the panic.
match state.take_err() {
// All tasks have completed successfully, so in particular root_task has returned Ok.
None => Ok(root_task_result.unwrap()),
// One of the tasks returned an error, but no panic occurred.
Some(OrPanic::Err(err)) => Err(err),
// Note that panic is propagated only once all of the tasks are run to completion.
Some(OrPanic::Panic) => {
panic!("one of the tasks panicked, look for a stack trace above")
}
}
}

/// not public; used by run_blocking! macro.
Expand All @@ -291,7 +316,7 @@ impl<'env, E: 'static + Send> Scope<'env, E> {
/// task (in particular, not from async code).
/// Behaves analogically to `run`.
#[doc(hidden)]
pub fn run_blocking<T, F>(&'env mut self, root_task: F) -> Result<T, E>
pub fn run_blocking<T: 'static + Send, F: Send>(&'env mut self, root_task: F) -> Result<T, E>
where
E: 'static + Send,
F: 'env + FnOnce(&'env ctx::Ctx, &'env Self) -> Result<T, E>,
Expand All @@ -300,28 +325,39 @@ impl<'env, E: 'static + Send> Scope<'env, E> {
self.cancel_guard = Arc::downgrade(&guard);
self.terminate_guard = Arc::downgrade(guard.terminate_guard());
let state = guard.terminate_guard().state().clone();
let root_res = Task::Main(guard).run_blocking(|| root_task(&self.ctx, self));
// Spawn the root task. We cannot run it directly in this task,
// because if the root task panicked, we wouldn't be able to
// wait for other tasks to finish.
let root_task = self.spawn_blocking(|| root_task(&self.ctx, self));
// Once we spawned the root task we can drop the guard.
drop(guard);
// Await for the completion of the root_task.
let root_task_result = ctx::block_on(root_task.join_raw());
// Wait for the scope termination.
ctx::block_on(state.terminated());
// Return the error, or the result of the root_task.

// Return the result of the root_task, the error, or propagate the panic.
match state.take_err() {
Some(err) => Err(err),
None => Ok(root_res.unwrap()),
// All tasks have completed successfully, so in particular root_task has returned Ok.
None => Ok(root_task_result.unwrap()),
// One of the tasks returned an error, but no panic occurred.
Some(OrPanic::Err(err)) => Err(err),
// Note that panic is propagated only once all of the tasks are run to completion.
Some(OrPanic::Panic) => {
panic!("one of the tasks panicked, look for a stack trace above")
}
}
}
}

/// Spawns the provided blocking closure `f` and waits until it completes or the context gets canceled.
pub async fn wait_blocking<'a, R, E>(
ctx: &'a ctx::Ctx,
f: impl FnOnce() -> Result<R, E> + Send + 'a,
) -> Result<R, E>
where
R: 'static + Send,
E: 'static + From<ctx::Canceled> + Send,
{
run!(ctx, |ctx, s| async {
Ok(s.spawn_blocking(f).join(ctx).await?)
})
.await
/// Spawns the blocking closure `f` and unconditionally awaits for completion.
/// Panics if `f` panics.
/// Aborts if dropped before completion.
pub async fn wait_blocking<'a, T: 'static + Send>(f: impl 'a + Send + FnOnce() -> T) -> T {
let must_complete = must_complete::Guard;
let res = unsafe { spawn_blocking(Box::new(|| Ok(f()))) }
.join_raw()
.await;
must_complete.defuse();
res.expect("awaited task panicked")
}
27 changes: 12 additions & 15 deletions node/libs/concurrency/src/scope/must_complete.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
//! must_complete wraps a future, so that it aborts if it is dropped before completion.
//! Note that it ABORTS the process rather than just panic, so that we get a strong guarantee
//! of completion in both `panic=abort` and `panic=unwind` compilation modes.
//! Guard that aborts the process if dropped without being defused.
//! Note that it ABORTS the process rather than just panic, so it behaves consistently
//! in both `panic=abort` and `panic=unwind` compilation modes.
//!
//! It should be used to prevent a future from being dropped before completion.
//! Possibility that a future can be dropped/aborted at every await makes the control flow unnecessarily complicated.
//! In fact, only few basic futures (like io primitives) actually need to be abortable, so
//! that they can be put together into a tokio::select block. All the higher level logic
//! would greatly benefit (in terms of readability and bug-resistance) from being non-abortable.
//! Rust doesn't support linear types as of now, so best we can do is a runtime check.
use std::future::Future;

/// must_complete wraps a future, so that it aborts if it is dropped before completion.
pub(super) fn must_complete<Fut: Future>(fut: Fut) -> impl Future<Output = Fut::Output> {
let guard = Guard;
async move {
let res = fut.await;
std::mem::forget(guard);
res
/// Guard which aborts the process when dropped.
/// Use `Guard::defuse()` to avoid aborting.
pub(super) struct Guard;

impl Guard {
/// Drops the guard silently, so that it doesn't abort the process.
pub(crate) fn defuse(self) {
std::mem::forget(self)
}
}

/// Guard which aborts the process when dropped.
/// Use std::mem::ManuallyDrop to avoid the drop call.
struct Guard;

impl Drop for Guard {
fn drop(&mut self) {
// We always abort here, no matter if compiled with panic=abort or panic=unwind.
Expand Down
Loading
Loading