diff --git a/src/rpc.rs b/src/rpc.rs index 1bd80a5..cebb8bd 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -766,7 +766,7 @@ impl IncomingRequest { // Do nothing, just discard the response. } EnqueueError::BufferLimitHit(_) => { - // TODO: Add seperate type to avoid this. + // TODO: Add separate type to avoid this. unreachable!("cannot hit request limit when responding") } } @@ -851,7 +851,10 @@ mod tests { use bytes::Bytes; use futures::FutureExt; - use tokio::io::{DuplexStream, ReadHalf, WriteHalf}; + use tokio::{ + io::{DuplexStream, ReadHalf, WriteHalf}, + sync::mpsc, + }; use tracing::{error_span, info, span, Instrument, Level}; use crate::{ @@ -1330,7 +1333,7 @@ mod tests { large_volume_test::<1>(spec).await; } - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn run_large_volume_test_with_default_values_10_channels() { tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) @@ -1352,7 +1355,7 @@ mod tests { let (mut alice, mut bob) = LargeVolumeTestSpec::::default().mk_rpc(); - // Alice server. Will close the connection after enough bytes have been sent. + // Alice server. Will close the connection after enough bytes have been received. let mut remaining = spec.min_send_bytes; let alice_server = tokio::spawn( async move { @@ -1371,6 +1374,7 @@ mod tests { request.respond(None); remaining = remaining.saturating_sub(payload_size); + tracing::debug!("payload_size: {payload_size}, remaining: {remaining}"); if remaining == 0 { // We've reached the volume we were looking for, end test. break; @@ -1420,14 +1424,18 @@ mod tests { Err(guard) => { // Not ready, but we are not going to wait. - tokio::spawn(async move { - if let Err(err) = guard.wait_for_response().await { - match err { - RequestError::RemoteClosed(_) | RequestError::Shutdown => {} - err => panic!("{}", err), + tokio::spawn( + async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) + | RequestError::Shutdown => {} + err => panic!("{}", err), + } } } - }); + .in_current_span(), + ); } } } @@ -1437,10 +1445,11 @@ mod tests { .instrument(error_span!("alice_client")), ); - // Bob server. + // A channel to allow Bob's server to notify Bob's client to send a new request to Alice. + let (notify_tx, mut notify_rx) = mpsc::unbounded_channel(); + // Bob server. Will shut down once Alice closes the connection. let bob_server = tokio::spawn( async move { - let mut bob_counter = 0; while let Some(request) = bob .server .next_request() @@ -1459,7 +1468,19 @@ mod tests { let channel = request.channel(); // Just discard the message payload, but acknowledge receiving it. request.respond(None); + // Notify Bob client to send a new request to Alice. + notify_tx.send(channel).unwrap(); + } + info!("exiting"); + } + .instrument(error_span!("bob_server")), + ); + // Bob client. Will shut down once Alice closes the connection. + let bob_client = tokio::spawn( + async move { + let mut bob_counter = 0; + while let Some(channel) = notify_rx.recv().await { let payload_size = spec.gen_payload_size(bob_counter); let large_payload: Bytes = iter::repeat(0xFF) .take(payload_size) @@ -1470,11 +1491,11 @@ mod tests { let bobs_request: RequestGuard = bob .client .create_request(channel) - .with_payload(large_payload.clone()) + .with_payload(large_payload) .queue_for_sending() .await; - info!(bob_counter, "bob enqueued request"); + info!(bob_counter, payload_size, "bob enqueued request"); bob_counter += 1; match bobs_request.try_get_response() { @@ -1492,26 +1513,30 @@ mod tests { Err(guard) => { // Do not wait, instead attempt to retrieve next request. - tokio::spawn(async move { - if let Err(err) = guard.wait_for_response().await { - match err { - RequestError::RemoteClosed(_) | RequestError::Shutdown => {} - err => panic!("{}", err), + tokio::spawn( + async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) + | RequestError::Shutdown => {} + err => panic!("{}", err), + } } } - }); + .in_current_span(), + ); } } } - info!("exiting"); } - .instrument(error_span!("bob_server")), + .instrument(error_span!("bob_client")), ); alice_server.await.expect("failed to join alice server"); alice_client.await.expect("failed to join alice client"); bob_server.await.expect("failed to join bob server"); + bob_client.await.expect("failed to join bob client"); info!("all joined"); }