Skip to content

Commit

Permalink
Simplify and fix TestStore
Browse files Browse the repository at this point in the history
.. as we don't require all that logic anymore now that we don't return
an `FilesystemWriter` anymore etc.
  • Loading branch information
tnull committed Aug 24, 2023
1 parent a9bb2e9 commit 75ffc50
Showing 1 changed file with 30 additions and 54 deletions.
84 changes: 30 additions & 54 deletions src/test/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ use rand::{thread_rng, Rng};
use std::collections::hash_map;
use std::collections::HashMap;
use std::env;
use std::io::{Cursor, Read, Write};
use std::io::{self, Write};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::sync::{Arc, Mutex};
use std::time::Duration;

macro_rules! expect_event {
Expand All @@ -42,25 +42,20 @@ macro_rules! expect_event {
pub(crate) use expect_event;

pub(crate) struct TestStore {
persisted_bytes: RwLock<HashMap<String, HashMap<String, Arc<RwLock<Vec<u8>>>>>>,
persisted_bytes: Mutex<HashMap<String, HashMap<String, Vec<u8>>>>,
did_persist: Arc<AtomicBool>,
}

impl TestStore {
pub fn new() -> Self {
let persisted_bytes = RwLock::new(HashMap::new());
let persisted_bytes = Mutex::new(HashMap::new());
let did_persist = Arc::new(AtomicBool::new(false));
Self { persisted_bytes, did_persist }
}

pub fn get_persisted_bytes(&self, namespace: &str, key: &str) -> Option<Vec<u8>> {
if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) {
if let Some(inner_ref) = outer_ref.get(key) {
let locked = inner_ref.read().unwrap();
return Some((*locked).clone());
}
}
None
let persisted_lock = self.persisted_bytes.lock().unwrap();
persisted_lock.get(namespace).and_then(|e| e.get(key).cloned())
}

pub fn get_and_clear_did_persist(&self) -> bool {
Expand All @@ -69,46 +64,45 @@ impl TestStore {
}

impl KVStore for TestStore {
type Reader = TestReader;
type Reader = io::Cursor<Vec<u8>>;

fn read(&self, namespace: &str, key: &str) -> std::io::Result<Self::Reader> {
if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) {
fn read(&self, namespace: &str, key: &str) -> io::Result<Self::Reader> {
let persisted_lock = self.persisted_bytes.lock().unwrap();
if let Some(outer_ref) = persisted_lock.get(namespace) {
if let Some(inner_ref) = outer_ref.get(key) {
Ok(TestReader::new(Arc::clone(inner_ref)))
let bytes = inner_ref.clone();
Ok(io::Cursor::new(bytes))
} else {
let msg = format!("Key not found: {}", key);
Err(std::io::Error::new(std::io::ErrorKind::NotFound, msg))
Err(io::Error::new(io::ErrorKind::NotFound, "Key not found"))
}
} else {
let msg = format!("Namespace not found: {}", namespace);
Err(std::io::Error::new(std::io::ErrorKind::NotFound, msg))
Err(io::Error::new(io::ErrorKind::NotFound, "Namespace not found"))
}
}

fn write(&self, namespace: &str, key: &str, buf: &[u8]) -> std::io::Result<()> {
let mut guard = self.persisted_bytes.write().unwrap();
let outer_e = guard.entry(namespace.to_string()).or_insert(HashMap::new());
let inner_e = outer_e.entry(key.to_string()).or_insert(Arc::new(RwLock::new(Vec::new())));

let mut guard = inner_e.write().unwrap();
guard.write_all(buf)?;
fn write(&self, namespace: &str, key: &str, buf: &[u8]) -> io::Result<()> {
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
let outer_e = persisted_lock.entry(namespace.to_string()).or_insert(HashMap::new());
let mut bytes = Vec::new();
bytes.write_all(buf)?;
outer_e.insert(key.to_string(), bytes);
self.did_persist.store(true, Ordering::SeqCst);
Ok(())
}

fn remove(&self, namespace: &str, key: &str) -> std::io::Result<()> {
match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) {
hash_map::Entry::Occupied(mut e) => {
self.did_persist.store(true, Ordering::SeqCst);
e.get_mut().remove(&key.to_string());
Ok(())
}
hash_map::Entry::Vacant(_) => Ok(()),
fn remove(&self, namespace: &str, key: &str) -> io::Result<()> {
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
if let Some(outer_ref) = persisted_lock.get_mut(namespace) {
outer_ref.remove(&key.to_string());
self.did_persist.store(true, Ordering::SeqCst);
}

Ok(())
}

fn list(&self, namespace: &str) -> std::io::Result<Vec<String>> {
match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) {
fn list(&self, namespace: &str) -> io::Result<Vec<String>> {
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
match persisted_lock.entry(namespace.to_string()) {
hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()),
hash_map::Entry::Vacant(_) => Ok(Vec::new()),
}
Expand Down Expand Up @@ -139,24 +133,6 @@ impl KVStorePersister for TestStore {
}
}

pub struct TestReader {
entry_ref: Arc<RwLock<Vec<u8>>>,
}

impl TestReader {
pub fn new(entry_ref: Arc<RwLock<Vec<u8>>>) -> Self {
Self { entry_ref }
}
}

impl Read for TestReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let bytes = self.entry_ref.read().unwrap().clone();
let mut reader = Cursor::new(bytes);
reader.read(buf)
}
}

// Copied over from upstream LDK
#[allow(dead_code)]
pub struct TestLogger {
Expand Down

0 comments on commit 75ffc50

Please sign in to comment.