diff --git a/crates/protocol/src/execute_protocol.rs b/crates/protocol/src/execute_protocol.rs index c7df58fe0..a6f127995 100644 --- a/crates/protocol/src/execute_protocol.rs +++ b/crates/protocol/src/execute_protocol.rs @@ -18,7 +18,7 @@ use num::bigint::BigUint; use rand_core::{CryptoRngCore, OsRng}; use sp_core::{sr25519, Pair}; -use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; use subxt::utils::AccountId32; use synedrion::{ ecdsa::VerifyingKey, @@ -69,7 +69,7 @@ impl RandomizedPrehashSigner for PairWrapper { } } -pub async fn execute_protocol_generic( +pub async fn execute_protocol_generic( mut chans: Channels, session: Session, session_id_hash: [u8; 32], @@ -105,52 +105,81 @@ pub async fn execute_protocol_generic( accum.add_processed_message(processed)??; } - while !session.can_finalize(&accum)? { - let mut messages_for_later = VecDeque::new(); - let (from, payload) = loop { - let message = rx.recv().await.ok_or_else(|| { - GenericProtocolError::::IncomingStream(format!( - "{:?}", - session.current_round() - )) - })?; - - if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() { - if payload.session_id() == &session_id { - break (message.from, *payload); + // Channel for receiving results of processing messages + let (process_tx, mut process_rx) = mpsc::unbounded_channel(); + let current_round = session.current_round(); + let session_arc = Arc::new(Mutex::new(session)); + + loop { + { + let session = session_arc.lock().unwrap(); + if session.can_finalize(&accum)? { + break; + } + } + tokio::select! { + // Incoming message from remote peer + maybe_message = rx.recv() => { + let message = maybe_message.ok_or_else(|| { + GenericProtocolError::IncomingStream(format!("{:?}", current_round)) + })?; + + if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() { + if payload.session_id() == &session_id { + let preprocessed = { + let session = session_arc.lock().unwrap(); + // Perform quick checks before proceeding with the verification. + session.preprocess_message(&mut accum, &message.from, *payload)? + }; + + if let Some(preprocessed) = preprocessed { + let session_clone = session_arc.clone(); + let tx_clone = process_tx.clone(); + tokio::spawn(async move { + let session = session_clone.lock().unwrap(); + let result = session.process_message(&mut OsRng, preprocessed).unwrap(); + tx_clone.send(result).unwrap(); + }); + } + } else { + tracing::warn!("Got protocol message with incorrect session ID - putting back in queue"); + // messages_for_later.push_back(message); + tx.incoming_sender.send(message).await?; + } } else { - tracing::warn!("Got protocol message with incorrect session ID - putting back in queue"); - messages_for_later.push_back(message); + tracing::warn!("Got verifying key during protocol - ignoring"); } - } else { - tracing::warn!("Got verifying key during protocol - ignoring"); } - }; - // Put messages which were not for this session back onto the incoming message channel - for message in messages_for_later.into_iter() { - tx.incoming_sender.send(message).await?; - } - // Perform quick checks before proceeding with the verification. - let preprocessed = session.preprocess_message(&mut accum, &from, payload)?; - - if let Some(preprocessed) = preprocessed { - // TODO (#641): this may happen in a spawned task. - let result = session.process_message(&mut OsRng, preprocessed)?; - // This will happen in a host task. - accum.add_processed_message(result)??; + // Result from processing a message + maybe_result = process_rx.recv() => { + if let Some(result) = maybe_result { + accum.add_processed_message(result)??; + } + } } } - match session.finalize_round(&mut OsRng, accum)? { - FinalizeOutcome::Success(res) => break Ok((res, chans)), - FinalizeOutcome::AnotherRound { - session: new_session, - cached_messages: new_cached_messages, - } => { - session = new_session; - cached_messages = new_cached_messages; - }, + // Put messages which were not for this session back onto the incoming message channel + // for message in messages_for_later.into_iter() { + // tx.incoming_sender.send(message).await?; + // } + + // Get session back out of Arc and Mutex + if let Ok(session_inner) = Arc::try_unwrap(session_arc) { + let session_inner = session_inner.into_inner().unwrap(); + match session_inner.finalize_round(&mut OsRng, accum)? { + FinalizeOutcome::Success(res) => break Ok((res, chans)), + FinalizeOutcome::AnotherRound { + session: new_session, + cached_messages: new_cached_messages, + } => { + session = new_session; + cached_messages = new_cached_messages; + }, + } + } else { + panic!("Cannot get session out of Arc"); } } }