Skip to content

Commit

Permalink
For completeness' sake, add Serde support also for BptreeMap and Hash…
Browse files Browse the repository at this point in the history
…Trie. (#107)
  • Loading branch information
adamreichold authored Mar 6, 2024
1 parent ea6d612 commit b114036
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 30 deletions.
59 changes: 59 additions & 0 deletions src/bptree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
#[cfg(feature = "asynch")]
pub mod asynch;

#[cfg(feature = "serde")]
use serde::{
de::{Deserialize, Deserializer},
ser::{Serialize, SerializeMap, Serializer},
};

#[cfg(feature = "serde")]
use crate::utils::MapCollector;

use crate::internals::lincowcell::{LinCowCell, LinCowCellReadTxn, LinCowCellWriteTxn};

include!("impl.rs");
Expand Down Expand Up @@ -37,6 +46,42 @@ impl<K: Clone + Ord + Debug + Sync + Send + 'static, V: Clone + Sync + Send + 's
}
}

#[cfg(feature = "serde")]
impl<K, V> Serialize for BptreeMap<K, V>
where
K: Serialize + Clone + Ord + Debug + Sync + Send + 'static,
V: Serialize + Clone + Sync + Send + 'static,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let txn = self.read();

let mut state = serializer.serialize_map(Some(txn.len()))?;

for (key, val) in txn.iter() {
state.serialize_entry(key, val)?;
}

state.end()
}
}

#[cfg(feature = "serde")]
impl<'de, K, V> Deserialize<'de> for BptreeMap<K, V>
where
K: Deserialize<'de> + Clone + Ord + Debug + Sync + Send + 'static,
V: Deserialize<'de> + Clone + Sync + Send + 'static,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(MapCollector::new())
}
}

#[cfg(test)]
mod tests {
use super::BptreeMap;
Expand Down Expand Up @@ -602,4 +647,18 @@ mod tests {
// Done!
}
*/

#[cfg(feature = "serde")]
#[test]
fn test_bptreee2_serialize_deserialize() {
let map: BptreeMap<usize, usize> = vec![(10, 11), (15, 16), (20, 21)].into_iter().collect();

let value = serde_json::to_value(&map).unwrap();
assert_eq!(value, serde_json::json!({ "10": 11, "15": 16, "20": 21 }));

let map: BptreeMap<usize, usize> = serde_json::from_value(value).unwrap();
let mut vec: Vec<(usize, usize)> = map.read().iter().map(|(k, v)| (*k, *v)).collect();
vec.sort_unstable();
assert_eq!(vec, [(10, 11), (15, 16), (20, 21)]);
}
}
35 changes: 5 additions & 30 deletions src/hashmap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,15 @@
#[cfg(feature = "asynch")]
pub mod asynch;

#[cfg(feature = "serde")]
use std::fmt;
#[cfg(feature = "serde")]
use std::iter;
#[cfg(feature = "serde")]
use std::marker::PhantomData;

#[cfg(feature = "serde")]
use serde::{
de::{Deserialize, Deserializer, MapAccess, Visitor},
de::{Deserialize, Deserializer},
ser::{Serialize, SerializeMap, Serializer},
};

#[cfg(feature = "serde")]
use crate::utils::MapCollector;

use crate::internals::lincowcell::{LinCowCell, LinCowCellReadTxn, LinCowCellWriteTxn};

include!("impl.rs");
Expand Down Expand Up @@ -163,28 +159,7 @@ where
where
D: Deserializer<'de>,
{
struct Collector<K, V>(PhantomData<(K, V)>);

impl<'de, K, V> Visitor<'de> for Collector<K, V>
where
K: Deserialize<'de> + Hash + Eq + Clone + Debug + Sync + Send + 'static,
V: Deserialize<'de> + Clone + Sync + Send + 'static,
{
type Value = HashMap<K, V>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a map")
}

fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
iter::from_fn(|| access.next_entry().transpose()).collect()
}
}

