Skip to content

Commit

Permalink
Gateway better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
quackzar committed Jun 13, 2024
1 parent be3f2fe commit fb90918
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 44 deletions.
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![allow(refining_impl_trait)]
#![allow(dead_code)]
#![feature(async_fn_traits)]
#![allow(clippy::cast_possible_truncation)]

mod algebra;
pub mod net;
Expand Down
2 changes: 1 addition & 1 deletion src/net/agency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ impl<B: Broadcast, D: Digest> VerifiedBroadcast<B, D> {
// 3. Hash the hashes together and broadcast that
event!(Level::INFO, "Broadcast sum of all commits");
let mut digest = D::new();
for hash in msg_hashes.iter() {
for hash in &msg_hashes {
digest.update(hash);
}
let sum: Box<[u8]> = digest.finalize().to_vec().into_boxed_slice();
Expand Down
14 changes: 12 additions & 2 deletions src/net/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
use std::error::Error;

use futures::{SinkExt, StreamExt};
use futures_concurrency::future::Join;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf},
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf},
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
Expand Down Expand Up @@ -201,8 +202,17 @@ impl DuplexConnection {
(Self::new(r1, w1), Self::new(r2, w2))
}

/// Gracefully shutdown the connection.
///
/// # Errors
///
/// This function will return an error if the connection could not be shutdown cleanly.
pub async fn shutdown(self) -> Result<(), std::io::Error> {
let (r,w) = self.destroy();
let (mut r, mut w) = self.destroy();
// HACK: A single read/write so we don't exit too early.
// We ignore the errors, since we don't care if we can't send or receive,
// it is supposed to be closed.
let (_, _) = (r.read_u8(), w.write_u8(0)).join().await;
let mut stream = r.unsplit(w);
stream.shutdown().await
}
Expand Down
4 changes: 2 additions & 2 deletions src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ impl<R: RecvBytes> RecvBytes for &mut R {
/// A communication medium between you and another party.
///
/// Allows you to send and receive arbitrary messages.
pub trait Channel: SendBytes + RecvBytes {
type Error: Error + Send;
pub trait Channel: SendBytes<SendError = Self::Error> + RecvBytes<RecvError = Self::Error> {
type Error: Error + Send + Sync + 'static;
}
impl<C: Channel> Channel for &mut C {
type Error = C::Error;
Expand Down
119 changes: 82 additions & 37 deletions src/net/mux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use std::error::Error;
use std::sync::Arc;


use futures::future::join_all;
use futures::future::{join_all, try_join_all};

Check failure on line 12 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo test

unused import: `join_all`

Check failure on line 12 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo test

unused import: `join_all`
use num_traits::ToPrimitive;
use thiserror::Error;
use tokio::sync::{mpsc::{self, unbounded_channel, UnboundedSender}, oneshot};

Check warning on line 15 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/caring/caring/src/net/mux.rs

Check warning on line 15 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/caring/caring/src/net/mux.rs
use tokio_util::bytes::{Buf, BufMut, Bytes, BytesMut};
Expand Down Expand Up @@ -51,7 +52,7 @@ impl MultiplexedMessage {
let bytes = self.0;
let id = self.1;
let mut msg = BytesMut::new();
msg.put_u32(id as u32);
msg.put_u32(id.to_u32().expect("Too many multiplexed connections!"));
msg.put(bytes);
msg.freeze()
}
Expand Down Expand Up @@ -95,7 +96,7 @@ impl RecvBytes for MuxedReceiver {
//
/// Multiplexed Connection
///
/// Aqquirred by constructing a [Gateway] using [Gateway::multiplex]
/// Aqquirred by constructing a [Gateway] using [``Gateway::multiplex``]
pub struct MuxConn(MuxedSender, MuxedReceiver);

impl Channel for MuxConn {
Expand Down Expand Up @@ -135,7 +136,7 @@ impl SplitChannel for MuxConn {
///
/// Enables splitting a channel into multiple multiplexed channels.
/// The multiplexed channels must be *driven* by the gateway

Check warning on line 138 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/caring/caring/src/net/mux.rs

Check warning on line 138 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/caring/caring/src/net/mux.rs
/// (see [Gateway::drive]) otherwise the multiplexed channels won't
/// (see [``Gateway::drive``]) otherwise the multiplexed channels won't
/// be able to communicate.
///
/// ## Example:
Expand All @@ -154,7 +155,7 @@ impl SplitChannel for MuxConn {
/// tokio::spawn(async move {
/// m2.send(&"Hello MUX2!".to_owned()).await.unwrap();
/// });
/// gateway.drive().await;
/// gateway.drive().await.unwrap();
/// });
///
/// tokio::spawn( async {// party 2
Expand All @@ -169,7 +170,7 @@ impl SplitChannel for MuxConn {
/// let msg : String = m2.recv().await.unwrap();
/// assert_eq!(msg, "Hello MUX2!");
/// });
/// gateway.drive().await;
/// gateway.drive().await.unwrap();
/// });
/// })
/// ```
Expand All @@ -184,17 +185,37 @@ where
outbox: mpsc::WeakUnboundedSender<MultiplexedMessage>
}


#[derive(Debug, Error)]
pub enum GatewayError<E: Error + Send + 'static> {
#[error("Multiplexed connection {0} disappered")]
MailboxNotFound(usize),

Check warning on line 192 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/caring/caring/src/net/mux.rs

Check warning on line 192 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/caring/caring/src/net/mux.rs
#[error("Underlying connection died: {0}")]
DeadConnection(#[from] Arc<E>)
}

impl<C: SplitChannel + Send> Gateway<C> {
pub async fn drive(self) -> Self {
let mut gateway = self;
{
let (sending, recving) = gateway.channel.split();
/// Drive a gateway until all multiplexed connections are complete

Check warning on line 198 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/caring/caring/src/net/mux.rs

Check warning on line 198 in src/net/mux.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/caring/caring/src/net/mux.rs
///
/// # Errors
///
/// - [``GatewayError::MailboxNotFound``] if a given multiplexed connections has been
/// dropped and is receiving messages.
/// - [``GatewayError::DeadConnection``] if the underlying connection have failed.
///
pub async fn drive(mut self) -> Result<Self, GatewayError<C::Error>> {
let (sending, recving) = self.channel.split();

let send_out = async {
while let Some(msg) = gateway.inbox.recv().await {
// TODO: Error propagation.
sending.send_bytes(msg.make_bytes()).await.unwrap();
}
loop {
if let Some(msg) = self.inbox.recv().await {
match sending.send_bytes(msg.make_bytes()).await {
Ok(()) => continue,
Err(e) => break Err(e),
}
}
break Ok(());
}
};

let recv_in = async {
Expand All @@ -203,29 +224,49 @@ impl<C: SplitChannel + Send> Gateway<C> {
Ok(mut msg) => {
let id = msg.get_u32() as usize;
let bytes = msg;
gateway.mailboxes[id].send(bytes).unwrap();
let Some(mailbox) = self.mailboxes.get_mut(id) else {
break Err(GatewayError::MailboxNotFound(id))
};
let Ok(()) = mailbox.send(bytes) else {
break Err(GatewayError::MailboxNotFound(id))
};
}
Err(e) => break e,
Err(e) => break Ok(e),
}
}
};


tokio::select! { // Drive both futures to completion.
() = send_out => {},
err = recv_in => {
let err = Arc::new(err);
for [c1, c2] in gateway.errors.drain(..) {
// ignore dropped connections,
// they can't handle errors when they don't exist.
let _ = c1.send(MuxError::Connection(err.clone()));
let _ = c2.send(MuxError::Connection(err.clone()));
res = send_out => {
match res {
Ok(()) => {
Ok(self)
},
Err(err) => {
Err(self.propogate_error(err))
}
}
},
};
err = recv_in => {
let err : C::Error = err?; // return early on missing mailbox.
println!("Got an error! {err:?}");
Err(self.propogate_error(err))
},
}
}

fn propogate_error<E: Error + Send + Sync + 'static>(mut self, err: E) -> GatewayError<E> {
let err = Arc::new(err);
for [c1, c2] in self.errors.drain(..) {
// ignore dropped connections,
// they can't handle errors when they don't exist.
let _ = c1.send(MuxError::Connection(err.clone()));
let _ = c2.send(MuxError::Connection(err.clone()));
}
GatewayError::DeadConnection(err)
}

gateway
}

pub fn single(channel: C) -> (Self, MuxConn) {
let (outbox, inbox) = unbounded_channel();
Expand All @@ -252,7 +293,7 @@ impl<C: SplitChannel + Send> Gateway<C> {
/// * `net`: Connection to use as a gateway for multiplexing
/// * `n`: Number of new connections to multiplex into
///
/// Returns a gateway which the MuxConn communicate through, along with the MuxConn
/// Returns a gateway which the ``MuxConn`` communicate through, along with the MuxConn
pub fn multiplex(con: C, n: usize) -> (Self, Vec<MuxConn>) {
let (mut gateway, con) = Self::single(con);
let mut muxes = vec![con];
Expand Down Expand Up @@ -300,7 +341,7 @@ impl<C> NetworkGateway<C>
where
C: SplitChannel + Send,
{
pub fn multiplex(net: Network<C>, n: usize) -> (NetworkGateway<C>, Vec<MuxNet>) {
#[must_use] pub fn multiplex(net: Network<C>, n: usize) -> (NetworkGateway<C>, Vec<MuxNet>) {
let mut gateways = Vec::new();
let mut matrix = Vec::new();
let index = net.index;
Expand All @@ -327,19 +368,20 @@ where
NetworkGateway::<&mut C>::multiplex(net, n)
}

pub async fn drive(self) -> Self {
let gateways = join_all(self.gateways.into_iter().map(|c| c.drive())).await;
Self { gateways, index: self.index }
pub async fn drive(self) -> Result<Self, GatewayError<C::Error>> {
let gateways = try_join_all(self.gateways.into_iter().map(Gateway::drive)).await?;
Ok(Self { gateways, index: self.index })
}

#[must_use]
pub fn destroy(mut self) -> Network<C> {
let index= self.index;
let connections : Vec<_> = self.gateways.drain(..).map(|g| g.destroy()).collect();
let connections : Vec<_> = self.gateways.drain(..).map(Gateway::destroy).collect();
Network { connections, index }
}

pub fn new_mux(&mut self) -> MuxNet {
let connections = self.gateways.iter_mut().map(|g| g.muxify() ).collect();
let connections = self.gateways.iter_mut().map(Gateway::muxify ).collect();
MuxNet { connections, index: self.index }
}
}
Expand Down Expand Up @@ -385,7 +427,8 @@ mod test {
s1 + &s2 + &s3
};

let (s, mut gateway) = join!(s, gateway.drive());
let (s, gateway) = join!(s, gateway.drive());
let mut gateway = gateway.unwrap();
gateway.channel.send(&"bye".to_owned()).await.unwrap();
gateway.channel.shutdown().await.unwrap();
s
Expand All @@ -402,7 +445,8 @@ mod test {
);
s1 + &s2 + &s3
};
let (s, mut gateway) = (s, gateway.drive()).join().await;
let (s, gateway) = (s, gateway.drive()).join().await;
let mut gateway = gateway.unwrap();
let _ : String = gateway.channel.recv().await.unwrap();
gateway.channel.shutdown().await.unwrap();
s
Expand Down Expand Up @@ -438,7 +482,7 @@ mod test {
drop(c2);
};

let (_, _) = futures::join!(p1, p2);
let (_, ()) = futures::join!(p1, p2);
}

#[tokio::test]
Expand All @@ -458,6 +502,7 @@ mod test {
assert_eq!(res, vec!["World"; 3]);
});
let (r1, r2, gateway) = futures::join!(h1, h2, gateway.drive());
let gateway = gateway.unwrap();
gateway.destroy().shutdown().await.unwrap();
r1.unwrap();
r2.unwrap();
Expand Down
9 changes: 7 additions & 2 deletions src/net/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl<C: SplitChannel> Network<C> {
{
let my_id = self.index;
let (mut tx, mut rx): (Vec<_>, Vec<_>) =
self.connections.iter_mut().map(|c| c.split()).unzip();
self.connections.iter_mut().map(super::SplitChannel::split).unzip();

let packet: Bytes = bincode::serialize(&msg).unwrap().into();
let outgoing = tx.iter_mut().enumerate().map(|(id, conn)| {
Expand Down Expand Up @@ -199,6 +199,11 @@ impl<C: SplitChannel> Network<C> {
/// will be send to party `i`.
///
/// * `msg`: message to send and receive
///
/// # Errors
///
/// Returns a [``NetworkError``] with an ``id`` if the underlying connection to that ``id`` fails.
///
pub async fn symmetric_unicast<T>(&mut self, mut msgs: Vec<T>) -> NetResult<Vec<T>, C>
where
T: serde::Serialize + serde::de::DeserializeOwned + Sync,
Expand All @@ -207,7 +212,7 @@ impl<C: SplitChannel> Network<C> {
let my_own_msg = msgs.remove(my_id);

let (mut tx, mut rx): (Vec<_>, Vec<_>) =
self.connections.iter_mut().map(|c| c.split()).unzip();
self.connections.iter_mut().map(super::SplitChannel::split).unzip();

let outgoing = tx
.iter_mut()
Expand Down

0 comments on commit fb90918

Please sign in to comment.