Skip to content

Commit

Permalink
Upgraded concurrency framework to support panic=unwind mode (#16)
Browse files Browse the repository at this point in the history
We are going to use era-consensus code in zksync-era, which doesn't have
panic=abort enabled for the time being. This PR adjust the concurrency
framework so that it behaves in the expected way in the presence of
panics:
* scope::run/run_blocking will wait for all tasks to either complete OR
panic
* scope's context gets cancelled immediately if any of the tasks returns
an error OR panics
* scope::run/run_blocking will panic if any of the scope's tasks
panicked (but only after all tasks complete or panic). This also means
that even if task A returns an error before task B panics, the whole
scope will panic anyway (A's error will be ignored).
* the root task of the scope::run/run_blocking call will NOT be executed
in the same tokio task as the run call any more (so that a panic in the
root task won't prevent scope call from awaiting for completion of all
tasks).
  • Loading branch information
pompon0 committed Oct 31, 2023
1 parent 0217158 commit cb34e4c
Show file tree
Hide file tree
Showing 9 changed files with 322 additions and 154 deletions.
86 changes: 45 additions & 41 deletions node/actors/network/src/mux/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,53 +293,57 @@ async fn test_transport_closed() {
write_frame_size: 100,
});
scope::run!(ctx, |ctx, s| async {
let (s1, s2) = noise::testonly::pipe(ctx).await;
// Spawns a peer that establishes the a transient stream and then tries to read from it.
// Reading should terminate as soon as we close the transport.
s.spawn(async {
let mut stream = scope::run!(ctx, |ctx, s| async {
let mut mux = mux::Mux {
cfg: cfg.clone(),
accept: BTreeMap::default(),
connect: BTreeMap::default(),
};
let q = mux::StreamQueue::new(1);
mux.connect.insert(cap, q.clone());
s.spawn_bg(async { expected(mux.run(ctx, s2).await).context("mux2.run()") });
anyhow::Ok(q.open(ctx).await.unwrap())
let streams = s
.spawn(async {
let (s1, s2) = noise::testonly::pipe(ctx).await;
// Establish a transient stream, then close the transport (s1 and s2).
scope::run!(ctx, |ctx, s| async {
let outbound = s.spawn(async {
let mut mux = mux::Mux {
cfg: cfg.clone(),
accept: BTreeMap::default(),
connect: BTreeMap::default(),
};
let q = mux::StreamQueue::new(1);
mux.connect.insert(cap, q.clone());
s.spawn_bg(async {
expected(mux.run(ctx, s2).await).context("[connect] mux.run()")
});
q.open(ctx).await.context("[connect] q.open()")
});
let inbound = s.spawn(async {
let mut mux = mux::Mux {
cfg: cfg.clone(),
accept: BTreeMap::default(),
connect: BTreeMap::default(),
};
let q = mux::StreamQueue::new(1);
mux.accept.insert(cap, q.clone());
s.spawn_bg(async {
expected(mux.run(ctx, s1).await).context("[accept] mux.run()")
});
q.open(ctx).await.context("[accept] q.open()")
});
Ok([
inbound.join(ctx).await.context("inbound")?,
outbound.join(ctx).await.context("outbound")?,
])
})
.await
})
.await
.unwrap();

.join(ctx)
.await?;
// Check how the streams without transport behave.
for mut s in streams {
let mut buf = bytes::Buffer::new(100);
// Read is expected to succeed, but no data should be read.
stream.read.read_exact(ctx, &mut buf).await.unwrap();
s.read.read_exact(ctx, &mut buf).await.unwrap();
assert_eq!(buf.len(), 0);
// Writing will succeed (thanks to buffering), but flushing should fail
// because the transport is closed.
stream.write.write_all(ctx, &[1, 2, 3]).await.unwrap();
assert!(stream.write.flush(ctx).await.is_err());
anyhow::Ok(())
});

let mut mux = mux::Mux {
cfg: cfg.clone(),
accept: BTreeMap::default(),
connect: BTreeMap::default(),
};
let q = mux::StreamQueue::new(1);
mux.accept.insert(cap, q.clone());
// Accept the stream and drop the connection completely.
s.spawn_bg(async { expected(mux.run(ctx, s1).await).context("mux1.run()") });
let mut stream = q.open(ctx).await.unwrap();

let mut buf = bytes::Buffer::new(100);
stream.read.read_exact(ctx, &mut buf).await.unwrap();
// The peer multiplexer is dropped to complete `read_exact()`, so we don't have a deadlock here.
assert_eq!(buf.len(), 0);
stream.write.write_all(ctx, &[1, 2, 3]).await.unwrap();
assert!(stream.write.flush(ctx).await.is_err());

s.write.write_all(ctx, &[1, 2, 3]).await.unwrap();
assert!(s.write.flush(ctx).await.is_err());
}
Ok(())
})
.await
Expand Down
2 changes: 1 addition & 1 deletion node/actors/sync_blocks/src/tests/end_to_end.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl GossipNetwork<InitialNodeHandle> {
}

#[async_trait]
trait GossipNetworkTest: fmt::Debug {
trait GossipNetworkTest: fmt::Debug + Send {
/// Returns the number of nodes in the gossip network and number of peers for each node.
fn network_params(&self) -> (usize, usize);

Expand Down
40 changes: 21 additions & 19 deletions node/actors/sync_blocks/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,28 +139,30 @@ async fn subscribing_to_state_updates() {
anyhow::Ok(())
});

let initial_state = state_subscriber.borrow_and_update();
assert_eq!(
initial_state.first_stored_block,
genesis_block.justification
);
assert_eq!(
initial_state.last_contiguous_stored_block,
genesis_block.justification
);
assert_eq!(initial_state.last_stored_block, genesis_block.justification);
drop(initial_state);
{
let initial_state = state_subscriber.borrow_and_update();
assert_eq!(
initial_state.first_stored_block,
genesis_block.justification
);
assert_eq!(
initial_state.last_contiguous_stored_block,
genesis_block.justification
);
assert_eq!(initial_state.last_stored_block, genesis_block.justification);
}

storage.put_block(ctx, &block_1).await.unwrap();

let new_state = sync::changed(ctx, &mut state_subscriber).await?;
assert_eq!(new_state.first_stored_block, genesis_block.justification);
assert_eq!(
new_state.last_contiguous_stored_block,
block_1.justification
);
assert_eq!(new_state.last_stored_block, block_1.justification);
drop(new_state);
{
let new_state = sync::changed(ctx, &mut state_subscriber).await?;
assert_eq!(new_state.first_stored_block, genesis_block.justification);
assert_eq!(
new_state.last_contiguous_stored_block,
block_1.justification
);
assert_eq!(new_state.last_stored_block, block_1.justification);
}

storage.put_block(ctx, &block_3).await.unwrap();

Expand Down
112 changes: 74 additions & 38 deletions node/libs/concurrency/src/scope/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ mod state;
mod task;

pub use macros::*;
use must_complete::must_complete;
use state::{CancelGuard, State, TerminateGuard};
use state::{CancelGuard, OrPanic, State, TerminateGuard};
use task::{Task, Terminated};
use tracing::Instrument as _;

Expand Down Expand Up @@ -100,6 +99,7 @@ impl<'env, T> JoinHandle<'env, T> {
/// Awaits completion of the task.
/// Returns the result of the task.
/// Returns `Canceled` if the context has been canceled before the task.
/// Panics if the awaited task panicked.
///
/// Caller is expected to provide their local context as an argument, which
/// is neccessarily a descendant of the scope's context (because JoinHandle cannot
Expand All @@ -126,6 +126,14 @@ impl<'env, T> JoinHandle<'env, T> {
Err(ctx::Canceled)
})
}

/// Unconditional join used in run/run_blocking to await the root task.
async fn join_raw(self) -> ctx::OrCanceled<T> {
match self.0.await {
Ok(Ok(v)) => Ok(v),
_ => Err(ctx::Canceled),
}
}
}

/// Scope represents a concurrent computation bounded by lifetime `'env`.
Expand Down Expand Up @@ -244,17 +252,13 @@ impl<'env, E: 'static + Send> Scope<'env, E> {
/// the lifetime of the call.
///
/// `root_task` is executed as a root task of this scope.
/// Although both blocking and async tasks can be executed
/// in this scope, the root task is executed inline in the caller's
/// thread (i.e. we don't call `tokio::spawn(root_task)`) so it has to
/// be async.
///
// Safety:
// - we are assuming that `run` is only called via `run!` macro
// - <'env> is exactly equal to the lifetime of the `Scope::new(...).run(...).await` call.
// - in particular `run(...)` cannot be assigned to a local variable, because
// it would reference the temporal `Scope::new(...)` object.
// - the returned future is wrapped in `must_complete` so it will abort
// - the returned future uses `must_complete::Guard` so it will abort
// the whole process if dropped before completion.
// - before the first `poll()` call of the `run(...)` future may be forgotten (via `std::mem::forget`)
// directly or indirectly and that's safe, because no unsafe code has been executed yet.
Expand All @@ -266,22 +270,43 @@ impl<'env, E: 'static + Send> Scope<'env, E> {
// transitively use references stored in the root task, so they stay valid as well,
// until `run()` future is dropped.
#[doc(hidden)]
pub fn run<T, F, Fut>(&'env mut self, root_task: F) -> impl 'env + Future<Output = Result<T, E>>
pub async fn run<T: 'static + Send, F, Fut>(&'env mut self, root_task: F) -> Result<T, E>
where
F: 'env + FnOnce(&'env ctx::Ctx, &'env Self) -> Fut,
Fut: 'env + Future<Output = Result<T, E>>,
Fut: 'env + Send + Future<Output = Result<T, E>>,
{
must_complete(async move {
let guard = Arc::new(State::make(self.ctx.clone()));
self.cancel_guard = Arc::downgrade(&guard);
self.terminate_guard = Arc::downgrade(guard.terminate_guard());
let state = guard.terminate_guard().state().clone();
let root_res = Task::Main(guard).run(root_task(&self.ctx, self)).await;
// Wait for the scope termination.
state.terminated().await;
// Return the error, or the result of the root_task.
state.take_err().map_or_else(|| Ok(root_res.unwrap()), Err)
})
// Abort if run() future is dropped before completion.
let must_complete = must_complete::Guard;

let guard = Arc::new(State::make(self.ctx.clone()));
self.cancel_guard = Arc::downgrade(&guard);
self.terminate_guard = Arc::downgrade(guard.terminate_guard());
let state = guard.terminate_guard().state().clone();
// Spawn the root task. We cannot run it directly in this task,
// because if the root task panicked, we wouldn't be able to
// wait for other tasks to finish.
let root_task = self.spawn(root_task(&self.ctx, self));
// Once we spawned the root task we can drop the guard.
drop(guard);
// Await for the completion of the root_task.
let root_task_result = root_task.join_raw().await;
// Wait for the scope termination.
state.terminated().await;

// All tasks completed.
must_complete.defuse();

// Return the result of the root_task, the error, or propagate the panic.
match state.take_err() {
// All tasks have completed successfully, so in particular root_task has returned Ok.
None => Ok(root_task_result.unwrap()),
// One of the tasks returned an error, but no panic occurred.
Some(OrPanic::Err(err)) => Err(err),
// Note that panic is propagated only once all of the tasks are run to completion.
Some(OrPanic::Panic) => {
panic!("one of the tasks panicked, look for a stack trace above")
}
}
}

/// not public; used by run_blocking! macro.
Expand All @@ -291,7 +316,7 @@ impl<'env, E: 'static + Send> Scope<'env, E> {
/// task (in particular, not from async code).
/// Behaves analogically to `run`.
#[doc(hidden)]
pub fn run_blocking<T, F>(&'env mut self, root_task: F) -> Result<T, E>
pub fn run_blocking<T: 'static + Send, F: Send>(&'env mut self, root_task: F) -> Result<T, E>
where
E: 'static + Send,
F: 'env + FnOnce(&'env ctx::Ctx, &'env Self) -> Result<T, E>,
Expand All @@ -300,28 +325,39 @@ impl<'env, E: 'static + Send> Scope<'env, E> {
self.cancel_guard = Arc::downgrade(&guard);
self.terminate_guard = Arc::downgrade(guard.terminate_guard());
let state = guard.terminate_guard().state().clone();
let root_res = Task::Main(guard).run_blocking(|| root_task(&self.ctx, self));
// Spawn the root task. We cannot run it directly in this task,
// because if the root task panicked, we wouldn't be able to
// wait for other tasks to finish.
let root_task = self.spawn_blocking(|| root_task(&self.ctx, self));
// Once we spawned the root task we can drop the guard.
drop(guard);
// Await for the completion of the root_task.
let root_task_result = ctx::block_on(root_task.join_raw());
// Wait for the scope termination.
ctx::block_on(state.terminated());
// Return the error, or the result of the root_task.

// Return the result of the root_task, the error, or propagate the panic.
match state.take_err() {
Some(err) => Err(err),
None => Ok(root_res.unwrap()),
// All tasks have completed successfully, so in particular root_task has returned Ok.
None => Ok(root_task_result.unwrap()),
// One of the tasks returned an error, but no panic occurred.
Some(OrPanic::Err(err)) => Err(err),
// Note that panic is propagated only once all of the tasks are run to completion.
Some(OrPanic::Panic) => {
panic!("one of the tasks panicked, look for a stack trace above")
}
}
}
}