deserializer.deserialize_map(Collector(PhantomData))
deserializer.deserialize_map(MapCollector::new())
}
}

Expand Down
59 changes: 59 additions & 0 deletions src/hashtrie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,18 @@
#[cfg(feature = "asynch")]
pub mod asynch;

#[cfg(feature = "serde")]
use serde::{
de::{Deserialize, Deserializer},
ser::{Serialize, SerializeMap, Serializer},
};

#[cfg(feature = "arcache")]
use crate::internals::hashtrie::cursor::Datum;

#[cfg(feature = "serde")]
use crate::utils::MapCollector;

use crate::internals::lincowcell::{LinCowCell, LinCowCellReadTxn, LinCowCellWriteTxn};

include!("impl.rs");
Expand Down Expand Up @@ -121,6 +130,42 @@ impl<K: Hash + Eq + Clone + Debug + Sync + Send + 'static, V: Clone + Sync + Sen
}
}

#[cfg(feature = "serde")]
impl<K, V> Serialize for HashTrie<K, V>
where
K: Serialize + Hash + Eq + Clone + Debug + Sync + Send + 'static,
V: Serialize + Clone + Sync + Send + 'static,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let txn = self.read();

let mut state = serializer.serialize_map(Some(txn.len()))?;

for (key, val) in txn.iter() {
state.serialize_entry(key, val)?;
}

state.end()
}
}

#[cfg(feature = "serde")]
impl<'de, K, V> Deserialize<'de> for HashTrie<K, V>
where
K: Deserialize<'de> + Hash + Eq + Clone + Debug + Sync + Send + 'static,
V: Deserialize<'de> + Clone + Sync + Send + 'static,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(MapCollector::new())
}
}

#[cfg(test)]
mod tests {
use super::HashTrie;
Expand Down Expand Up @@ -226,4 +271,18 @@ mod tests {
tx.remove(&13);
}
}

#[cfg(feature = "serde")]
#[test]
fn test_hashmap_serialize_deserialize() {
let hmap: HashTrie<usize, usize> = vec![(10, 11), (15, 16), (20, 21)].into_iter().collect();

let value = serde_json::to_value(&hmap).unwrap();
assert_eq!(value, serde_json::json!({ "10": 11, "15": 16, "20": 21 }));

let hmap: HashTrie<usize, usize> = serde_json::from_value(value).unwrap();
let mut vec: Vec<(usize, usize)> = hmap.read().iter().map(|(k, v)| (*k, *v)).collect();
vec.sort_unstable();
assert_eq!(vec, [(10, 11), (15, 16), (20, 21)]);
}
}
40 changes: 40 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
use std::borrow::Borrow;
use std::cmp::Ordering;
// use std::mem::MaybeUninit;
#[cfg(feature = "serde")]
use std::fmt;
#[cfg(feature = "serde")]
use std::iter;
#[cfg(feature = "serde")]
use std::marker::PhantomData;
use std::ptr;

#[cfg(feature = "serde")]
use serde::de::{Deserialize, MapAccess, Visitor};

pub(crate) unsafe fn slice_insert<T>(slice: &mut [T], new: T, idx: usize) {
ptr::copy(
slice.as_ptr().add(idx),
Expand Down Expand Up @@ -80,3 +89,34 @@ where
}
Err(slice.len())
}

#[cfg(feature = "serde")]
pub struct MapCollector<T, K, V>(PhantomData<(T, K, V)>);

#[cfg(feature = "serde")]
impl<T, K, V> MapCollector<T, K, V> {
pub fn new() -> Self {
Self(PhantomData)
}
}

#[cfg(feature = "serde")]
impl<'de, T, K, V> Visitor<'de> for MapCollector<T, K, V>
where
T: FromIterator<(K, V)>,
K: Deserialize<'de>,
V: Deserialize<'de>,
{
type Value = T;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a map")
}

fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
iter::from_fn(|| access.next_entry().transpose()).collect()
}
}

0 comments on commit b114036

Please sign in to comment.