diff --git a/CHANGELOG.md b/CHANGELOG.md index f6ea9b6..a4f93a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed + +* There is now a timeout for how long a peer can take to accept an error message. +* Application errors (`ErrorKind::OTHER`) are now truncated to fit into a single frame. + +### Fixed + +* The IO layer will no longer drop frames if no multi-frame payloads are sent while a non-multi-frame payload has been moved to the wait queue due to exceeding the in-flight request limit. +* The outgoing request queue will now process much faster in some cases when filled with large numbers of requests. +* The `io` layer will no longer attempt to allocate incredibly large amounts of memory under certain error conditions. + ## [0.2.0] - 2023-11-24 ### Changed diff --git a/src/io.rs b/src/io.rs index 346f921..9d27998 100644 --- a/src/io.rs +++ b/src/io.rs @@ -25,13 +25,14 @@ //! [`IoCore`] to close the connection. use std::{ - collections::{BTreeSet, VecDeque}, + collections::VecDeque, fmt::{self, Display, Formatter}, io, sync::{ atomic::{AtomicU64, Ordering}, Arc, }, + time::Duration, }; use bimap::BiMap; @@ -55,6 +56,9 @@ use crate::{ ChannelId, Id, Outcome, }; +/// Maximum number of bytes to pre-allocate in buffers. +const MAX_ALLOC: usize = 32 * 1024; // 32 KiB + /// An item in the outgoing queue. /// /// Requests are not transformed into messages in the queue to conserve limited request ID space. @@ -181,6 +185,13 @@ pub enum CoreError { /// Failed to write using underlying writer. #[error("write failed")] WriteFailed(#[source] io::Error), + + /// Could not send an error in time. + /// + /// A limit is imposed on how long a peer may take to receive an error to avoid denial of + /// service through receiving these very slowly. + #[error("peer did not accept error in timely manner")] + ErrorWriteTimeout, /// Remote peer will/has disconnect(ed), but sent us an error message before. #[error("remote peer sent error [channel {}/id {}]: {} (payload: {} bytes)", header.channel(), @@ -245,8 +256,10 @@ pub struct IoCore { /// /// Used to ensure we don't attempt to parse too often. next_parse_at: usize, - /// Whether or not we are shutting down due to an error. - shutting_down_due_to_err: bool, + /// The error queued to be sent before shutting down. + pending_error: Option, + /// The maximum time allowed for a peer to receive an error. + error_timeout: Duration, /// The frame in the process of being sent, which may be partially transferred already. current_frame: Option, @@ -260,11 +273,9 @@ pub struct IoCore { receiver: UnboundedReceiver, /// Mapping for outgoing requests, mapping internal IDs to public ones. request_map: BiMap, - /// A set of channels whose wait queues should be checked again for data to send. - dirty_channels: BTreeSet, } -/// Shared data between a handles and the core itself. +/// Shared data between handles and the core itself. #[derive(Debug)] #[repr(transparent)] struct IoShared { @@ -370,6 +381,8 @@ pub struct IoCoreBuilder { protocol: ProtocolBuilder, /// Number of additional requests to buffer, per channel. buffer_size: [usize; N], + /// The maximum time allowed for a peer to receive an error. + error_timeout: Duration, } impl IoCoreBuilder { @@ -388,6 +401,7 @@ impl IoCoreBuilder { Self { protocol, buffer_size: [default_buffer_size; N], + error_timeout: Duration::from_secs(10), } } @@ -404,6 +418,15 @@ impl IoCoreBuilder { self } + /// Sets the maximum time a peer is allowed to take to receive an error. + /// + /// This is a grace time given to peers to be notified of bad behavior before the connection + /// will be closed. + pub const fn error_timeout(mut self, error_timeout: Duration) -> Self { + self.error_timeout = error_timeout; + self + } + /// Builds a new [`IoCore`] with a [`RequestHandle`]. /// /// See [`IoCore::next_event`] for details on how to handle the core. The [`RequestHandle`] can @@ -417,14 +440,14 @@ impl IoCoreBuilder { writer, buffer: BytesMut::new(), next_parse_at: 0, - shutting_down_due_to_err: false, + pending_error: None, + error_timeout: self.error_timeout, current_frame: None, active_multi_frame: [Default::default(); N], ready_queue: Default::default(), wait_queue: array_init::array_init(|_| Default::default()), receiver, request_map: Default::default(), - dirty_channels: Default::default(), }; let shared = Arc::new(IoShared { @@ -456,9 +479,22 @@ where /// Polling of this function must continue only until `Err(_)` or `Ok(None)` is returned, /// indicating that the connection should be closed or has been closed. pub async fn next_event(&mut self) -> Result, CoreError> { - loop { - self.process_dirty_channels()?; + if let Some(ref mut pending_error) = self.pending_error { + tokio::time::timeout(self.error_timeout, self.writer.write_all_buf(pending_error)) + .await + .map_err(|_elapsed| CoreError::ErrorWriteTimeout)? + .map_err(CoreError::WriteFailed)?; + + // We succeeded writing, clear the error. + let peers_crime = self + .pending_error + .take() + .expect("pending_error should not have disappeared") + .header(); + return Err(CoreError::RemoteProtocolViolation(peers_crime)); + } + loop { if self.next_parse_at <= self.buffer.remaining() { // Simplify reasoning about this code. self.next_parse_at = 0; @@ -473,6 +509,12 @@ where self.inject_error(err_msg); } Outcome::Success(successful_read) => { + // If we received a response, we may have additional capacity available to + // send out more requests, so we process the wait queue. + if let CompletedRead::ReceivedResponse { channel, .. } = &successful_read { + self.process_wait_queue(*channel)?; + } + // Check if we have produced an event. return self.handle_completed_read(successful_read).map(Some); } @@ -504,22 +546,30 @@ where write_result.map_err(CoreError::WriteFailed)?; - // If we just finished sending an error, it's time to exit. - let frame_sent = self.current_frame.take().unwrap(); - - #[cfg(feature = "tracing")] - { + // Clear `current_frame` via `Option::take` and examine what was sent. + if let Some(frame_sent) = self.current_frame.take() { + #[cfg(feature = "tracing")] tracing::trace!(frame=%frame_sent, "sent"); - } - if frame_sent.header().is_error() { - // We finished sending an error frame, time to exit. - return Err(CoreError::RemoteProtocolViolation(frame_sent.header())); + if frame_sent.header().is_error() { + // We finished sending an error frame, time to exit. + return Err(CoreError::RemoteProtocolViolation(frame_sent.header())); + } + + // TODO: We should restrict the dirty-queue processing here a little bit + // (only check when completing a multi-frame message). + // A message has completed sending, process the wait queue in case we have + // to start sending a multi-frame message like a response that was delayed + // only because of the one-multi-frame-per-channel restriction. + self.process_wait_queue(frame_sent.header().channel())?; + } else { + #[cfg(feature = "tracing")] + tracing::error!("current frame should not disappear"); } } // Reading incoming data. - read_result = read_until_bytesmut(&mut self.reader, &mut self.buffer, self.next_parse_at), if !self.shutting_down_due_to_err => { + read_result = read_until_bytesmut(&mut self.reader, &mut self.buffer, self.next_parse_at) => { // Our read function will not return before `read_until_bytesmut` has completed. let read_complete = read_result.map_err(CoreError::ReadFailed)?; @@ -532,14 +582,13 @@ where } // Processing locally queued things. - incoming = self.receiver.recv(), if !self.shutting_down_due_to_err => { + incoming = self.receiver.recv() => { match incoming { Some(item) => { self.handle_incoming_item(item)?; } None => { - // If the receiver was closed it means that we locally shut down the - // connection. + // If the receiver was closed we locally shut down the connection. #[cfg(feature = "tracing")] tracing::info!("local shutdown"); return Ok(None); @@ -554,7 +603,7 @@ where Err(TryRecvError::Disconnected) => { // While processing incoming items, the last handle was closed. #[cfg(feature = "tracing")] - tracing::debug!("last local io handle closed, shutting down"); + tracing::info!("last local io handle closed, shutting down"); return Ok(None); } Err(TryRecvError::Empty) => { @@ -571,13 +620,10 @@ where /// Ensures the next message sent is an error message. /// /// Clears all buffers related to sending and closes the local incoming channel. - fn inject_error(&mut self, err_msg: OutgoingMessage) { + fn inject_error(&mut self, mut err_msg: OutgoingMessage) { // Stop accepting any new local data. self.receiver.close(); - // Set the error state. - self.shutting_down_due_to_err = true; - // We do not continue parsing, ever again. self.next_parse_at = usize::MAX; @@ -589,8 +635,15 @@ where queue.clear(); } - // Ensure the error message is the next frame sent. - self.ready_queue.push_front(err_msg.frames()); + // Ensure the error message is the next frame sent, truncating as needed. + let max_frame_size = self.juliet.max_frame_size(); + err_msg.truncate_to_single_frame(max_frame_size); + let (frame, _remainder) = err_msg.frames().next_owned(max_frame_size); + debug_assert!( + _remainder.is_none(), + "should not have more than one frame after truncating to fit into single frame" + ); + self.pending_error = Some(frame); } /// Processes a completed read into a potential event. @@ -649,7 +702,7 @@ where /// Handles a new item to send out that arrived through the incoming channel. fn handle_incoming_item(&mut self, item: QueuedItem) -> Result<(), LocalProtocolViolation> { // Check if the item is sendable immediately. - if let Some(channel) = item_should_wait(&item, &self.juliet, &self.active_multi_frame) { + if let Some(channel) = item_should_wait(&item, &self.juliet, &self.active_multi_frame)? { #[cfg(feature = "tracing")] tracing::debug!(%item, "postponing send"); self.wait_queue[channel.get() as usize].push_back(item); @@ -658,18 +711,11 @@ where #[cfg(feature = "tracing")] tracing::debug!(%item, "ready to send"); - self.send_to_ready_queue(item, false) + self.send_to_ready_queue(item) } /// Sends an item directly to the ready queue, causing it to be sent out eventually. - /// - /// `item` is passed as a mutable reference for compatibility with functions like `retain_mut`, - /// but will be left with all payloads removed, thus should likely not be reused. - fn send_to_ready_queue( - &mut self, - item: QueuedItem, - check_for_cancellation: bool, - ) -> Result<(), LocalProtocolViolation> { + fn send_to_ready_queue(&mut self, item: QueuedItem) -> Result<(), LocalProtocolViolation> { match item { QueuedItem::Request { io_id, @@ -677,18 +723,10 @@ where payload, permit, } => { - // "Chase" our own requests here -- if the request was still in the wait queue, - // we can cancel it by checking if the `IoId` has been removed in the meantime. - // - // Note that this only cancels multi-frame requests. - if check_for_cancellation && !self.request_map.contains_left(&io_id) { - // We just ignore the request, as it has been cancelled in the meantime. - } else { - let msg = self.juliet.create_request(channel, payload)?; - let id = msg.header().id(); - self.request_map.insert(io_id, (channel, id)); - self.ready_queue.push_back(msg.frames()); - } + let msg = self.juliet.create_request(channel, payload)?; + let id = msg.header().id(); + self.request_map.insert(io_id, (channel, id)); + self.ready_queue.push_back(msg.frames()); drop(permit); } @@ -736,14 +774,15 @@ where /// Clears a potentially finished frame and returns the next frame to send. /// - /// Returns `None` if no frames are ready to be sent. Note that there may be frames waiting - /// that cannot be sent due them being multi-frame messages when there already is a multi-frame - /// message in progress, or request limits are being hit. + /// Note that there may be frames waiting that cannot be sent due them being multi-frame + /// messages when there already is a multi-frame message in progress, or request limits are + /// being hit. + /// + /// The caller needs to ensure that the current frame is empty (i.e. has been sent). fn ready_next_frame(&mut self) -> Result<(), LocalProtocolViolation> { debug_assert!(self.current_frame.is_none()); // Must be guaranteed by caller. - // Try to fetch a frame from the ready queue. If there is nothing, we are stuck until the - // next time the wait queue is processed or new data arrives. + // Try to fetch a frame from the ready queue. let (frame, additional_frames) = match self.ready_queue.pop_front() { Some(item) => item, None => return Ok(()), @@ -763,9 +802,6 @@ where // Once the scheduled frame is processed, we will finished the multi-frame // transfer, so we can allow for the next multi-frame transfer to be scheduled. self.active_multi_frame[about_to_finish.channel().get() as usize] = None; - - // There is a chance another multi-frame messages became ready now. - self.dirty_channels.insert(about_to_finish.channel()); } } } @@ -774,56 +810,49 @@ where Ok(()) } - /// Process the wait queue of all channels marked dirty, promoting messages that are ready to be - /// sent to the ready queue. - fn process_dirty_channels(&mut self) -> Result<(), CoreError> { - while let Some(channel) = self.dirty_channels.pop_first() { - let wait_queue_len = self.wait_queue[channel.get() as usize].len(); - - // The code below is not as bad it looks complexity wise, anticipating two common cases: - // - // 1. A multi-frame read has finished, with capacity for requests to spare. Only - // multi-frame requests will be waiting in the wait queue, so we will likely pop the - // first item, only scanning the rest once. - // 2. One or more requests finished, so we also have a high chance of picking the first - // few requests out of the queue. - - for _ in 0..(wait_queue_len) { - let item = self.wait_queue[channel.get() as usize].pop_front().ok_or( - CoreError::InternalError("did not expect wait_queue to disappear"), - )?; - - if item_should_wait(&item, &self.juliet, &self.active_multi_frame).is_some() { - // Put it right back into the queue. - self.wait_queue[channel.get() as usize].push_back(item); - } else { - self.send_to_ready_queue(item, true)?; + /// Process the wait queue of a given channel, promoting messages that are ready to be sent. + fn process_wait_queue(&mut self, channel: ChannelId) -> Result<(), LocalProtocolViolation> { + let mut remaining = self.wait_queue[channel.get() as usize].len(); + + while let Some(item) = self.wait_queue[channel.get() as usize].pop_front() { + if item_should_wait(&item, &self.juliet, &self.active_multi_frame)?.is_some() { + // Put it right back into the queue. + self.wait_queue[channel.get() as usize].push_back(item); + } else { + self.send_to_ready_queue(item)?; + + // No need to look further if we have saturated the channel. + if !self.juliet.allowed_to_send_request(channel)? { + break; } } + + // Ensure we do not loop endlessly if we cannot find anything. + remaining -= 1; + if remaining == 0 { + break; + } } Ok(()) } } -/// Determines whether an item is ready to be moved from the wait queue from the ready queue. +/// Determines whether an item is ready to be moved from the wait queue to the ready queue. /// -/// Returns `None` if the item does not need to wait. Otherwise, the items channel ID is returned. +/// Returns `None` if the item does not need to wait. Otherwise, the item's channel ID is returned. fn item_should_wait( item: &QueuedItem, juliet: &JulietProtocol, active_multi_frame: &[Option
; N], -) -> Option { +) -> Result, LocalProtocolViolation> { let (payload, channel) = match item { QueuedItem::Request { channel, payload, .. } => { // Check if we cannot schedule due to the message exceeding the request limit. - if !juliet - .allowed_to_send_request(*channel) - .expect("should not be called with invalid channel") - { - return Some(*channel); + if !juliet.allowed_to_send_request(*channel)? { + return Ok(Some(*channel)); } (payload, channel) @@ -835,7 +864,7 @@ fn item_should_wait( // Other messages are always ready. QueuedItem::RequestCancellation { .. } | QueuedItem::ResponseCancellation { .. } - | QueuedItem::Error { .. } => return None, + | QueuedItem::Error { .. } => return Ok(None), }; let active_multi_frame = active_multi_frame[channel.get() as usize]; @@ -845,13 +874,13 @@ fn item_should_wait( if active_multi_frame.is_some() { if let Some(payload) = payload { if payload_is_multi_frame(juliet.max_frame_size(), payload.len()) { - return Some(*channel); + return Ok(Some(*channel)); } } } // Otherwise, this should be a legitimate add to the run queue. - None + Ok(None) } /// A handle to the input queue to the [`IoCore`] that allows sending requests and responses. @@ -1074,6 +1103,8 @@ impl Handle { /// /// Enqueuing an error causes the [`IoCore`] to begin shutting down immediately, only making an /// effort to finish sending the error before doing so. + /// + /// If payload exceeds what is possible to send in a single frame, it is truncated. pub fn enqueue_error( &self, channel: ChannelId, @@ -1114,7 +1145,9 @@ where R: AsyncReadExt + Sized + Unpin, { let extra_required = target.saturating_sub(buf.remaining()); - buf.reserve(extra_required); + // Note: `reserve` is purely an optimization -- `BufMut::remaining_mut(&mut buf)` will always + // return 2**64-1, which is the number `read_buf` looks at for exiting early. + buf.reserve(extra_required.min(MAX_ALLOC)); while buf.remaining() < target { match reader.read_buf(buf).await { diff --git a/src/protocol.rs b/src/protocol.rs index fa2b102..13f1845 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -95,11 +95,18 @@ impl MaxFrameSize { self.0 as usize } - /// Returns the maximum frame size without the header size. + /// Returns the maximum frame size with the header size subtracted. #[inline(always)] pub const fn without_header(self) -> usize { self.get_usize() - Header::SIZE } + + /// Returns the maximum frame size with the preamble size subtracted, assuming it includes a + /// payload length for `payload_len`. + #[inline(always)] + pub const fn without_preamble(self, payload_len: u32) -> usize { + self.without_header() - Varint32::length_of(payload_len) + } } impl Default for MaxFrameSize { @@ -1012,8 +1019,7 @@ pub const fn payload_is_multi_frame(max_frame_size: MaxFrameSize, payload_len: u "payload cannot exceed `u32::MAX`" ); - payload_len as u64 + Header::SIZE as u64 + (Varint32::encode(payload_len as u32)).len() as u64 - > max_frame_size.get() as u64 + max_frame_size.without_preamble(payload_len as u32) < payload_len } #[cfg(test)] @@ -1051,6 +1057,8 @@ mod tests { SingleFrame, /// A payload that spans more than one frame. MultiFrame, + /// A payload that spans a large number of frames. + LargeMultiFrame, /// A payload that exceeds the request size limit. TooLarge, } @@ -1062,6 +1070,7 @@ mod tests { VaryingPayload::None, VaryingPayload::SingleFrame, VaryingPayload::MultiFrame, + VaryingPayload::LargeMultiFrame, ] .into_iter() } @@ -1072,6 +1081,7 @@ mod tests { VaryingPayload::None => true, VaryingPayload::SingleFrame => false, VaryingPayload::MultiFrame => false, + VaryingPayload::LargeMultiFrame => false, VaryingPayload::TooLarge => false, } } @@ -1111,6 +1121,11 @@ mod tests { b"large payload large payload large payload large payload large payload large payload"; const_assert!(LONG_PAYLOAD.len() > TestingSetup::MAX_FRAME_SIZE as usize); + const VERY_LONG_PAYLOAD: &[u8] = + b"very very large payload very large payload very large payload very large payload very large payload very large payload very very large payload very large payload very large payload very large payload very large payload very large payload very very large payload very large payload very large payload very large payload very large payload very large payload very very large payload very large payload very large payload very large payload very large payload very large payload"; + const_assert!(VERY_LONG_PAYLOAD.len() > TestingSetup::MAX_FRAME_SIZE as usize); + const_assert!(VERY_LONG_PAYLOAD.len() <= TestingSetup::MAX_PAYLOAD_SIZE as usize); + const OVERLY_LONG_PAYLOAD: &[u8] = b"abcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefgh"; const_assert!(OVERLY_LONG_PAYLOAD.len() > TestingSetup::MAX_PAYLOAD_SIZE as usize); @@ -1118,6 +1133,7 @@ mod tests { VaryingPayload::None => None, VaryingPayload::SingleFrame => Some(SHORT_PAYLOAD), VaryingPayload::MultiFrame => Some(LONG_PAYLOAD), + VaryingPayload::LargeMultiFrame => Some(VERY_LONG_PAYLOAD), VaryingPayload::TooLarge => Some(OVERLY_LONG_PAYLOAD), } } @@ -2608,4 +2624,22 @@ mod tests { }) ); } + + #[test] + fn can_send_back_to_back_multi_frame_requests() { + let big_payload_1 = VaryingPayload::LargeMultiFrame; + let big_payload_2 = VaryingPayload::LargeMultiFrame; + + let mut env = TestingSetup::new(); + + let resp1 = env + .create_and_send_request(Alice, big_payload_1.get()) + .expect("should be able to send multiframe request"); + let resp2 = env + .create_and_send_request(Alice, big_payload_2.get()) + .expect("should be able to send multiframe request"); + + assert!(matches!(resp1, CompletedRead::NewRequest { id, .. } if id == Id(1))); + assert!(matches!(resp2, CompletedRead::NewRequest { id, .. } if id == Id(2))); + } } diff --git a/src/protocol/outgoing_message.rs b/src/protocol/outgoing_message.rs index a264027..7a41fbf 100644 --- a/src/protocol/outgoing_message.rs +++ b/src/protocol/outgoing_message.rs @@ -142,6 +142,18 @@ impl OutgoingMessage { let mut everything = self.iter_bytes(max_frame_size); everything.copy_to_bytes(everything.remaining()) } + + /// Truncates the message payload so that it fits into the first frame. + #[inline] + pub fn truncate_to_single_frame(&mut self, max_frame_size: MaxFrameSize) { + if self.is_multi_frame(max_frame_size) { + // Note: There are some edge cases where we might miss one byte due to the `Varint32` + // shrinking, but we're accepting that. + if let Some(ref mut payload) = self.payload { + payload.truncate(max_frame_size.without_preamble(payload.len() as u32)); + } + } + } } /// Combination of header and potential message payload length. diff --git a/src/rpc.rs b/src/rpc.rs index b3d93fe..30c91c6 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -847,17 +847,17 @@ where #[cfg(test)] mod tests { - use std::{collections::BinaryHeap, sync::Arc, time::Duration}; + use std::{collections::BinaryHeap, iter, sync::Arc, time::Duration}; use bytes::Bytes; use futures::FutureExt; use tokio::io::{DuplexStream, ReadHalf, WriteHalf}; - use tracing::{span, Instrument, Level}; + use tracing::{error_span, info, span, Instrument, Level}; use crate::{ - io::IoCoreBuilder, + io::{CoreError, IoCoreBuilder}, protocol::ProtocolBuilder, - rpc::{RequestError, RpcBuilder}, + rpc::{RequestError, RpcBuilder, RpcServerError}, ChannelConfiguration, ChannelId, }; @@ -1168,4 +1168,351 @@ mod tests { assert!(drain_heap_while(&mut empty_heap, |_| true).next().is_none()); } + + /// Parameters for a "large volume" test. + #[derive(Copy, Clone, Debug)] + + struct LargeVolumeTestSpec { + /// Maximum frame size to use. + max_frame_size: u32, + /// Per-channel in-flight request limit. + /// + /// All channels use the same in-flight limit. + request_limit: u16, + /// The "step size" of a payload. + /// + /// Any payload from Bob to Alice will have a size that is a multiple of + /// `payload_step_size`. + payload_step_size: u32, + /// Maximum multiplier for the payload. + /// + /// A random multiplier is chosen for payloads up to `payload_max_multiplier` for those sent + /// from Bob to Alice. + payload_max_multiplier: u32, + /// How many bytes to buffer in the internal in-memory buffer of the transport. + pipe_buffer: usize, + /// How many bytes of payload data to send before ending the test. + /// + /// Measures the amount of data Alice receives from Bob. + min_send_bytes: usize, + /// Timeout for a single message. + timeout: Duration, + } + + impl Default for LargeVolumeTestSpec { + fn default() -> Self { + Self { + max_frame_size: 37, + request_limit: 3, + payload_step_size: 20, + payload_max_multiplier: 10, + pipe_buffer: 80, + min_send_bytes: 100 * 1024 * 1024, // 100 MiB + timeout: Duration::from_millis(250), + } + } + } + + impl LargeVolumeTestSpec { + fn max_payload_size(&self) -> u32 { + self.payload_step_size * self.payload_max_multiplier + } + + fn default_buffer_size(&self) -> usize { + self.request_limit as usize * 2 + } + + /// Generates a "random" payload size. + /// + /// `count` is used as a seed, using very weak randomness. + fn gen_payload_size(&self, count: usize) -> usize { + let multiplier = ((count * 239) % self.payload_max_multiplier as usize) + 1; + self.payload_step_size as usize * multiplier + } + + /// Setup function for RPC testing. + /// + /// Creates two "nodes" linked using an in-memory transport, hopefully with deterministic + /// behavior. + fn mk_rpc(&self) -> (CompleteSetup, CompleteSetup) { + let channel_cfg = ChannelConfiguration::new() + .with_max_request_payload_size(self.max_payload_size()) + .with_max_response_payload_size(self.max_payload_size()) + .with_request_limit(self.request_limit); + + let protocol_builder = ProtocolBuilder::with_default_channel_config(channel_cfg) + .max_frame_size(self.max_frame_size); + let rpc_builder: RpcBuilder = + RpcBuilder::new(IoCoreBuilder::with_default_buffer_size( + protocol_builder, + self.default_buffer_size(), + )) + .with_bubble_timeouts(true) + .with_default_timeout(self.timeout); + + let (alice_stream, bob_stream) = tokio::io::duplex(self.pipe_buffer); + + let alice = CompleteSetup::new(&rpc_builder, alice_stream); + let bob = CompleteSetup::new(&rpc_builder, bob_stream); + + (alice, bob) + } + } + + struct CompleteSetup { + client: JulietRpcClient, + server: JulietRpcServer, WriteHalf>, + } + + impl CompleteSetup { + fn new(builder: &RpcBuilder, duplex: DuplexStream) -> Self { + let (reader, writer) = tokio::io::split(duplex); + let (client, server) = builder.build(reader, writer); + CompleteSetup { client, server } + } + } + + #[tokio::test] + async fn large_volume_setup_smoke_test() { + let (mut alice, mut bob) = LargeVolumeTestSpec::<4>::default().mk_rpc(); + + tokio::spawn(async move { + while let Some(request) = alice + .server + .next_request() + .await + .expect("next request failed") + { + // Simply echo back the payload. + let pl = request.payload().clone(); + request.respond(pl); + } + }); + + tokio::spawn(async move { bob.server.next_request().await }); + + for i in 0i32..10 { + let num: Box<[u8]> = i.to_be_bytes().into(); + let pl = Bytes::from(num); + let handle = bob + .client + .create_request(ChannelId::new(2)) + .with_payload(pl.clone()) + .queue_for_sending() + .await; + + let resp = handle + .wait_for_response() + .await + .expect("should get response") + .expect("should have payload"); + + assert_eq!(resp, pl); + } + } + + #[tokio::test] + async fn run_large_volume_test_single_channel_single_request() { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .ok(); + + let spec = LargeVolumeTestSpec { + request_limit: 1, + max_frame_size: 17, + // 10 Bytes requests means they all fit in one frame. + payload_max_multiplier: 1, + payload_step_size: 10, + ..Default::default() + }; + + large_volume_test::<1>(spec).await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn run_large_volume_test_with_default_values_10_channels() { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .ok(); + + large_volume_test::<10>(Default::default()).await; + } + + async fn large_volume_test(spec: LargeVolumeTestSpec) { + // Our setup is as follows: + // + // 1. All messages are `ACK`'d with empty responses. + // 2. Alice will send a constant stream of small messages to Bob. + // 3. Bob will send a larger message every time he receives a small message from Alice, on + // the same channel. + + let channel_ids: Vec = (0..N).map(|id| ChannelId::new(id as u8)).collect(); + + let (mut alice, mut bob) = LargeVolumeTestSpec::::default().mk_rpc(); + + // Alice server. Will close the connection after enough bytes have been sent. + let mut remaining = spec.min_send_bytes; + let alice_server = tokio::spawn( + async move { + while let Some(request) = alice + .server + .next_request() + .await + .expect("next request failed") + { + let payload_size = request + .payload() + .as_ref() + .expect("should have payload in bobs request") + .len(); + // Just discard the message payload, but acknowledge receiving it. + request.respond(None); + + remaining = remaining.saturating_sub(payload_size); + if remaining == 0 { + // We've reached the volume we were looking for, end test. + break; + } + } + + info!("exiting"); + } + .instrument(error_span!("alice_server")), + ); + + let small_payload: Bytes = iter::repeat(0xFF) + .take(spec.max_frame_size as usize / 2) + .collect::>() + .into(); + + // Alice client. Will shut down once bob closes the connection. + let alice_client = tokio::spawn( + async move { + let mut next_channel = channel_ids.iter().cloned().cycle(); + + let mut alice_counter = 0; + loop { + let small_request = alice + .client + .create_request(next_channel.next().unwrap()) + .with_payload(small_payload.clone()) + .queue_for_sending() + .await; + info!(alice_counter, "alice enqueued request"); + alice_counter += 1; + + match small_request.try_get_response() { + Ok(Ok(_)) => { + // A surprise to be sure, but a welcome one (very fast answer). + } + Ok(Err(err)) => match err { + RequestError::RemoteClosed(_) | RequestError::Shutdown => break, + RequestError::TimedOut + | RequestError::TimeoutOverflow(_) + | RequestError::RemoteCancelled + | RequestError::Cancelled + | RequestError::Error(_) => { + panic!("{}", err); + } + }, + + Err(guard) => { + // Not ready, but we are not going to wait. + tokio::spawn(async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) | RequestError::Shutdown => {} + err => panic!("{}", err), + } + } + }); + } + } + } + + info!("exiting"); + } + .instrument(error_span!("alice_client")), + ); + + // Bob server. + let bob_server = tokio::spawn( + async move { + let mut bob_counter = 0; + while let Some(request) = bob + .server + .next_request() + .await + .or_else(|err| match err { + RpcServerError::CoreError(ref core_err) => match core_err { + CoreError::ReadFailed(_) + | CoreError::WriteFailed(_) + | CoreError::ErrorWriteTimeout => Ok(None), // Ignore these IO errors. + _ => Err(err), + }, + other => Err(other), + }) + .expect("next request failed") + { + let channel = request.channel(); + // Just discard the message payload, but acknowledge receiving it. + request.respond(None); + + let payload_size = spec.gen_payload_size(bob_counter); + let large_payload: Bytes = iter::repeat(0xFF) + .take(payload_size) + .collect::>() + .into(); + + // Send another request back. + let bobs_request: RequestGuard = bob + .client + .create_request(channel) + .with_payload(large_payload.clone()) + .queue_for_sending() + .await; + + info!(bob_counter, "bob enqueued request"); + bob_counter += 1; + + match bobs_request.try_get_response() { + Ok(Ok(_)) => {} + Ok(Err(err)) => match err { + RequestError::RemoteClosed(_) | RequestError::Shutdown => break, + RequestError::TimedOut + | RequestError::TimeoutOverflow(_) + | RequestError::RemoteCancelled + | RequestError::Cancelled + | RequestError::Error(_) => { + panic!("{}", err); + } + }, + + Err(guard) => { + // Do not wait, instead attempt to retrieve next request. + tokio::spawn(async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) | RequestError::Shutdown => {} + err => panic!("{}", err), + } + } + }); + } + } + } + + info!("exiting"); + } + .instrument(error_span!("bob_server")), + ); + + alice_server.await.expect("failed to join alice server"); + alice_client.await.expect("failed to join alice client"); + bob_server.await.expect("failed to join bob server"); + + info!("all joined"); + } }