Skip to content

Commit

Permalink
Apply formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
quackzar committed Jun 13, 2024
1 parent fb90918 commit bb81a97
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 102 deletions.
4 changes: 1 addition & 3 deletions src/net/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ impl<R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send> SplitChannel for
fn split(&mut self) -> (&mut Self::Sender, &mut Self::Receiver) {
(&mut self.sender, &mut self.receiver)
}

}

pub type TcpConnection = Connection<OwnedReadHalf, OwnedWriteHalf>;
Expand All @@ -173,7 +172,7 @@ impl TcpConnection {
Self::new(reader, writer)
}

pub fn to_tcp_stream(self) -> TcpStream {
pub fn to_tcp_stream(self) -> TcpStream {
let (r, w) = self.destroy();
// UNWRAP: Should never fail, as we build the connection from two
// streams before. However! One could construct TcpConnection manually
Expand All @@ -187,7 +186,6 @@ impl TcpConnection {
}
}


/// Connection to a in-memory data stream.
/// This always have a corresponding other connection in the same process.
pub type DuplexConnection = Connection<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>>;
Expand Down
189 changes: 103 additions & 86 deletions src/net/mux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
use std::error::Error;
use std::sync::Arc;


use futures::future::{join_all, try_join_all};
use futures::future::try_join_all;
use num_traits::ToPrimitive;
use thiserror::Error;
use tokio::sync::{mpsc::{self, unbounded_channel, UnboundedSender}, oneshot};
use tokio::sync::{
mpsc::{self, unbounded_channel, UnboundedSender},
oneshot,
};
use tokio_util::bytes::{Buf, BufMut, Bytes, BytesMut};

use crate::{
Expand Down Expand Up @@ -63,7 +65,7 @@ impl SendBytes for MuxedSender {

async fn send_bytes(&mut self, bytes: tokio_util::bytes::Bytes) -> Result<(), Self::SendError> {
if let Ok(err) = self.error.try_recv() {
return Err(err)
return Err(err);
};

self.gateway
Expand Down Expand Up @@ -138,7 +140,7 @@ impl SplitChannel for MuxConn {
/// The multiplexed channels must be *driven* by the gateway
/// (see [``Gateway::drive``]) otherwise the multiplexed channels won't
/// be able to communicate.
///
///
/// ## Example:
/// ```
/// use caring::net::{connection::Connection, mux::Gateway, RecvBytes, SendBytes};
Expand Down Expand Up @@ -182,78 +184,78 @@ where
mailboxes: Vec<mpsc::UnboundedSender<BytesMut>>,
inbox: mpsc::UnboundedReceiver<MultiplexedMessage>,
errors: Vec<[oneshot::Sender<MuxError>; 2]>,
outbox: mpsc::WeakUnboundedSender<MultiplexedMessage>
outbox: mpsc::WeakUnboundedSender<MultiplexedMessage>,
}


#[derive(Debug, Error)]
pub enum GatewayError<E: Error + Send + 'static> {
#[error("Multiplexed connection {0} disappered")]
MailboxNotFound(usize),
#[error("Underlying connection died: {0}")]
DeadConnection(#[from] Arc<E>)
DeadConnection(#[from] Arc<E>),
}

impl<C: SplitChannel + Send> Gateway<C> {
/// Drive a gateway until all multiplexed connections are complete
///
/// # Errors
///
///
/// - [``GatewayError::MailboxNotFound``] if a given multiplexed connections has been
/// dropped and is receiving messages.
/// - [``GatewayError::DeadConnection``] if the underlying connection have failed.
///
pub async fn drive(mut self) -> Result<Self, GatewayError<C::Error>> {
let (sending, recving) = self.channel.split();

let send_out = async {
loop {
if let Some(msg) = self.inbox.recv().await {
match sending.send_bytes(msg.make_bytes()).await {
Ok(()) => continue,
Err(e) => break Err(e),
}
}
break Ok(());
// TODO: maybe have this be nonconsuming so it can be resumed after new muxes are added?
// This however would compromise the possible destruction when errors occur,
// thus leaving the error handling in a bad state.
let (sending, recving) = self.channel.split();
let send_out = async {
loop {
if let Some(msg) = self.inbox.recv().await {
match sending.send_bytes(msg.make_bytes()).await {
Ok(()) => continue,
Err(e) => break Err(e),
}
};
}
break Ok(());
}
};

let recv_in = async {
loop {
match recving.recv_bytes().await {
Ok(mut msg) => {
let id = msg.get_u32() as usize;
let bytes = msg;
let Some(mailbox) = self.mailboxes.get_mut(id) else {
break Err(GatewayError::MailboxNotFound(id))
};
let Ok(()) = mailbox.send(bytes) else {
break Err(GatewayError::MailboxNotFound(id))
};
}
Err(e) => break Ok(e),
let recv_in = async {
loop {
match recving.recv_bytes().await {
Ok(mut msg) => {
let id = msg.get_u32() as usize;
let bytes = msg;
let Some(mailbox) = self.mailboxes.get_mut(id) else {
break Err(GatewayError::MailboxNotFound(id));
};
let Ok(()) = mailbox.send(bytes) else {
break Err(GatewayError::MailboxNotFound(id));
};
}
Err(e) => break Ok(e),
}
};

}
};

tokio::select! { // Drive both futures to completion.
res = send_out => {
match res {
Ok(()) => {
Ok(self)
},
Err(err) => {
Err(self.propogate_error(err))
}
tokio::select! { // Drive both futures to completion.
res = send_out => {
match res {
Ok(()) => {
Ok(self)
},
Err(err) => {
Err(self.propogate_error(err))
}
},
err = recv_in => {
let err : C::Error = err?; // return early on missing mailbox.
println!("Got an error! {err:?}");
Err(self.propogate_error(err))
},
}
}
},
err = recv_in => {
let err : C::Error = err?; // return early on missing mailbox.
println!("Got an error! {err:?}");
Err(self.propogate_error(err))
},
}
}

fn propogate_error<E: Error + Send + Sync + 'static>(mut self, err: E) -> GatewayError<E> {
Expand All @@ -265,13 +267,12 @@ impl<C: SplitChannel + Send> Gateway<C> {
let _ = c2.send(MuxError::Connection(err.clone()));
}
GatewayError::DeadConnection(err)
}

}

pub fn single(channel: C) -> (Self, MuxConn) {
let (outbox, inbox) = unbounded_channel();
let gateway = outbox.clone();
let outbox= outbox.downgrade();
let outbox = outbox.downgrade();
let mut new = Self {
channel,
mailboxes: vec![],
Expand All @@ -281,7 +282,6 @@ impl<C: SplitChannel + Send> Gateway<C> {
};
let con = new.add_mux(gateway);
(new, con)

}

pub fn destroy(self) -> C {
Expand All @@ -294,42 +294,53 @@ impl<C: SplitChannel + Send> Gateway<C> {
/// * `n`: Number of new connections to multiplex into
///
/// Returns a gateway which the ``MuxConn`` communicate through, along with the MuxConn
pub fn multiplex(con: C, n: usize) -> (Self, Vec<MuxConn>) {
let (mut gateway, con) = Self::single(con);
pub fn multiplex(con: C, n: usize) -> (Self, Vec<MuxConn>) {
let (mut gateway, con) = Self::single(con);
let mut muxes = vec![con];
for _ in 1..n {
muxes.push(gateway.muxify());
}
(gateway, muxes)
}


fn add_mux(&mut self, gateway: UnboundedSender<MultiplexedMessage>) -> MuxConn {
let id = self.mailboxes.len();
let (errors_coms1, error) = oneshot::channel();
let mx_sender = MuxedSender {
id,
gateway,
error,
};
let (errors_coms1, error) = oneshot::channel();
let mx_sender = MuxedSender { id, gateway, error };
let (outbox, mailbox) = tokio::sync::mpsc::unbounded_channel();
let (errors_coms2, error) = oneshot::channel();
let mx_receiver = MuxedReceiver {
id,
mailbox,
error,
};
let (errors_coms2, error) = oneshot::channel();
let mx_receiver = MuxedReceiver { id, mailbox, error };
self.errors.push([errors_coms1, errors_coms2]);
self.mailboxes.push(outbox);
MuxConn(mx_sender, mx_receiver)
}

pub fn muxify(&mut self) -> MuxConn {
let gateway = self.outbox.clone().upgrade().expect("We are holding the receiver");
let gateway = self
.outbox
.clone()
.upgrade()
.expect("We are holding the receiver");
self.add_mux(gateway)
}
}

pub struct ActiveGateway<C: SplitChannel>(
tokio::task::JoinHandle<Result<Gateway<C>, GatewayError<C::Error>>>,
);

impl<C: SplitChannel + Send + 'static> Gateway<C> {
pub fn go(self) -> ActiveGateway<C> {
ActiveGateway(tokio::spawn(self.drive()))
}
}

impl<C: SplitChannel + Send + 'static> ActiveGateway<C> {
pub async fn deactivate(self) -> Result<Gateway<C>, GatewayError<C::Error>> {
self.0.await.unwrap()
}
}

pub struct NetworkGateway<C: SplitChannel> {
gateways: Vec<Gateway<C>>,
index: usize,
Expand All @@ -341,7 +352,8 @@ impl<C> NetworkGateway<C>
where
C: SplitChannel + Send,
{
#[must_use] pub fn multiplex(net: Network<C>, n: usize) -> (NetworkGateway<C>, Vec<MuxNet>) {
#[must_use]
pub fn multiplex(net: Network<C>, n: usize) -> (NetworkGateway<C>, Vec<MuxNet>) {
let mut gateways = Vec::new();
let mut matrix = Vec::new();
let index = net.index;
Expand All @@ -350,7 +362,7 @@ where
matrix.push(muxes);
gateways.push(gateway);
}
let gateway = NetworkGateway { gateways, index, };
let gateway = NetworkGateway { gateways, index };

let matrix = help::transpose(matrix);
let muxnets: Vec<_> = matrix
Expand All @@ -361,32 +373,38 @@ where
(gateway, muxnets)
}

pub fn multiplex_borrow(net: &mut Network<C>, n: usize)
-> (NetworkGateway<&mut C>, Vec<MuxNet>)
{
pub fn multiplex_borrow(
net: &mut Network<C>,
n: usize,
) -> (NetworkGateway<&mut C>, Vec<MuxNet>) {
let net = net.as_mut();
NetworkGateway::<&mut C>::multiplex(net, n)
}

pub async fn drive(self) -> Result<Self, GatewayError<C::Error>> {
let gateways = try_join_all(self.gateways.into_iter().map(Gateway::drive)).await?;
Ok(Self { gateways, index: self.index })
Ok(Self {
gateways,
index: self.index,
})
}

#[must_use]
pub fn destroy(mut self) -> Network<C> {
let index= self.index;
let connections : Vec<_> = self.gateways.drain(..).map(Gateway::destroy).collect();
let index = self.index;
let connections: Vec<_> = self.gateways.drain(..).map(Gateway::destroy).collect();
Network { connections, index }
}

pub fn new_mux(&mut self) -> MuxNet {
let connections = self.gateways.iter_mut().map(Gateway::muxify ).collect();
MuxNet { connections, index: self.index }
let connections = self.gateways.iter_mut().map(Gateway::muxify).collect();
MuxNet {
connections,
index: self.index,
}
}
}


#[cfg(test)]
mod test {
use std::time::Duration;
Expand Down Expand Up @@ -447,7 +465,7 @@ mod test {
};
let (s, gateway) = (s, gateway.drive()).join().await;
let mut gateway = gateway.unwrap();
let _ : String = gateway.channel.recv().await.unwrap();
let _: String = gateway.channel.recv().await.unwrap();
gateway.channel.shutdown().await.unwrap();
s
};
Expand Down Expand Up @@ -511,7 +529,6 @@ mod test {
.unwrap();
}


#[tokio::test]
async fn network_borrowed() {
crate::testing::Cluster::new(3)
Expand Down
Loading

0 comments on commit bb81a97

Please sign in to comment.