From 1a5b88b1ead2452678c0baa0d9a9bb507b794676 Mon Sep 17 00:00:00 2001 From: Timo Glane Date: Mon, 13 May 2024 18:33:50 +0200 Subject: [PATCH] Drop the join waker of a task eagerly when the task completes and there is no join interest --- tokio/src/runtime/task/harness.rs | 64 +++++++++--- tokio/src/runtime/task/state.rs | 103 +++++++++++++------ tokio/src/runtime/tests/loom_multi_thread.rs | 30 ++++++ 3 files changed, 151 insertions(+), 46 deletions(-) diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 996f0f2d9b4..01a8e63abc4 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -284,21 +284,31 @@ where } pub(super) fn drop_join_handle_slow(self) { + use super::state::TransitionToJoinHandleDrop; // Try to unset `JOIN_INTEREST`. This must be done as a first step in // case the task concurrently completed. - if self.state().unset_join_interested().is_err() { - // It is our responsibility to drop the output. This is critical as - // the task output may not be `Send` and as such must remain with - // the scheduler or `JoinHandle`. i.e. if the output remains in the - // task structure until the task is deallocated, it may be dropped - // by a Waker on any arbitrary thread. - // - // Panics are delivered to the user via the `JoinHandle`. Given that - // they are dropping the `JoinHandle`, we assume they are not - // interested in the panic and swallow it. - let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { - self.core().drop_future_or_output(); - })); + let transition = self.state().transition_to_join_handle_drop(); + match transition { + TransitionToJoinHandleDrop::Failed => { + // It is our responsibility to drop the output. This is critical as + // the task output may not be `Send` and as such must remain with + // the scheduler or `JoinHandle`. i.e. if the output remains in the + // task structure until the task is deallocated, it may be dropped + // by a Waker on any arbitrary thread. + // + // Panics are delivered to the user via the `JoinHandle`. Given that + // they are dropping the `JoinHandle`, we assume they are not + // interested in the panic and swallow it. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().drop_future_or_output(); + })); + } + TransitionToJoinHandleDrop::OkDropJoinWaker => unsafe { + // If there is a waker associated with this task when the `JoinHandle` is about to get + // dropped we want to also drop this waker if the task is already completed. + self.trailer().set_waker(None); + }, + TransitionToJoinHandleDrop::OkDoNothing => (), } // Drop the `JoinHandle` reference, possibly deallocating the task @@ -309,6 +319,7 @@ where /// Completes the task. This method assumes that the state is RUNNING. fn complete(self) { + use super::state::TransitionToTerminal; // The future has completed and its output has been written to the task // stage. We transition from running to complete. @@ -346,8 +357,29 @@ where // The task has completed execution and will no longer be scheduled. let num_release = self.release(); - if self.state().transition_to_terminal(num_release) { - self.dealloc(); + match self.state().transition_to_terminal(num_release) { + TransitionToTerminal::OkDoNothing => (), + TransitionToTerminal::OkDealloc => { + self.dealloc(); + } + TransitionToTerminal::FailedDropJoinWaker => { + // Safety: In this case we are the only one referencing the task and the active + // waker is the only one preventing the task from being deallocated so noone else + // will try to access the waker here. + unsafe { + self.trailer().set_waker(None); + } + + // We do not expect this to happen since `TransitionToTerminal::DropJoinWaker` + // will only be returned when after dropping the JoinWaker the task can be + // safely. Because after this failed transition the COMPLETE bit is still set + // its fine to transition to terminal in two steps here + if let TransitionToTerminal::OkDealloc = + self.state().transition_to_terminal(num_release) + { + self.dealloc(); + } + } } } @@ -387,7 +419,7 @@ fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool { debug_assert!(snapshot.is_join_interested()); - if !snapshot.is_complete() { + if !snapshot.is_complete() && !snapshot.is_terminal() { // If the task is not complete, try storing the provided waker in the // task's waker field. diff --git a/tokio/src/runtime/task/state.rs b/tokio/src/runtime/task/state.rs index 0fc7bb0329b..d5bdad6fcc3 100644 --- a/tokio/src/runtime/task/state.rs +++ b/tokio/src/runtime/task/state.rs @@ -36,8 +36,12 @@ const JOIN_WAKER: usize = 0b10_000; /// The task has been forcibly cancelled. const CANCELLED: usize = 0b100_000; +const TERMINAL: usize = 0b1_000_000; + /// All bits. -const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; +// const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; +const STATE_MASK: usize = + LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED | TERMINAL; /// Bits used by the ref count portion of the state. const REF_COUNT_MASK: usize = !STATE_MASK; @@ -89,6 +93,20 @@ pub(crate) enum TransitionToNotifiedByRef { Submit, } +#[must_use] +pub(crate) enum TransitionToJoinHandleDrop { + Failed, + OkDoNothing, + OkDropJoinWaker, +} + +#[must_use] +pub(crate) enum TransitionToTerminal { + OkDoNothing, + OkDealloc, + FailedDropJoinWaker, +} + /// All transitions are performed via RMW operations. This establishes an /// unambiguous modification order. impl State { @@ -174,6 +192,24 @@ impl State { }) } + pub(super) fn transition_to_join_handle_drop(&self) -> TransitionToJoinHandleDrop { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_join_interested() { + snapshot.unset_join_interested() + } + + if snapshot.is_complete() && (!snapshot.is_terminal() || !snapshot.is_join_waker_set()) + { + (TransitionToJoinHandleDrop::Failed, None) + } else if snapshot.is_join_waker_set() { + snapshot.unset_join_waker(); + (TransitionToJoinHandleDrop::OkDropJoinWaker, Some(snapshot)) + } else { + (TransitionToJoinHandleDrop::OkDoNothing, Some(snapshot)) + } + }) + } + /// Transitions the task from `Running` -> `Complete`. pub(super) fn transition_to_complete(&self) -> Snapshot { const DELTA: usize = RUNNING | COMPLETE; @@ -181,6 +217,7 @@ impl State { let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel)); assert!(prev.is_running()); assert!(!prev.is_complete()); + assert!(!prev.is_terminal()); Snapshot(prev.0 ^ DELTA) } @@ -188,16 +225,37 @@ impl State { /// Transitions from `Complete` -> `Terminal`, decrementing the reference /// count the specified number of times. /// - /// Returns true if the task should be deallocated. - pub(super) fn transition_to_terminal(&self, count: usize) -> bool { - let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel)); - assert!( - prev.ref_count() >= count, - "current: {}, sub: {}", - prev.ref_count(), - count - ); - prev.ref_count() == count + /// Returns `TransitionToTerminal::OkDoNothing` if transition was successful but the task can + /// not already be deallocated. + /// Returns `TransitionToTerminal::OkDealloc` if the task should be deallocated. + /// Returns `TransitionToTerminal::FailedDropJoinWaker` if the transition failed because of a + /// the join waker being the only last. In this case the reference count will not be decremented + /// but the `JOIN_WAKER` bit will be unset. + pub(super) fn transition_to_terminal(&self, count: usize) -> TransitionToTerminal { + self.fetch_update_action(|mut snapshot| { + assert!(!snapshot.is_running()); + assert!(snapshot.is_complete()); + assert!(!snapshot.is_terminal()); + assert!( + snapshot.ref_count() >= count, + "current: {}, sub: {}", + snapshot.ref_count(), + count + ); + + if snapshot.ref_count() == count { + snapshot.0 -= count * REF_ONE; + snapshot.0 |= TERMINAL; + (TransitionToTerminal::OkDealloc, Some(snapshot)) + } else if !snapshot.is_join_interested() && snapshot.is_join_waker_set() { + snapshot.unset_join_waker(); + (TransitionToTerminal::FailedDropJoinWaker, Some(snapshot)) + } else { + snapshot.0 -= count * REF_ONE; + snapshot.0 |= TERMINAL; + (TransitionToTerminal::OkDoNothing, Some(snapshot)) + } + }) } /// Transitions the state to `NOTIFIED`. @@ -371,25 +429,6 @@ impl State { .map_err(|_| ()) } - /// Tries to unset the `JOIN_INTEREST` flag. - /// - /// Returns `Ok` if the operation happens before the task transitions to a - /// completed state, `Err` otherwise. - pub(super) fn unset_join_interested(&self) -> UpdateResult { - self.fetch_update(|curr| { - assert!(curr.is_join_interested()); - - if curr.is_complete() { - return None; - } - - let mut next = curr; - next.unset_join_interested(); - - Some(next) - }) - } - /// Sets the `JOIN_WAKER` bit. /// /// Returns `Ok` if the bit is set, `Err` otherwise. This operation fails if @@ -557,6 +596,10 @@ impl Snapshot { self.0 & COMPLETE == COMPLETE } + pub(super) fn is_terminal(self) -> bool { + self.0 & TERMINAL == TERMINAL + } + pub(super) fn is_join_interested(self) -> bool { self.0 & JOIN_INTEREST == JOIN_INTEREST } diff --git a/tokio/src/runtime/tests/loom_multi_thread.rs b/tokio/src/runtime/tests/loom_multi_thread.rs index ddd14b7fb3f..e2706e65c65 100644 --- a/tokio/src/runtime/tests/loom_multi_thread.rs +++ b/tokio/src/runtime/tests/loom_multi_thread.rs @@ -10,6 +10,7 @@ mod yield_now; /// In order to speed up the C use crate::runtime::tests::loom_oneshot as oneshot; use crate::runtime::{self, Runtime}; +use crate::sync::mpsc::channel; use crate::{spawn, task}; use tokio_test::assert_ok; @@ -459,3 +460,32 @@ impl Future for Track { }) } } + +#[test] +fn drop_tasks_with_reference_cycle() { + loom::model(|| { + let pool = mk_pool(2); + + pool.block_on(async move { + let (tx, mut rx) = channel(1); + + let (a_closer, mut wait_for_close_a) = channel::<()>(1); + let (b_closer, mut wait_for_close_b) = channel::<()>(1); + + let a = spawn(async move { + let b = rx.recv().await.unwrap(); + + futures::future::select(std::pin::pin!(b), std::pin::pin!(a_closer.send(()))).await; + }); + + let b = spawn(async move { + let _ = a.await; + let _ = b_closer.send(()).await; + }); + + tx.send(b).await.unwrap(); + + futures::future::join(wait_for_close_a.recv(), wait_for_close_b.recv()).await; + }); + }); +}