diff --git a/crates/prism/src/node_types/sequencer.rs b/crates/prism/src/node_types/sequencer.rs index 91d4439e..502995c6 100644 --- a/crates/prism/src/node_types/sequencer.rs +++ b/crates/prism/src/node_types/sequencer.rs @@ -10,7 +10,7 @@ use prism_common::{ }, }; use std::{self, collections::VecDeque, sync::Arc}; -use tokio::sync::{broadcast, Mutex}; +use tokio::sync::{broadcast, RwLock}; use sp1_sdk::{ProverClient, SP1ProvingKey, SP1Stdin, SP1VerifyingKey}; @@ -39,9 +39,9 @@ pub struct Sequencer { // [`pending_operations`] is a buffer for operations that have not yet been // posted to the DA layer. - pending_operations: Arc>>, - tree: Arc>>>, - prover_client: Arc>, + pending_operations: Arc>>, + tree: Arc>>>, + prover_client: Arc>, proving_key: SP1ProvingKey, verifying_key: SP1VerifyingKey, @@ -76,7 +76,7 @@ impl Sequencer { let ws = cfg.webserver.context("Missing webserver configuration")?; let start_height = cfg.celestia_config.unwrap_or_default().start_height; - let tree = Arc::new(Mutex::new(KeyDirectoryTree::new(db.clone()))); + let tree = Arc::new(RwLock::new(KeyDirectoryTree::new(db.clone()))); let prover_client = ProverClient::new(); let (pk, vk) = prover_client.setup(PRISM_ELF); @@ -89,9 +89,9 @@ impl Sequencer { verifying_key: vk, key, start_height, - prover_client: Arc::new(Mutex::new(prover_client)), + prover_client: Arc::new(RwLock::new(prover_client)), tree, - pending_operations: Arc::new(Mutex::new(Vec::new())), + pending_operations: Arc::new(RwLock::new(Vec::new())), }) } @@ -152,7 +152,7 @@ impl Sequencer { // Get pending operations let pending_operations = { - let mut ops = self.pending_operations.lock().await; + let mut ops = self.pending_operations.write().await; std::mem::take(&mut *ops) }; @@ -316,7 +316,6 @@ impl Sequencer { Ok(()) } - async fn prove_epoch( &self, height: u64, @@ -332,8 +331,7 @@ impl Sequencer { let mut stdin = SP1Stdin::new(); stdin.write(&batch); - - let client = self.prover_client.lock().await; + let client = self.prover_client.read().await; info!("generating proof for epoch height {}", height); #[cfg(not(feature = "plonk"))] @@ -357,17 +355,15 @@ impl Sequencer { epoch_json.insert_signature(&self.key); Ok(epoch_json) } - pub async fn get_commitment(&self) -> Result { - let tree = self.tree.lock().await; + let tree = self.tree.read().await; tree.get_commitment().context("Failed to get commitment") } - pub async fn get_hashchain( &self, id: &String, ) -> Result> { - let tree = self.tree.lock().await; + let tree = self.tree.read().await; let hashed_id = hash(id.as_bytes()); let key_hash = KeyHash::with::(hashed_id); @@ -376,7 +372,7 @@ impl Sequencer { /// Updates the state from an already verified pending operation. async fn process_operation(&self, operation: &Operation) -> Result { - let mut tree = self.tree.lock().await; + let mut tree = self.tree.write().await; tree.process_operation(operation) } @@ -387,7 +383,7 @@ impl Sequencer { ) -> Result<()> { // TODO: this is only basic validation. The validation over if an entry can be added to the hashchain or not is done in the process_operation function incoming_operation.validate()?; - let mut pending = self.pending_operations.lock().await; + let mut pending = self.pending_operations.write().await; pending.push(incoming_operation.clone()); Ok(()) } @@ -504,7 +500,6 @@ mod tests { Sequencer::new(db.clone(), da_layer, Config::default(), signing_key.clone()).unwrap(), ) } - #[tokio::test] #[serial] async fn test_validate_and_queue_update() { @@ -519,12 +514,11 @@ mod tests { .await .unwrap(); - let pending_ops = sequencer.pending_operations.lock().await; + let pending_ops = sequencer.pending_operations.read().await; assert_eq!(pending_ops.len(), 1); teardown_db(sequencer.db.clone()); } - #[tokio::test] #[serial] async fn test_process_operation() { @@ -555,7 +549,6 @@ mod tests { teardown_db(sequencer.db.clone()); } - #[tokio::test] #[serial] async fn test_execute_block() { @@ -580,12 +573,10 @@ mod tests { teardown_db(sequencer.db.clone()); } - #[tokio::test] #[serial] async fn test_finalize_new_epoch() { let sequencer = create_test_sequencer().await; - let tree = sequencer.tree.lock().await; let signing_key_1 = create_mock_signing_key(); let signing_key_2 = create_mock_signing_key(); @@ -601,10 +592,10 @@ mod tests { add_key("user1@example.com", 0, new_key, signing_key_1), ]; - let prev_commitment = tree.get_commitment().unwrap(); + let prev_commitment = sequencer.get_commitment().await.unwrap(); sequencer.finalize_new_epoch(0, operations).await.unwrap(); - let new_commitment = tree.get_commitment().unwrap(); + let new_commitment = sequencer.get_commitment().await.unwrap(); assert_ne!(prev_commitment, new_commitment); teardown_db(sequencer.db.clone());