diff --git a/node/Cargo.lock b/node/Cargo.lock index 566f9ae1..beb226da 100644 --- a/node/Cargo.lock +++ b/node/Cargo.lock @@ -2624,6 +2624,7 @@ name = "zksync_concurrency" version = "0.1.0" dependencies = [ "anyhow", + "assert_matches", "once_cell", "pin-project", "rand 0.8.5", diff --git a/node/actors/bft/src/leader/replica_commit.rs b/node/actors/bft/src/leader/replica_commit.rs index b8a56ab4..f0d31d4a 100644 --- a/node/actors/bft/src/leader/replica_commit.rs +++ b/node/actors/bft/src/leader/replica_commit.rs @@ -161,7 +161,7 @@ impl StateMachine { )), recipient: Target::Broadcast, }; - self.pipe.send(output_message.into()); + self.outbound_pipe.send(output_message.into()); // Clean the caches. self.prepare_message_cache.retain(|k, _| k >= &self.view); diff --git a/node/actors/bft/src/leader/state_machine.rs b/node/actors/bft/src/leader/state_machine.rs index d37a87e5..ad08d118 100644 --- a/node/actors/bft/src/leader/state_machine.rs +++ b/node/actors/bft/src/leader/state_machine.rs @@ -6,8 +6,8 @@ use std::{ }; use tracing::instrument; use zksync_concurrency::{ctx, error::Wrap as _, metrics::LatencyHistogramExt as _, sync, time}; -use zksync_consensus_network::io::{ConsensusInputMessage, Target}; -use zksync_consensus_roles::validator; +use zksync_consensus_network::io::{ConsensusInputMessage, ConsensusReq, Target}; +use zksync_consensus_roles::validator::{self, ConsensusMsg, Signed}; /// The StateMachine struct contains the state of the leader. This is a simple state machine. We just store /// replica messages and produce leader messages (including proposing blocks) when we reach the threshold for @@ -15,8 +15,10 @@ use zksync_consensus_roles::validator; pub(crate) struct StateMachine { /// Consensus configuration and output channel. pub(crate) config: Arc, - /// Pipe through with leader sends network messages. - pub(crate) pipe: OutputSender, + /// Pipe through which leader sends network messages. + pub(crate) outbound_pipe: OutputSender, + /// Pipe through which leader receives network requests. + inbound_pipe: sync::prunable_mpsc::Receiver, /// The current view number. This might not match the replica's view number, we only have this here /// to make the leader advance monotonically in time and stop it from accepting messages from the past. pub(crate) view: validator::ViewNumber, @@ -28,7 +30,7 @@ pub(crate) struct StateMachine { /// A cache of replica prepare messages indexed by view number and validator. pub(crate) prepare_message_cache: BTreeMap< validator::ViewNumber, - HashMap>, + HashMap>, >, /// Prepare QCs indexed by view number. pub(crate) prepare_qcs: BTreeMap, @@ -37,19 +39,29 @@ pub(crate) struct StateMachine { /// A cache of replica commit messages indexed by view number and validator. pub(crate) commit_message_cache: BTreeMap< validator::ViewNumber, - HashMap>, + HashMap>, >, /// Commit QCs indexed by view number. pub(crate) commit_qcs: BTreeMap, } impl StateMachine { - /// Creates a new StateMachine struct. + /// Creates a new [`StateMachine`] instance. + /// + /// Returns a tuple containing: + /// * The newly created [`StateMachine`] instance. + /// * A sender handle that should be used to send values to be processed by the instance, asynchronously. #[instrument(level = "trace")] - pub fn new(ctx: &ctx::Ctx, config: Arc, pipe: OutputSender) -> Self { - StateMachine { + pub fn new( + ctx: &ctx::Ctx, + config: Arc, + outbound_pipe: OutputSender, + ) -> (Self, sync::prunable_mpsc::Sender) { + let (send, recv) = sync::prunable_mpsc::channel(StateMachine::inbound_pruning_predicate); + + let this = StateMachine { config, - pipe, + outbound_pipe, view: validator::ViewNumber(0), phase: validator::Phase::Prepare, phase_start: ctx.now(), @@ -58,49 +70,54 @@ impl StateMachine { commit_message_cache: BTreeMap::new(), prepare_qc: sync::watch::channel(None).0, commit_qcs: BTreeMap::new(), - } + inbound_pipe: recv, + }; + + (this, send) } - /// Process an input message (leaders don't time out waiting for a message). This is the - /// main entry point for the state machine. We need read-access to the inner consensus struct. - /// As a result, we can modify our state machine or send a message to the executor. - #[instrument(level = "trace", skip(self), ret)] - pub(crate) async fn process_input( - &mut self, - ctx: &ctx::Ctx, - input: validator::Signed, - ) -> ctx::Result<()> { - let now = ctx.now(); - let label = match &input.msg { - validator::ConsensusMsg::ReplicaPrepare(_) => { - let res = match self - .process_replica_prepare(ctx, input.cast().unwrap()) - .await - .wrap("process_replica_prepare()") - { - Ok(()) => Ok(()), - Err(super::replica_prepare::Error::Internal(err)) => { - return Err(err); - } - Err(err) => { - tracing::warn!("process_replica_prepare: {err:#}"); - Err(()) - } - }; - metrics::ConsensusMsgLabel::ReplicaPrepare.with_result(&res) - } - validator::ConsensusMsg::ReplicaCommit(_) => { - let res = self - .process_replica_commit(ctx, input.cast().unwrap()) - .map_err(|err| { - tracing::warn!("process_replica_commit: {err:#}"); - }); - metrics::ConsensusMsgLabel::ReplicaCommit.with_result(&res) - } - _ => unreachable!(), - }; - metrics::METRICS.leader_processing_latency[&label].observe_latency(ctx.now() - now); - Ok(()) + /// Runs a loop to process incoming messages. + /// This is the main entry point for the state machine, + /// potentially triggering state modifications and message sending to the executor. + pub(crate) async fn run(mut self, ctx: &ctx::Ctx) -> ctx::Result<()> { + loop { + let req = self.inbound_pipe.recv(ctx).await?; + + let now = ctx.now(); + let label = match &req.msg.msg { + ConsensusMsg::ReplicaPrepare(_) => { + let res = match self + .process_replica_prepare(ctx, req.msg.cast().unwrap()) + .await + .wrap("process_replica_prepare()") + { + Ok(()) => Ok(()), + Err(super::replica_prepare::Error::Internal(err)) => { + return Err(err); + } + Err(err) => { + tracing::warn!("process_replica_prepare: {err:#}"); + Err(()) + } + }; + metrics::ConsensusMsgLabel::ReplicaPrepare.with_result(&res) + } + ConsensusMsg::ReplicaCommit(_) => { + let res = self + .process_replica_commit(ctx, req.msg.cast().unwrap()) + .map_err(|err| { + tracing::warn!("process_replica_commit: {err:#}"); + }); + metrics::ConsensusMsgLabel::ReplicaCommit.with_result(&res) + } + _ => unreachable!(), + }; + metrics::METRICS.leader_processing_latency[&label].observe_latency(ctx.now() - now); + + // Notify network actor that the message has been processed. + // Ignore sending error. + let _ = req.ack.send(()); + } } /// In a loop, receives a PrepareQC and sends a LeaderPrepare containing it. @@ -213,4 +230,16 @@ impl StateMachine { ); Ok(()) } + + #[allow(clippy::match_like_matches_macro)] + fn inbound_pruning_predicate(pending_req: &ConsensusReq, new_req: &ConsensusReq) -> bool { + if pending_req.msg.key != new_req.msg.key { + return false; + } + match (&pending_req.msg.msg, &new_req.msg.msg) { + (ConsensusMsg::ReplicaPrepare(_), ConsensusMsg::ReplicaPrepare(_)) => true, + (ConsensusMsg::ReplicaCommit(_), ConsensusMsg::ReplicaCommit(_)) => true, + _ => false, + } + } } diff --git a/node/actors/bft/src/lib.rs b/node/actors/bft/src/lib.rs index a4309c9c..74b750c2 100644 --- a/node/actors/bft/src/lib.rs +++ b/node/actors/bft/src/lib.rs @@ -14,10 +14,12 @@ //! - [Notes on modern consensus algorithms](https://timroughgarden.github.io/fob21/andy.pdf) //! - [Blog post comparing several consensus algorithms](https://decentralizedthoughts.github.io/2023-04-01-hotstuff-2/) //! - Blog posts explaining [safety](https://seafooler.com/2022/01/24/understanding-safety-hotstuff/) and [responsiveness](https://seafooler.com/2022/04/02/understanding-responsiveness-hotstuff/) + use crate::io::{InputMessage, OutputMessage}; +pub use config::Config; use std::sync::Arc; use zksync_concurrency::{ctx, scope}; -use zksync_consensus_roles::validator; +use zksync_consensus_roles::validator::{self, ConsensusMsg}; use zksync_consensus_utils::pipe::ActorPipe; mod config; @@ -30,8 +32,6 @@ pub mod testonly; #[cfg(test)] mod tests; -pub use config::Config; - /// Protocol version of this BFT implementation. pub const PROTOCOL_VERSION: validator::ProtocolVersion = validator::ProtocolVersion::EARLIEST; @@ -65,15 +65,19 @@ impl Config { mut pipe: ActorPipe, ) -> anyhow::Result<()> { let cfg = Arc::new(self); + let (leader, leader_send) = leader::StateMachine::new(ctx, cfg.clone(), pipe.send.clone()); + let (replica, replica_send) = + replica::StateMachine::start(ctx, cfg.clone(), pipe.send.clone()).await?; + let res = scope::run!(ctx, |ctx, s| async { - let mut replica = - replica::StateMachine::start(ctx, cfg.clone(), pipe.send.clone()).await?; - let mut leader = leader::StateMachine::new(ctx, cfg.clone(), pipe.send.clone()); + let prepare_qc_recv = leader.prepare_qc.subscribe(); + s.spawn_bg(replica.run(ctx)); + s.spawn_bg(leader.run(ctx)); s.spawn_bg(leader::StateMachine::run_proposer( ctx, &cfg, - leader.prepare_qc.subscribe(), + prepare_qc_recv, &pipe.send, )); @@ -82,11 +86,7 @@ impl Config { // This is the infinite loop where the consensus actually runs. The validator waits for either // a message from the network or for a timeout, and processes each accordingly. loop { - let input = pipe - .recv - .recv(&ctx.with_deadline(replica.timeout_deadline)) - .await - .ok(); + let input = pipe.recv.recv(ctx).await; // We check if the context is active before processing the input. If the context is not active, // we stop. @@ -94,24 +94,15 @@ impl Config { return Ok(()); } - let Some(InputMessage::Network(req)) = input else { - replica.start_new_view(ctx).await?; - continue; - }; - - use validator::ConsensusMsg as Msg; - let res = match &req.msg.msg { - Msg::ReplicaPrepare(_) | Msg::ReplicaCommit(_) => { - leader.process_input(ctx, req.msg).await + let InputMessage::Network(req) = input.unwrap(); + match &req.msg.msg { + ConsensusMsg::ReplicaPrepare(_) | ConsensusMsg::ReplicaCommit(_) => { + leader_send.send(req); } - Msg::LeaderPrepare(_) | Msg::LeaderCommit(_) => { - replica.process_input(ctx, req.msg).await + ConsensusMsg::LeaderPrepare(_) | ConsensusMsg::LeaderCommit(_) => { + replica_send.send(req); } - }; - // Notify network actor that the message has been processed. - // Ignore sending error. - let _ = req.ack.send(()); - res?; + } } }) .await; diff --git a/node/actors/bft/src/replica/leader_prepare.rs b/node/actors/bft/src/replica/leader_prepare.rs index 50e9dcf4..04b6a2e9 100644 --- a/node/actors/bft/src/replica/leader_prepare.rs +++ b/node/actors/bft/src/replica/leader_prepare.rs @@ -335,7 +335,7 @@ impl StateMachine { .sign_msg(validator::ConsensusMsg::ReplicaCommit(commit_vote)), recipient: Target::Validator(author.clone()), }; - self.pipe.send(output_message.into()); + self.outbound_pipe.send(output_message.into()); Ok(()) } diff --git a/node/actors/bft/src/replica/new_view.rs b/node/actors/bft/src/replica/new_view.rs index 6495dd3d..b1f73d7a 100644 --- a/node/actors/bft/src/replica/new_view.rs +++ b/node/actors/bft/src/replica/new_view.rs @@ -38,7 +38,7 @@ impl StateMachine { )), recipient: Target::Validator(self.config.view_leader(next_view)), }; - self.pipe.send(output_message.into()); + self.outbound_pipe.send(output_message.into()); // Reset the timer. self.reset_timer(ctx); diff --git a/node/actors/bft/src/replica/state_machine.rs b/node/actors/bft/src/replica/state_machine.rs index 327cacbd..15e62335 100644 --- a/node/actors/bft/src/replica/state_machine.rs +++ b/node/actors/bft/src/replica/state_machine.rs @@ -3,9 +3,9 @@ use std::{ collections::{BTreeMap, HashMap}, sync::Arc, }; -use tracing::instrument; -use zksync_concurrency::{ctx, error::Wrap as _, metrics::LatencyHistogramExt as _, time}; -use zksync_consensus_roles::validator; +use zksync_concurrency::{ctx, error::Wrap as _, metrics::LatencyHistogramExt as _, sync, time}; +use zksync_consensus_network::io::ConsensusReq; +use zksync_consensus_roles::{validator, validator::ConsensusMsg}; use zksync_consensus_storage as storage; /// The StateMachine struct contains the state of the replica. This is the most complex state machine and is responsible @@ -15,7 +15,9 @@ pub(crate) struct StateMachine { /// Consensus configuration and output channel. pub(crate) config: Arc, /// Pipe through which replica sends network messages. - pub(super) pipe: OutputSender, + pub(super) outbound_pipe: OutputSender, + /// Pipe through which replica receives network requests. + inbound_pipe: sync::prunable_mpsc::Receiver, /// The current view number. pub(crate) view: validator::ViewNumber, /// The current phase. @@ -32,13 +34,17 @@ pub(crate) struct StateMachine { } impl StateMachine { - /// Creates a new StateMachine struct. We try to recover a past state from the storage module, - /// otherwise we initialize the state machine with whatever head block we have. + /// Creates a new [`StateMachine`] instance, attempting to recover a past state from the storage module, + /// otherwise initializes the state machine with the current head block. + /// + /// Returns a tuple containing: + /// * The newly created [`StateMachine`] instance. + /// * A sender handle that should be used to send values to be processed by the instance, asynchronously. pub(crate) async fn start( ctx: &ctx::Ctx, config: Arc, - pipe: OutputSender, - ) -> ctx::Result { + outbound_pipe: OutputSender, + ) -> ctx::Result<(Self, sync::prunable_mpsc::Sender)> { let backup = match config.replica_store.state(ctx).await? { Some(backup) => backup, None => config.block_store.subscribe().borrow().last.clone().into(), @@ -51,9 +57,12 @@ impl StateMachine { .insert(proposal.payload.hash(), proposal.payload); } + let (send, recv) = sync::prunable_mpsc::channel(StateMachine::inbound_pruning_predicate); + let mut this = Self { config, - pipe, + outbound_pipe, + inbound_pipe: recv, view: backup.view, phase: backup.phase, high_vote: backup.high_vote, @@ -61,56 +70,74 @@ impl StateMachine { block_proposal_cache, timeout_deadline: time::Deadline::Infinite, }; + // We need to start the replica before processing inputs. this.start_new_view(ctx).await.wrap("start_new_view()")?; - Ok(this) + + Ok((this, send)) } - /// Process an input message (it will be None if the channel timed out waiting for a message). This is - /// the main entry point for the state machine. We need read-access to the inner consensus struct. - /// As a result, we can modify our state machine or send a message to the executor. - #[instrument(level = "trace", ret)] - pub(crate) async fn process_input( - &mut self, - ctx: &ctx::Ctx, - input: validator::Signed, - ) -> ctx::Result<()> { - let now = ctx.now(); - let label = match &input.msg { - validator::ConsensusMsg::LeaderPrepare(_) => { - let res = match self - .process_leader_prepare(ctx, input.cast().unwrap()) - .await - .wrap("process_leader_prepare()") - { - Err(super::leader_prepare::Error::Internal(err)) => return Err(err), - Err(err) => { - tracing::warn!("process_leader_prepare(): {err:#}"); - Err(()) - } - Ok(()) => Ok(()), - }; - metrics::ConsensusMsgLabel::LeaderPrepare.with_result(&res) - } - validator::ConsensusMsg::LeaderCommit(_) => { - let res = match self - .process_leader_commit(ctx, input.cast().unwrap()) - .await - .wrap("process_leader_commit()") - { - Err(super::leader_commit::Error::Internal(err)) => return Err(err), - Err(err) => { - tracing::warn!("process_leader_commit(): {err:#}"); - Err(()) - } - Ok(()) => Ok(()), - }; - metrics::ConsensusMsgLabel::LeaderCommit.with_result(&res) + /// Runs a loop to process incoming messages (may be `None` if the channel times out while waiting for a message). + /// This is the main entry point for the state machine, + /// potentially triggering state modifications and message sending to the executor. + pub(crate) async fn run(mut self, ctx: &ctx::Ctx) -> ctx::Result<()> { + loop { + let recv = self + .inbound_pipe + .recv(&ctx.with_deadline(self.timeout_deadline)) + .await; + + // Check for non-timeout cancellation. + if !ctx.is_active() { + return Ok(()); } - _ => unreachable!(), - }; - metrics::METRICS.replica_processing_latency[&label].observe_latency(ctx.now() - now); - Ok(()) + + // Check for timeout. + let Some(req) = recv.ok() else { + self.start_new_view(ctx).await?; + continue; + }; + + let now = ctx.now(); + let label = match &req.msg.msg { + ConsensusMsg::LeaderPrepare(_) => { + let res = match self + .process_leader_prepare(ctx, req.msg.cast().unwrap()) + .await + .wrap("process_leader_prepare()") + { + Err(super::leader_prepare::Error::Internal(err)) => return Err(err), + Err(err) => { + tracing::warn!("process_leader_prepare(): {err:#}"); + Err(()) + } + Ok(()) => Ok(()), + }; + metrics::ConsensusMsgLabel::LeaderPrepare.with_result(&res) + } + ConsensusMsg::LeaderCommit(_) => { + let res = match self + .process_leader_commit(ctx, req.msg.cast().unwrap()) + .await + .wrap("process_leader_commit()") + { + Err(super::leader_commit::Error::Internal(err)) => return Err(err), + Err(err) => { + tracing::warn!("process_leader_commit(): {err:#}"); + Err(()) + } + Ok(()) => Ok(()), + }; + metrics::ConsensusMsgLabel::LeaderCommit.with_result(&res) + } + _ => unreachable!(), + }; + metrics::METRICS.replica_processing_latency[&label].observe_latency(ctx.now() - now); + + // Notify network actor that the message has been processed. + // Ignore sending error. + let _ = req.ack.send(()); + } } /// Backups the replica state to disk. @@ -136,4 +163,16 @@ impl StateMachine { .wrap("put_replica_state")?; Ok(()) } + + #[allow(clippy::match_like_matches_macro)] + fn inbound_pruning_predicate(pending_req: &ConsensusReq, new_req: &ConsensusReq) -> bool { + if pending_req.msg.key != new_req.msg.key { + return false; + } + match (&pending_req.msg.msg, &new_req.msg.msg) { + (ConsensusMsg::LeaderPrepare(_), ConsensusMsg::LeaderPrepare(_)) => true, + (ConsensusMsg::LeaderCommit(_), ConsensusMsg::LeaderCommit(_)) => true, + _ => false, + } + } } diff --git a/node/actors/bft/src/replica/timer.rs b/node/actors/bft/src/replica/timer.rs index c5231fa8..df5ef3b3 100644 --- a/node/actors/bft/src/replica/timer.rs +++ b/node/actors/bft/src/replica/timer.rs @@ -13,6 +13,7 @@ impl StateMachine { pub(crate) fn reset_timer(&mut self, ctx: &ctx::Ctx) { let timeout = Self::BASE_DURATION * 2u32.pow((self.view.0 - self.high_qc.message.view.0) as u32); + metrics::METRICS.replica_view_timeout.set_latency(timeout); self.timeout_deadline = time::Deadline::Finite(ctx.now() + timeout); } diff --git a/node/actors/bft/src/testonly/ut_harness.rs b/node/actors/bft/src/testonly/ut_harness.rs index f3d1b164..22eda0a1 100644 --- a/node/actors/bft/src/testonly/ut_harness.rs +++ b/node/actors/bft/src/testonly/ut_harness.rs @@ -67,8 +67,8 @@ impl UTHarness { payload_manager, max_payload_size: MAX_PAYLOAD_SIZE, }); - let leader = leader::StateMachine::new(ctx, cfg.clone(), send.clone()); - let replica = replica::StateMachine::start(ctx, cfg.clone(), send.clone()) + let (leader, _) = leader::StateMachine::new(ctx, cfg.clone(), send.clone()); + let (replica, _) = replica::StateMachine::start(ctx, cfg.clone(), send.clone()) .await .unwrap(); let mut this = UTHarness { @@ -214,9 +214,14 @@ impl UTHarness { self.leader.process_replica_prepare(ctx, msg).await?; if prepare_qc.has_changed().unwrap() { let prepare_qc = prepare_qc.borrow().clone().unwrap(); - leader::StateMachine::propose(ctx, &self.leader.config, prepare_qc, &self.leader.pipe) - .await - .unwrap(); + leader::StateMachine::propose( + ctx, + &self.leader.config, + prepare_qc, + &self.leader.outbound_pipe, + ) + .await + .unwrap(); } Ok(self.try_recv()) } diff --git a/node/libs/concurrency/Cargo.toml b/node/libs/concurrency/Cargo.toml index ba8ed85d..2597d7fa 100644 --- a/node/libs/concurrency/Cargo.toml +++ b/node/libs/concurrency/Cargo.toml @@ -20,4 +20,7 @@ tracing-subscriber.workspace = true vise.workspace = true [lints] -workspace = true \ No newline at end of file +workspace = true + +[dev-dependencies] +assert_matches.workspace = true \ No newline at end of file diff --git a/node/libs/concurrency/src/sync/mod.rs b/node/libs/concurrency/src/sync/mod.rs index f424b4b7..7b5587a6 100644 --- a/node/libs/concurrency/src/sync/mod.rs +++ b/node/libs/concurrency/src/sync/mod.rs @@ -10,6 +10,7 @@ pub use tokio::{ task::yield_now, }; +pub mod prunable_mpsc; #[cfg(test)] mod tests; diff --git a/node/libs/concurrency/src/sync/prunable_mpsc/mod.rs b/node/libs/concurrency/src/sync/prunable_mpsc/mod.rs new file mode 100644 index 00000000..b559112b --- /dev/null +++ b/node/libs/concurrency/src/sync/prunable_mpsc/mod.rs @@ -0,0 +1,101 @@ +//! Prunable, multi-producer, single-consumer, unbounded FIFO queue for communicating between asynchronous tasks. +//! The pruning takes place whenever a new value is sent, based on a specified predicate. +//! +//! The separation of [`Sender`] and [`Receiver`] is employed primarily because [`Receiver`] requires +//! a mutable reference to the signaling channel, unlike [`Sender`], hence making it undesirable to +//! be used in conjunction. +//! +use crate::{ + ctx, + sync::{self, watch}, +}; +use std::{collections::VecDeque, fmt, sync::Arc}; + +#[cfg(test)] +mod tests; + +/// Creates a channel and returns the [`Sender`] and [`Receiver`] handles. +/// All values sent on [`Sender`] will become available on [`Receiver`] in the same order as it was sent, +/// unless will be pruned before received. +/// The Sender can be cloned to send to the same channel from multiple code locations. Only one Receiver is supported. +/// +/// * [`T`]: The type of data that will be sent through the channel. +/// * [`pruning_predicate`]: A function that determines whether an unreceived, pending value in the buffer (represented by the first `T`) should be pruned +/// based on a newly sent value (represented by the second `T`). +pub fn channel( + pruning_predicate: impl 'static + Sync + Send + Fn(&T, &T) -> bool, +) -> (Sender, Receiver) { + let buf = VecDeque::new(); + let (send, recv) = watch::channel(buf); + + let shared = Arc::new(Shared { send }); + + let send = Sender { + shared: shared.clone(), + pruning_predicate: Box::new(pruning_predicate), + }; + + let recv = Receiver { + shared: shared.clone(), + recv, + }; + + (send, recv) +} + +struct Shared { + send: watch::Sender>, +} + +/// Sends values to the associated [`Receiver`]. +/// Instances are created by the [`channel`] function. +#[allow(clippy::type_complexity)] +pub struct Sender { + shared: Arc>, + pruning_predicate: Box bool>, +} + +impl Sender { + /// Sends a value. + /// This initiates the pruning procedure which operates in O(N) time complexity + /// on the buffer of pending values. + pub fn send(&self, value: T) { + self.shared.send.send_modify(|buf| { + buf.retain(|pending_value| !(self.pruning_predicate)(pending_value, &value)); + buf.push_back(value); + }); + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Sender").finish() + } +} + +/// Receives values from the associated [`Sender`]. +/// Instances are created by the [`channel`] function. +pub struct Receiver { + shared: Arc>, + recv: watch::Receiver>, +} + +impl Receiver { + /// Receives the next value for this receiver. + /// If there are no messages in the buffer, this method will hang until a message is sent. + pub async fn recv(&mut self, ctx: &ctx::Ctx) -> ctx::OrCanceled { + sync::wait_for(ctx, &mut self.recv, |buf| !buf.is_empty()).await?; + + let mut value: Option = None; + self.shared.send.send_modify(|buf| value = buf.pop_front()); + + // `None` is unexpected because we waited for new values, and there's only a single receiver. + Ok(value.unwrap()) + } +} + +impl fmt::Debug for Receiver { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Receiver").finish() + } +} diff --git a/node/libs/concurrency/src/sync/prunable_mpsc/tests.rs b/node/libs/concurrency/src/sync/prunable_mpsc/tests.rs new file mode 100644 index 00000000..38e26eb8 --- /dev/null +++ b/node/libs/concurrency/src/sync/prunable_mpsc/tests.rs @@ -0,0 +1,68 @@ +use super::*; +use crate::ctx; +use assert_matches::assert_matches; + +// Test scenario: +// 1. Pre-send two sets of 1000 values, so that the first set is expected to be pruned. +// 2. Send a third set of 1000 values in parallel to receiving. +#[tokio::test] +async fn test_prunable_mpsc() { + crate::testonly::abort_on_panic(); + let ctx = ctx::test_root(&ctx::RealClock); + + #[derive(Debug, Clone)] + struct ValueType(usize, usize); + + #[allow(clippy::type_complexity)] + let (send, recv): (Sender, Receiver) = + channel(|a: &ValueType, b: &ValueType| { + // Prune values with the same i. + a.1 == b.1 + }); + + let res: Result<(), ctx::Canceled> = crate::scope::run!(&ctx, |ctx, s| async move { + // Pre-send sets 0 and 1, 1000 values each. + // Set 0 is expected to be pruned and dropped. + let values = (0..2000).map(|i| ValueType(i / 1000, i % 1000)); + for val in values { + send.send(val); + } + // Send set 2. + s.spawn(async { + let send = send; + let values = (1000..2000).map(|i| ValueType(2, i)); + for val in values { + send.send(val.clone()); + } + Ok(()) + }); + // Receive. + s.spawn(async { + let mut recv = recv; + let mut i = 0; + loop { + let val = recv.recv(ctx).await.unwrap(); + assert_eq!(val.1, i); + // Assert the expected set. + match val.1 { + 0..=999 => assert_eq!(val.0, 1), + 1000..=1999 => assert_eq!(val.0, 2), + _ => unreachable!(), + }; + i += 1; + if i == 2000 { + assert_matches!( + recv.recv(&ctx.with_timeout(time::Duration::milliseconds(10))).await, + Err(ctx::Canceled), + "recv() is expected to hang and be canceled since all values have been exhausted" + ); + break; + } + } + Ok(()) + }); + Ok(()) + }) + .await; + assert_eq!(Ok(()), res); +}