From d0c5f6754970336be2478b0da5b4b9b2f30c892d Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 29 Oct 2024 16:10:25 +0100 Subject: [PATCH] Also concurrently process outgoing messages --- crates/protocol/src/execute_protocol.rs | 45 ++++++++++++++----------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/crates/protocol/src/execute_protocol.rs b/crates/protocol/src/execute_protocol.rs index 76ab9bbfa..007009dd4 100644 --- a/crates/protocol/src/execute_protocol.rs +++ b/crates/protocol/src/execute_protocol.rs @@ -89,21 +89,36 @@ where loop { let mut accum = session.make_accumulator(); + let current_round = session.current_round(); + let session_arc = Arc::new(session); + let destinations = session_arc.message_destinations(); + // Channel for receiving message artifacts + let (artifact_tx, mut artifact_rx) = mpsc::channel(destinations.len()); // Send out messages - let destinations = session.message_destinations(); - // TODO (#641): this can happen in a spawned task for destination in destinations.iter() { - let (message, artifact) = session.make_message(&mut OsRng, destination)?; - tx.send(ProtocolMessage::new(&my_id, destination, message))?; + let session_arc = session_arc.clone(); + let tx = tx.clone(); + let my_id = my_id.clone(); + let artifact_tx = artifact_tx.clone(); + let destination = destination.clone(); + tokio::spawn(async move { + let (message, artifact) = + session_arc.make_message(&mut OsRng, &destination).unwrap(); + tx.send(ProtocolMessage::new(&my_id, &destination, message)).unwrap(); + artifact_tx.send(artifact).await.unwrap(); + }); + } - // This will happen in a host task - accum.add_artifact(artifact)?; + for _ in 0..destinations.len() { + if let Some(artifact) = artifact_rx.recv().await { + accum.add_artifact(artifact)?; + } } for preprocessed in cached_messages { // TODO (#641): this may happen in a spawned task. - let processed = session.process_message(&mut OsRng, preprocessed)?; + let processed = session_arc.process_message(&mut OsRng, preprocessed)?; // This will happen in a host task. accum.add_processed_message(processed)??; @@ -111,8 +126,6 @@ where // Channel for receiving results of processing messages let (process_tx, mut process_rx) = mpsc::channel(1024); - let current_round = session.current_round(); - let session_arc = Arc::new(session); while !session_arc.can_finalize(&accum)? { tokio::select! { @@ -129,18 +142,17 @@ where session_arc.preprocess_message(&mut accum, &message.from, *payload)?; if let Some(preprocessed) = preprocessed { - let session_clone = session_arc.clone(); - let tx_clone = process_tx.clone(); + let session_arc = session_arc.clone(); + let tx = process_tx.clone(); tokio::spawn(async move { - let result = session_clone.process_message(&mut OsRng, preprocessed); - if tx_clone.send(result).await.is_err() { + let result = session_arc.process_message(&mut OsRng, preprocessed); + if tx.send(result).await.is_err() { tracing::error!("Protocol finished before message processing result sent"); } }); } } else { tracing::warn!("Got protocol message with incorrect session ID - putting back in queue"); - // messages_for_later.push_back(message); tx.incoming_sender.send(message).await?; } } else { @@ -157,11 +169,6 @@ where } } - // Put messages which were not for this session back onto the incoming message channel - // for message in messages_for_later.into_iter() { - // tx.incoming_sender.send(message).await?; - // } - // Get session back out of Arc let session_inner = Arc::try_unwrap(session_arc).map_err(|_| GenericProtocolError::ArcUnwrapError)?;