Skip to content

Commit

Permalink
Also concurrently process outgoing messages
Browse files Browse the repository at this point in the history
  • Loading branch information
ameba23 committed Oct 29, 2024
1 parent a4d864d commit d0c5f67
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions crates/protocol/src/execute_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,30 +89,43 @@ where

loop {
let mut accum = session.make_accumulator();
let current_round = session.current_round();
let session_arc = Arc::new(session);
let destinations = session_arc.message_destinations();

// Channel for receiving message artifacts
let (artifact_tx, mut artifact_rx) = mpsc::channel(destinations.len());
// Send out messages
let destinations = session.message_destinations();
// TODO (#641): this can happen in a spawned task
for destination in destinations.iter() {
let (message, artifact) = session.make_message(&mut OsRng, destination)?;
tx.send(ProtocolMessage::new(&my_id, destination, message))?;
let session_arc = session_arc.clone();
let tx = tx.clone();
let my_id = my_id.clone();
let artifact_tx = artifact_tx.clone();
let destination = destination.clone();
tokio::spawn(async move {
let (message, artifact) =
session_arc.make_message(&mut OsRng, &destination).unwrap();
tx.send(ProtocolMessage::new(&my_id, &destination, message)).unwrap();
artifact_tx.send(artifact).await.unwrap();
});
}

// This will happen in a host task
accum.add_artifact(artifact)?;
for _ in 0..destinations.len() {
if let Some(artifact) = artifact_rx.recv().await {
accum.add_artifact(artifact)?;
}
}

for preprocessed in cached_messages {
// TODO (#641): this may happen in a spawned task.
let processed = session.process_message(&mut OsRng, preprocessed)?;
let processed = session_arc.process_message(&mut OsRng, preprocessed)?;

// This will happen in a host task.
accum.add_processed_message(processed)??;
}

// Channel for receiving results of processing messages
let (process_tx, mut process_rx) = mpsc::channel(1024);
let current_round = session.current_round();
let session_arc = Arc::new(session);

while !session_arc.can_finalize(&accum)? {
tokio::select! {
Expand All @@ -129,18 +142,17 @@ where
session_arc.preprocess_message(&mut accum, &message.from, *payload)?;

if let Some(preprocessed) = preprocessed {
let session_clone = session_arc.clone();
let tx_clone = process_tx.clone();
let session_arc = session_arc.clone();
let tx = process_tx.clone();
tokio::spawn(async move {
let result = session_clone.process_message(&mut OsRng, preprocessed);
if tx_clone.send(result).await.is_err() {
let result = session_arc.process_message(&mut OsRng, preprocessed);
if tx.send(result).await.is_err() {
tracing::error!("Protocol finished before message processing result sent");
}
});
}
} else {
tracing::warn!("Got protocol message with incorrect session ID - putting back in queue");
// messages_for_later.push_back(message);
tx.incoming_sender.send(message).await?;
}
} else {
Expand All @@ -157,11 +169,6 @@ where
}
}

// Put messages which were not for this session back onto the incoming message channel
// for message in messages_for_later.into_iter() {
// tx.incoming_sender.send(message).await?;
// }

// Get session back out of Arc
let session_inner =
Arc::try_unwrap(session_arc).map_err(|_| GenericProtocolError::ArcUnwrapError)?;
Expand Down

0 comments on commit d0c5f67

Please sign in to comment.