From bb81a9701c4493336d8f728ca340264ce8d8d9ec Mon Sep 17 00:00:00 2001 From: Mikkel Wienberg Madsen Date: Thu, 13 Jun 2024 16:41:18 +0200 Subject: [PATCH] Apply formatting --- src/net/connection.rs | 4 +- src/net/mux.rs | 189 +++++++++++++++++++++++------------------- src/net/network.rs | 33 +++++--- 3 files changed, 124 insertions(+), 102 deletions(-) diff --git a/src/net/connection.rs b/src/net/connection.rs index 8aac8f9..47bb45a 100644 --- a/src/net/connection.rs +++ b/src/net/connection.rs @@ -160,7 +160,6 @@ impl SplitChannel for fn split(&mut self) -> (&mut Self::Sender, &mut Self::Receiver) { (&mut self.sender, &mut self.receiver) } - } pub type TcpConnection = Connection; @@ -173,7 +172,7 @@ impl TcpConnection { Self::new(reader, writer) } - pub fn to_tcp_stream(self) -> TcpStream { + pub fn to_tcp_stream(self) -> TcpStream { let (r, w) = self.destroy(); // UNWRAP: Should never fail, as we build the connection from two // streams before. However! One could construct TcpConnection manually @@ -187,7 +186,6 @@ impl TcpConnection { } } - /// Connection to a in-memory data stream. /// This always have a corresponding other connection in the same process. pub type DuplexConnection = Connection, WriteHalf>; diff --git a/src/net/mux.rs b/src/net/mux.rs index f35ddda..153d549 100644 --- a/src/net/mux.rs +++ b/src/net/mux.rs @@ -8,11 +8,13 @@ use std::error::Error; use std::sync::Arc; - -use futures::future::{join_all, try_join_all}; +use futures::future::try_join_all; use num_traits::ToPrimitive; use thiserror::Error; -use tokio::sync::{mpsc::{self, unbounded_channel, UnboundedSender}, oneshot}; +use tokio::sync::{ + mpsc::{self, unbounded_channel, UnboundedSender}, + oneshot, +}; use tokio_util::bytes::{Buf, BufMut, Bytes, BytesMut}; use crate::{ @@ -63,7 +65,7 @@ impl SendBytes for MuxedSender { async fn send_bytes(&mut self, bytes: tokio_util::bytes::Bytes) -> Result<(), Self::SendError> { if let Ok(err) = self.error.try_recv() { - return Err(err) + return Err(err); }; self.gateway @@ -138,7 +140,7 @@ impl SplitChannel for MuxConn { /// The multiplexed channels must be *driven* by the gateway /// (see [``Gateway::drive``]) otherwise the multiplexed channels won't /// be able to communicate. -/// +/// /// ## Example: /// ``` /// use caring::net::{connection::Connection, mux::Gateway, RecvBytes, SendBytes}; @@ -182,78 +184,78 @@ where mailboxes: Vec>, inbox: mpsc::UnboundedReceiver, errors: Vec<[oneshot::Sender; 2]>, - outbox: mpsc::WeakUnboundedSender + outbox: mpsc::WeakUnboundedSender, } - #[derive(Debug, Error)] pub enum GatewayError { #[error("Multiplexed connection {0} disappered")] MailboxNotFound(usize), #[error("Underlying connection died: {0}")] - DeadConnection(#[from] Arc) + DeadConnection(#[from] Arc), } impl Gateway { /// Drive a gateway until all multiplexed connections are complete /// /// # Errors - /// + /// /// - [``GatewayError::MailboxNotFound``] if a given multiplexed connections has been /// dropped and is receiving messages. /// - [``GatewayError::DeadConnection``] if the underlying connection have failed. /// pub async fn drive(mut self) -> Result> { - let (sending, recving) = self.channel.split(); - - let send_out = async { - loop { - if let Some(msg) = self.inbox.recv().await { - match sending.send_bytes(msg.make_bytes()).await { - Ok(()) => continue, - Err(e) => break Err(e), - } - } - break Ok(()); + // TODO: maybe have this be nonconsuming so it can be resumed after new muxes are added? + // This however would compromise the possible destruction when errors occur, + // thus leaving the error handling in a bad state. + let (sending, recving) = self.channel.split(); + let send_out = async { + loop { + if let Some(msg) = self.inbox.recv().await { + match sending.send_bytes(msg.make_bytes()).await { + Ok(()) => continue, + Err(e) => break Err(e), } - }; + } + break Ok(()); + } + }; - let recv_in = async { - loop { - match recving.recv_bytes().await { - Ok(mut msg) => { - let id = msg.get_u32() as usize; - let bytes = msg; - let Some(mailbox) = self.mailboxes.get_mut(id) else { - break Err(GatewayError::MailboxNotFound(id)) - }; - let Ok(()) = mailbox.send(bytes) else { - break Err(GatewayError::MailboxNotFound(id)) - }; - } - Err(e) => break Ok(e), + let recv_in = async { + loop { + match recving.recv_bytes().await { + Ok(mut msg) => { + let id = msg.get_u32() as usize; + let bytes = msg; + let Some(mailbox) = self.mailboxes.get_mut(id) else { + break Err(GatewayError::MailboxNotFound(id)); + }; + let Ok(()) = mailbox.send(bytes) else { + break Err(GatewayError::MailboxNotFound(id)); + }; } + Err(e) => break Ok(e), } - }; - + } + }; - tokio::select! { // Drive both futures to completion. - res = send_out => { - match res { - Ok(()) => { - Ok(self) - }, - Err(err) => { - Err(self.propogate_error(err)) - } + tokio::select! { // Drive both futures to completion. + res = send_out => { + match res { + Ok(()) => { + Ok(self) + }, + Err(err) => { + Err(self.propogate_error(err)) } - }, - err = recv_in => { - let err : C::Error = err?; // return early on missing mailbox. - println!("Got an error! {err:?}"); - Err(self.propogate_error(err)) - }, - } + } + }, + err = recv_in => { + let err : C::Error = err?; // return early on missing mailbox. + println!("Got an error! {err:?}"); + Err(self.propogate_error(err)) + }, + } } fn propogate_error(mut self, err: E) -> GatewayError { @@ -265,13 +267,12 @@ impl Gateway { let _ = c2.send(MuxError::Connection(err.clone())); } GatewayError::DeadConnection(err) - } - + } pub fn single(channel: C) -> (Self, MuxConn) { let (outbox, inbox) = unbounded_channel(); let gateway = outbox.clone(); - let outbox= outbox.downgrade(); + let outbox = outbox.downgrade(); let mut new = Self { channel, mailboxes: vec![], @@ -281,7 +282,6 @@ impl Gateway { }; let con = new.add_mux(gateway); (new, con) - } pub fn destroy(self) -> C { @@ -294,8 +294,8 @@ impl Gateway { /// * `n`: Number of new connections to multiplex into /// /// Returns a gateway which the ``MuxConn`` communicate through, along with the MuxConn - pub fn multiplex(con: C, n: usize) -> (Self, Vec) { - let (mut gateway, con) = Self::single(con); + pub fn multiplex(con: C, n: usize) -> (Self, Vec) { + let (mut gateway, con) = Self::single(con); let mut muxes = vec![con]; for _ in 1..n { muxes.push(gateway.muxify()); @@ -303,33 +303,44 @@ impl Gateway { (gateway, muxes) } - fn add_mux(&mut self, gateway: UnboundedSender) -> MuxConn { let id = self.mailboxes.len(); - let (errors_coms1, error) = oneshot::channel(); - let mx_sender = MuxedSender { - id, - gateway, - error, - }; + let (errors_coms1, error) = oneshot::channel(); + let mx_sender = MuxedSender { id, gateway, error }; let (outbox, mailbox) = tokio::sync::mpsc::unbounded_channel(); - let (errors_coms2, error) = oneshot::channel(); - let mx_receiver = MuxedReceiver { - id, - mailbox, - error, - }; + let (errors_coms2, error) = oneshot::channel(); + let mx_receiver = MuxedReceiver { id, mailbox, error }; self.errors.push([errors_coms1, errors_coms2]); self.mailboxes.push(outbox); MuxConn(mx_sender, mx_receiver) } pub fn muxify(&mut self) -> MuxConn { - let gateway = self.outbox.clone().upgrade().expect("We are holding the receiver"); + let gateway = self + .outbox + .clone() + .upgrade() + .expect("We are holding the receiver"); self.add_mux(gateway) } } +pub struct ActiveGateway( + tokio::task::JoinHandle, GatewayError>>, +); + +impl Gateway { + pub fn go(self) -> ActiveGateway { + ActiveGateway(tokio::spawn(self.drive())) + } +} + +impl ActiveGateway { + pub async fn deactivate(self) -> Result, GatewayError> { + self.0.await.unwrap() + } +} + pub struct NetworkGateway { gateways: Vec>, index: usize, @@ -341,7 +352,8 @@ impl NetworkGateway where C: SplitChannel + Send, { - #[must_use] pub fn multiplex(net: Network, n: usize) -> (NetworkGateway, Vec) { + #[must_use] + pub fn multiplex(net: Network, n: usize) -> (NetworkGateway, Vec) { let mut gateways = Vec::new(); let mut matrix = Vec::new(); let index = net.index; @@ -350,7 +362,7 @@ where matrix.push(muxes); gateways.push(gateway); } - let gateway = NetworkGateway { gateways, index, }; + let gateway = NetworkGateway { gateways, index }; let matrix = help::transpose(matrix); let muxnets: Vec<_> = matrix @@ -361,32 +373,38 @@ where (gateway, muxnets) } - pub fn multiplex_borrow(net: &mut Network, n: usize) - -> (NetworkGateway<&mut C>, Vec) - { + pub fn multiplex_borrow( + net: &mut Network, + n: usize, + ) -> (NetworkGateway<&mut C>, Vec) { let net = net.as_mut(); NetworkGateway::<&mut C>::multiplex(net, n) } pub async fn drive(self) -> Result> { let gateways = try_join_all(self.gateways.into_iter().map(Gateway::drive)).await?; - Ok(Self { gateways, index: self.index }) + Ok(Self { + gateways, + index: self.index, + }) } #[must_use] pub fn destroy(mut self) -> Network { - let index= self.index; - let connections : Vec<_> = self.gateways.drain(..).map(Gateway::destroy).collect(); + let index = self.index; + let connections: Vec<_> = self.gateways.drain(..).map(Gateway::destroy).collect(); Network { connections, index } } pub fn new_mux(&mut self) -> MuxNet { - let connections = self.gateways.iter_mut().map(Gateway::muxify ).collect(); - MuxNet { connections, index: self.index } + let connections = self.gateways.iter_mut().map(Gateway::muxify).collect(); + MuxNet { + connections, + index: self.index, + } } } - #[cfg(test)] mod test { use std::time::Duration; @@ -447,7 +465,7 @@ mod test { }; let (s, gateway) = (s, gateway.drive()).join().await; let mut gateway = gateway.unwrap(); - let _ : String = gateway.channel.recv().await.unwrap(); + let _: String = gateway.channel.recv().await.unwrap(); gateway.channel.shutdown().await.unwrap(); s }; @@ -511,7 +529,6 @@ mod test { .unwrap(); } - #[tokio::test] async fn network_borrowed() { crate::testing::Cluster::new(3) diff --git a/src/net/network.rs b/src/net/network.rs index 3d58c45..a67fa10 100644 --- a/src/net/network.rs +++ b/src/net/network.rs @@ -150,8 +150,11 @@ impl Network { T: serde::Serialize + serde::de::DeserializeOwned + Sync, { let my_id = self.index; - let (mut tx, mut rx): (Vec<_>, Vec<_>) = - self.connections.iter_mut().map(super::SplitChannel::split).unzip(); + let (mut tx, mut rx): (Vec<_>, Vec<_>) = self + .connections + .iter_mut() + .map(super::SplitChannel::split) + .unzip(); let packet: Bytes = bincode::serialize(&msg).unwrap().into(); let outgoing = tx.iter_mut().enumerate().map(|(id, conn)| { @@ -211,8 +214,11 @@ impl Network { let my_id = self.index; let my_own_msg = msgs.remove(my_id); - let (mut tx, mut rx): (Vec<_>, Vec<_>) = - self.connections.iter_mut().map(super::SplitChannel::split).unzip(); + let (mut tx, mut rx): (Vec<_>, Vec<_>) = self + .connections + .iter_mut() + .map(super::SplitChannel::split) + .unzip(); let outgoing = tx .iter_mut() @@ -302,10 +308,12 @@ impl Network { todo!("Initiate a drop vote"); } - pub(crate) fn as_mut(&mut self) -> Network<&mut C> { let connections = self.connections.iter_mut().collect(); - Network { connections, index: self.index } + Network { + connections, + index: self.index, + } } } @@ -483,9 +491,7 @@ impl InMemoryNetwork { let futs = self .connections .into_iter() - .map(|conn| async move { - conn.shutdown().await - }); + .map(|conn| async move { conn.shutdown().await }); join_all(futs).await.into_iter().map_ok(|_| {}).collect() } } @@ -534,7 +540,10 @@ impl TcpNetwork { stream.set_nodelay(true).unwrap(); } - let connections = parties.into_iter().map(Connection::from_tcp_stream).collect(); + let connections = parties + .into_iter() + .map(Connection::from_tcp_stream) + .collect(); let mut network = Self { connections, @@ -549,9 +558,7 @@ impl TcpNetwork { let futs = self .connections .into_iter() - .map(|conn| async move { - conn.shutdown().await - }); + .map(|conn| async move { conn.shutdown().await }); join_all(futs).await.into_iter().map_ok(|_| {}).collect() } }