diff --git a/src/test/utils.rs b/src/test/utils.rs index 31c73d288..a695bb82d 100644 --- a/src/test/utils.rs +++ b/src/test/utils.rs @@ -18,11 +18,11 @@ use rand::{thread_rng, Rng}; use std::collections::hash_map; use std::collections::HashMap; use std::env; -use std::io::{Cursor, Read, Write}; +use std::io::{self, Write}; use std::path::PathBuf; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex, RwLock}; +use std::sync::{Arc, Mutex}; use std::time::Duration; macro_rules! expect_event { @@ -42,25 +42,20 @@ macro_rules! expect_event { pub(crate) use expect_event; pub(crate) struct TestStore { - persisted_bytes: RwLock>>>>>, + persisted_bytes: Mutex>>>, did_persist: Arc, } impl TestStore { pub fn new() -> Self { - let persisted_bytes = RwLock::new(HashMap::new()); + let persisted_bytes = Mutex::new(HashMap::new()); let did_persist = Arc::new(AtomicBool::new(false)); Self { persisted_bytes, did_persist } } pub fn get_persisted_bytes(&self, namespace: &str, key: &str) -> Option> { - if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) { - if let Some(inner_ref) = outer_ref.get(key) { - let locked = inner_ref.read().unwrap(); - return Some((*locked).clone()); - } - } - None + let persisted_lock = self.persisted_bytes.lock().unwrap(); + persisted_lock.get(namespace).and_then(|e| e.get(key).cloned()) } pub fn get_and_clear_did_persist(&self) -> bool { @@ -69,46 +64,45 @@ impl TestStore { } impl KVStore for TestStore { - type Reader = TestReader; + type Reader = io::Cursor>; - fn read(&self, namespace: &str, key: &str) -> std::io::Result { - if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) { + fn read(&self, namespace: &str, key: &str) -> io::Result { + let persisted_lock = self.persisted_bytes.lock().unwrap(); + if let Some(outer_ref) = persisted_lock.get(namespace) { if let Some(inner_ref) = outer_ref.get(key) { - Ok(TestReader::new(Arc::clone(inner_ref))) + let bytes = inner_ref.clone(); + Ok(io::Cursor::new(bytes)) } else { - let msg = format!("Key not found: {}", key); - Err(std::io::Error::new(std::io::ErrorKind::NotFound, msg)) + Err(io::Error::new(io::ErrorKind::NotFound, "Key not found")) } } else { - let msg = format!("Namespace not found: {}", namespace); - Err(std::io::Error::new(std::io::ErrorKind::NotFound, msg)) + Err(io::Error::new(io::ErrorKind::NotFound, "Namespace not found")) } } - fn write(&self, namespace: &str, key: &str, buf: &[u8]) -> std::io::Result<()> { - let mut guard = self.persisted_bytes.write().unwrap(); - let outer_e = guard.entry(namespace.to_string()).or_insert(HashMap::new()); - let inner_e = outer_e.entry(key.to_string()).or_insert(Arc::new(RwLock::new(Vec::new()))); - - let mut guard = inner_e.write().unwrap(); - guard.write_all(buf)?; + fn write(&self, namespace: &str, key: &str, buf: &[u8]) -> io::Result<()> { + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + let outer_e = persisted_lock.entry(namespace.to_string()).or_insert(HashMap::new()); + let mut bytes = Vec::new(); + bytes.write_all(buf)?; + outer_e.insert(key.to_string(), bytes); self.did_persist.store(true, Ordering::SeqCst); Ok(()) } - fn remove(&self, namespace: &str, key: &str) -> std::io::Result<()> { - match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) { - hash_map::Entry::Occupied(mut e) => { - self.did_persist.store(true, Ordering::SeqCst); - e.get_mut().remove(&key.to_string()); - Ok(()) - } - hash_map::Entry::Vacant(_) => Ok(()), + fn remove(&self, namespace: &str, key: &str) -> io::Result<()> { + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + if let Some(outer_ref) = persisted_lock.get_mut(namespace) { + outer_ref.remove(&key.to_string()); + self.did_persist.store(true, Ordering::SeqCst); } + + Ok(()) } - fn list(&self, namespace: &str) -> std::io::Result> { - match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) { + fn list(&self, namespace: &str) -> io::Result> { + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + match persisted_lock.entry(namespace.to_string()) { hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()), hash_map::Entry::Vacant(_) => Ok(Vec::new()), } @@ -139,24 +133,6 @@ impl KVStorePersister for TestStore { } } -pub struct TestReader { - entry_ref: Arc>>, -} - -impl TestReader { - pub fn new(entry_ref: Arc>>) -> Self { - Self { entry_ref } - } -} - -impl Read for TestReader { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - let bytes = self.entry_ref.read().unwrap().clone(); - let mut reader = Cursor::new(bytes); - reader.read(buf) - } -} - // Copied over from upstream LDK #[allow(dead_code)] pub struct TestLogger {