/// Spawns the provided blocking closure `f` and waits until it completes or the context gets canceled.
pub async fn wait_blocking<'a, R, E>(
ctx: &'a ctx::Ctx,
f: impl FnOnce() -> Result<R, E> + Send + 'a,
) -> Result<R, E>
where
R: 'static + Send,
E: 'static + From<ctx::Canceled> + Send,
{
run!(ctx, |ctx, s| async {
Ok(s.spawn_blocking(f).join(ctx).await?)
})
.await
/// Spawns the blocking closure `f` and unconditionally awaits for completion.
/// Panics if `f` panics.
/// Aborts if dropped before completion.
pub async fn wait_blocking<'a, T: 'static + Send>(f: impl 'a + Send + FnOnce() -> T) -> T {
let must_complete = must_complete::Guard;
let res = unsafe { spawn_blocking(Box::new(|| Ok(f()))) }
.join_raw()
.await;
must_complete.defuse();
res.expect("awaited task panicked")
}
27 changes: 12 additions & 15 deletions node/libs/concurrency/src/scope/must_complete.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
//! must_complete wraps a future, so that it aborts if it is dropped before completion.
//! Note that it ABORTS the process rather than just panic, so that we get a strong guarantee
//! of completion in both `panic=abort` and `panic=unwind` compilation modes.
//! Guard that aborts the process if dropped without being defused.
//! Note that it ABORTS the process rather than just panic, so it behaves consistently
//! in both `panic=abort` and `panic=unwind` compilation modes.
//!
//! It should be used to prevent a future from being dropped before completion.
//! Possibility that a future can be dropped/aborted at every await makes the control flow unnecessarily complicated.
//! In fact, only few basic futures (like io primitives) actually need to be abortable, so
//! that they can be put together into a tokio::select block. All the higher level logic
//! would greatly benefit (in terms of readability and bug-resistance) from being non-abortable.
//! Rust doesn't support linear types as of now, so best we can do is a runtime check.
use std::future::Future;

/// must_complete wraps a future, so that it aborts if it is dropped before completion.
pub(super) fn must_complete<Fut: Future>(fut: Fut) -> impl Future<Output = Fut::Output> {
let guard = Guard;
async move {
let res = fut.await;
std::mem::forget(guard);
res
/// Guard which aborts the process when dropped.
/// Use `Guard::defuse()` to avoid aborting.
pub(super) struct Guard;

impl Guard {
/// Drops the guard silently, so that it doesn't abort the process.
pub(crate) fn defuse(self) {
std::mem::forget(self)
}
}

/// Guard which aborts the process when dropped.
/// Use std::mem::ManuallyDrop to avoid the drop call.
struct Guard;

impl Drop for Guard {
fn drop(&mut self) {
// We always abort here, no matter if compiled with panic=abort or panic=unwind.
Expand Down
Loading

0 comments on commit cb34e4c

Please sign in to comment.