forked from m-ou-se/rust-atomics-and-locks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rwlock_2.rs
104 lines (92 loc) · 2.85 KB
/
rwlock_2.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
use atomic_wait::{wait, wake_all, wake_one};
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release};
pub struct RwLock<T> {
/// The number of readers, or u32::MAX if write-locked.
state: AtomicU32,
/// Incremented to wake up writers.
writer_wake_counter: AtomicU32,
value: UnsafeCell<T>,
}
unsafe impl<T> Sync for RwLock<T> where T: Send + Sync {}
impl<T> RwLock<T> {
pub const fn new(value: T) -> Self {
Self {
state: AtomicU32::new(0),
writer_wake_counter: AtomicU32::new(0),
value: UnsafeCell::new(value),
}
}
pub fn read(&self) -> ReadGuard<T> {
let mut s = self.state.load(Relaxed);
loop {
if s < u32::MAX {
assert!(s < u32::MAX - 1, "too many readers");
match self.state.compare_exchange_weak(
s, s + 1, Acquire, Relaxed
) {
Ok(_) => return ReadGuard { rwlock: self },
Err(e) => s = e,
}
}
if s == u32::MAX {
wait(&self.state, u32::MAX);
s = self.state.load(Relaxed);
}
}
}
pub fn write(&self) -> WriteGuard<T> {
while self.state.compare_exchange(
0, u32::MAX, Acquire, Relaxed
).is_err() {
let w = self.writer_wake_counter.load(Acquire);
if self.state.load(Relaxed) != 0 {
// Wait if the RwLock is still locked, but only if
// there have been no wake signals since we checked.
wait(&self.writer_wake_counter, w);
}
}
WriteGuard { rwlock: self }
}
}
pub struct ReadGuard<'a, T> {
rwlock: &'a RwLock<T>,
}
pub struct WriteGuard<'a, T> {
rwlock: &'a RwLock<T>,
}
impl<T> Deref for WriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.rwlock.value.get() }
}
}
impl<T> DerefMut for WriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.rwlock.value.get() }
}
}
impl<T> Deref for ReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.rwlock.value.get() }
}
}
impl<T> Drop for ReadGuard<'_, T> {
fn drop(&mut self) {
if self.rwlock.state.fetch_sub(1, Release) == 1 {
self.rwlock.writer_wake_counter.fetch_add(1, Release); // New!
wake_one(&self.rwlock.writer_wake_counter); // Changed!
}
}
}
impl<T> Drop for WriteGuard<'_, T> {
fn drop(&mut self) {
self.rwlock.state.store(0, Release);
self.rwlock.writer_wake_counter.fetch_add(1, Release); // New!
wake_one(&self.rwlock.writer_wake_counter); // New!
wake_all(&self.rwlock.state);
}
}