diff --git a/curp/src/server/raw_curp/mod.rs b/curp/src/server/raw_curp/mod.rs index 270e45ed1b..5122f4ba6d 100644 --- a/curp/src/server/raw_curp/mod.rs +++ b/curp/src/server/raw_curp/mod.rs @@ -736,6 +736,7 @@ impl RawCurp { .iter() .filter_map(|(id, c)| c.is_conflict_with_cmd(cmd).then_some(*id)), ) + .unique() .collect_vec() }; if ids.is_empty() { diff --git a/xline/src/server/barriers.rs b/xline/src/server/barriers.rs index da9de88c54..50bd31e723 100644 --- a/xline/src/server/barriers.rs +++ b/xline/src/server/barriers.rs @@ -1,4 +1,7 @@ -use std::collections::{BTreeMap, HashMap}; +use std::{ + cmp::Reverse, + collections::{BinaryHeap, HashMap}, +}; use clippy_utilities::OverflowArithmetic; use curp::cmd::ProposeId; @@ -17,8 +20,9 @@ impl IndexBarrier { pub(crate) fn new() -> Self { IndexBarrier { inner: Mutex::new(IndexBarrierInner { - last_trigger_index: 0, - barriers: BTreeMap::new(), + next: 1, + indices: BinaryHeap::new(), + barriers: HashMap::new(), }), } } @@ -27,7 +31,7 @@ impl IndexBarrier { pub(crate) async fn wait(&self, index: u64) { let listener = { let mut inner_l = self.inner.lock(); - if inner_l.last_trigger_index >= index { + if inner_l.next > index { return; } inner_l @@ -42,13 +46,18 @@ impl IndexBarrier { /// Trigger all barriers whose index is less than or equal to the given index. pub(crate) fn trigger(&self, index: u64) { let mut inner_l = self.inner.lock(); - if inner_l.last_trigger_index < index { - inner_l.last_trigger_index = index; - } - let mut split_barriers = inner_l.barriers.split_off(&(index.overflow_add(1))); - std::mem::swap(&mut inner_l.barriers, &mut split_barriers); - for (_, barrier) in split_barriers { - barrier.notify(usize::MAX); + inner_l.indices.push(Reverse(index)); + while inner_l + .indices + .peek() + .map_or(false, |i| i.0.eq(&inner_l.next)) + { + let next = inner_l.next; + let _ignore = inner_l.indices.pop(); + if let Some(event) = inner_l.barriers.remove(&next) { + event.notify(usize::MAX); + } + inner_l.next = next.overflow_add(1); } } } @@ -56,10 +65,12 @@ impl IndexBarrier { /// Inner of index barrier. #[derive(Debug)] struct IndexBarrierInner { - /// The last index that the barrier has triggered. - last_trigger_index: u64, - /// Barrier of index. - barriers: BTreeMap, + /// The next index that haven't been triggered + next: u64, + /// Store all keys that larger than `next` + indices: BinaryHeap>, + /// Events + barriers: HashMap, } /// Barrier for id @@ -131,12 +142,28 @@ mod test { #[abort_on_panic] async fn test_index_barrier() { let index_barrier = Arc::new(IndexBarrier::new()); - let barriers = (0..5).map(|i| { - let id_barrier = Arc::clone(&index_barrier); - tokio::spawn(async move { - id_barrier.wait(i).await; + let (done_tx, done_rx) = flume::bounded(5); + let barriers = (1..=5) + .map(|i| { + let index_barrier = Arc::clone(&index_barrier); + let done_tx_c = done_tx.clone(); + tokio::spawn(async move { + index_barrier.wait(i).await; + done_tx_c.send(i).unwrap(); + }) }) - }); + .collect::>(); + + index_barrier.trigger(2); + index_barrier.trigger(3); + sleep(Duration::from_millis(100)).await; + assert!(done_rx.try_recv().is_err()); + index_barrier.trigger(1); + sleep(Duration::from_millis(100)).await; + assert_eq!(done_rx.try_recv().unwrap(), 1); + assert_eq!(done_rx.try_recv().unwrap(), 2); + assert_eq!(done_rx.try_recv().unwrap(), 3); + index_barrier.trigger(4); index_barrier.trigger(5); timeout(Duration::from_millis(100), index_barrier.wait(3))