one_shot_mutex/
rwlock.rs

1use core::sync::atomic::{AtomicUsize, Ordering};
2
3use lock_api::{
4    GuardSend, RawRwLock, RawRwLockDowngrade, RawRwLockRecursive, RawRwLockUpgrade,
5    RawRwLockUpgradeDowngrade,
6};
7
8/// A one-shot readers-writer lock that panics instead of (dead)locking on contention.
9///
10/// This lock allows no contention and panics on [`lock_shared`], [`lock_exclusive`], [`lock_upgradable`], and [`upgrade`] if it is already locked conflictingly.
11/// This is useful in situations where contention would be a bug,
12/// such as in single-threaded programs that would deadlock on contention.
13///
14/// [`lock_shared`]: RawOneShotRwLock::lock_shared
15/// [`lock_exclusive`]: RawOneShotRwLock::lock_exclusive
16/// [`lock_upgradable`]: RawOneShotRwLock::lock_upgradable
17/// [`upgrade`]: RawOneShotRwLock::upgrade
18///
19/// # Examples
20///
21/// ```
22/// use one_shot_mutex::OneShotRwLock;
23///
24/// static X: OneShotRwLock<i32> = OneShotRwLock::new(42);
25///
26/// // This is equivalent to `X.try_write().unwrap()`.
27/// let x = X.write();
28///
29/// // This panics instead of deadlocking.
30/// // let x2 = X.write();
31///
32/// // Once we unlock the mutex, we can lock it again.
33/// drop(x);
34/// let x = X.write();
35/// ```
36pub struct RawOneShotRwLock {
37    lock: AtomicUsize,
38}
39
40/// Normal shared lock counter
41const SHARED: usize = 1 << 2;
42/// Special upgradable shared lock flag
43const UPGRADABLE: usize = 1 << 1;
44/// Exclusive lock flag
45const EXCLUSIVE: usize = 1;
46
47impl RawOneShotRwLock {
48    #[inline]
49    fn is_locked_shared(&self) -> bool {
50        self.lock.load(Ordering::Relaxed) & !(EXCLUSIVE | UPGRADABLE) != 0
51    }
52
53    #[inline]
54    fn is_locked_upgradable(&self) -> bool {
55        self.lock.load(Ordering::Relaxed) & UPGRADABLE == UPGRADABLE
56    }
57
58    /// Acquire a shared lock, returning the new lock value.
59    #[inline]
60    fn acquire_shared(&self) -> usize {
61        let value = self.lock.fetch_add(SHARED, Ordering::Acquire);
62
63        // An arbitrary cap that allows us to catch overflows long before they happen
64        if value > usize::MAX / 2 {
65            self.lock.fetch_sub(SHARED, Ordering::Relaxed);
66            panic!("Too many shared locks, cannot safely proceed");
67        }
68
69        value
70    }
71}
72
73unsafe impl RawRwLock for RawOneShotRwLock {
74    #[allow(clippy::declare_interior_mutable_const)]
75    const INIT: Self = Self {
76        lock: AtomicUsize::new(0),
77    };
78
79    type GuardMarker = GuardSend;
80
81    #[inline]
82    fn lock_shared(&self) {
83        assert!(
84            self.try_lock_shared(),
85            "called `lock_shared` on a `RawOneShotRwLock` that is already locked exclusively"
86        );
87    }
88
89    #[inline]
90    fn try_lock_shared(&self) -> bool {
91        let value = self.acquire_shared();
92
93        let acquired = value & EXCLUSIVE != EXCLUSIVE;
94
95        if !acquired {
96            unsafe {
97                self.unlock_shared();
98            }
99        }
100
101        acquired
102    }
103
104    #[inline]
105    unsafe fn unlock_shared(&self) {
106        debug_assert!(self.is_locked_shared());
107
108        self.lock.fetch_sub(SHARED, Ordering::Release);
109    }
110
111    #[inline]
112    fn lock_exclusive(&self) {
113        assert!(
114            self.try_lock_exclusive(),
115            "called `lock_exclusive` on a `RawOneShotRwLock` that is already locked"
116        );
117    }
118
119    #[inline]
120    fn try_lock_exclusive(&self) -> bool {
121        self.lock
122            .compare_exchange(0, EXCLUSIVE, Ordering::Acquire, Ordering::Relaxed)
123            .is_ok()
124    }
125
126    #[inline]
127    unsafe fn unlock_exclusive(&self) {
128        debug_assert!(self.is_locked_exclusive());
129
130        self.lock.fetch_and(!EXCLUSIVE, Ordering::Release);
131    }
132
133    #[inline]
134    fn is_locked(&self) -> bool {
135        self.lock.load(Ordering::Relaxed) != 0
136    }
137
138    #[inline]
139    fn is_locked_exclusive(&self) -> bool {
140        self.lock.load(Ordering::Relaxed) & EXCLUSIVE == EXCLUSIVE
141    }
142}
143
144unsafe impl RawRwLockRecursive for RawOneShotRwLock {
145    #[inline]
146    fn lock_shared_recursive(&self) {
147        self.lock_shared();
148    }
149
150    #[inline]
151    fn try_lock_shared_recursive(&self) -> bool {
152        self.try_lock_shared()
153    }
154}
155
156unsafe impl RawRwLockDowngrade for RawOneShotRwLock {
157    #[inline]
158    unsafe fn downgrade(&self) {
159        // Reserve the shared guard for ourselves
160        self.acquire_shared();
161
162        unsafe {
163            self.unlock_exclusive();
164        }
165    }
166}
167
168unsafe impl RawRwLockUpgrade for RawOneShotRwLock {
169    #[inline]
170    fn lock_upgradable(&self) {
171        assert!(
172            self.try_lock_upgradable(),
173            "called `lock_upgradable` on a `RawOneShotRwLock` that is already locked upgradably or exclusively"
174        );
175    }
176
177    #[inline]
178    fn try_lock_upgradable(&self) -> bool {
179        let value = self.lock.fetch_or(UPGRADABLE, Ordering::Acquire);
180
181        let acquired = value & (UPGRADABLE | EXCLUSIVE) == 0;
182
183        if !acquired && value & UPGRADABLE == 0 {
184            unsafe {
185                self.unlock_upgradable();
186            }
187        }
188
189        acquired
190    }
191
192    #[inline]
193    unsafe fn unlock_upgradable(&self) {
194        debug_assert!(self.is_locked_upgradable());
195
196        self.lock.fetch_and(!UPGRADABLE, Ordering::Release);
197    }
198
199    #[inline]
200    unsafe fn upgrade(&self) {
201        assert!(
202            self.try_upgrade(),
203            "called `upgrade` on a `RawOneShotRwLock` that is also locked shared by others"
204        );
205    }
206
207    #[inline]
208    unsafe fn try_upgrade(&self) -> bool {
209        self.lock
210            .compare_exchange(UPGRADABLE, EXCLUSIVE, Ordering::Acquire, Ordering::Relaxed)
211            .is_ok()
212    }
213}
214
215unsafe impl RawRwLockUpgradeDowngrade for RawOneShotRwLock {
216    #[inline]
217    unsafe fn downgrade_upgradable(&self) {
218        self.acquire_shared();
219
220        unsafe {
221            self.unlock_upgradable();
222        }
223    }
224
225    #[inline]
226    unsafe fn downgrade_to_upgradable(&self) {
227        debug_assert!(self.is_locked_exclusive());
228
229        self.lock
230            .fetch_xor(UPGRADABLE | EXCLUSIVE, Ordering::Release);
231    }
232}
233
234/// A [`lock_api::RwLock`] based on [`RawOneShotRwLock`].
235pub type OneShotRwLock<T> = lock_api::RwLock<RawOneShotRwLock, T>;
236
237/// A [`lock_api::RwLockReadGuard`] based on [`RawOneShotRwLock`].
238pub type OneShotRwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, RawOneShotRwLock, T>;
239
240/// A [`lock_api::RwLockUpgradableReadGuard`] based on [`RawOneShotRwLock`].
241pub type OneShotRwLockUpgradableReadGuard<'a, T> =
242    lock_api::RwLockUpgradableReadGuard<'a, RawOneShotRwLock, T>;
243
244/// A [`lock_api::RwLockWriteGuard`] based on [`RawOneShotRwLock`].
245pub type OneShotRwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, RawOneShotRwLock, T>;
246
247#[cfg(test)]
248mod tests {
249    use lock_api::RwLockUpgradableReadGuard;
250
251    use super::*;
252
253    #[test]
254    fn lock_exclusive() {
255        let lock = OneShotRwLock::new(42);
256        let mut guard = lock.write();
257        assert_eq!(*guard, 42);
258
259        *guard += 1;
260        drop(guard);
261        let guard = lock.write();
262        assert_eq!(*guard, 43);
263    }
264
265    #[test]
266    #[should_panic]
267    fn lock_exclusive_panic() {
268        let lock = OneShotRwLock::new(42);
269        let _guard = lock.write();
270        let _guard2 = lock.write();
271    }
272
273    #[test]
274    #[should_panic]
275    fn lock_exclusive_shared_panic() {
276        let lock = OneShotRwLock::new(42);
277        let _guard = lock.write();
278        let _guard2 = lock.read();
279    }
280
281    #[test]
282    fn try_lock_exclusive() {
283        let lock = OneShotRwLock::new(42);
284        let mut guard = lock.try_write().unwrap();
285        assert_eq!(*guard, 42);
286        assert!(lock.try_write().is_none());
287
288        *guard += 1;
289        drop(guard);
290        let guard = lock.try_write().unwrap();
291        assert_eq!(*guard, 43);
292    }
293
294    #[test]
295    fn lock_shared() {
296        let lock = OneShotRwLock::new(42);
297        let guard = lock.read();
298        assert_eq!(*guard, 42);
299        let guard2 = lock.read();
300        assert_eq!(*guard2, 42);
301    }
302
303    #[test]
304    #[should_panic]
305    fn lock_shared_panic() {
306        let lock = OneShotRwLock::new(42);
307        let _guard = lock.write();
308        let _guard2 = lock.read();
309    }
310
311    #[test]
312    fn try_lock_shared() {
313        let lock = OneShotRwLock::new(42);
314        let guard = lock.try_read().unwrap();
315        assert_eq!(*guard, 42);
316        assert!(lock.try_write().is_none());
317
318        let guard2 = lock.try_read().unwrap();
319        assert_eq!(*guard2, 42);
320    }
321
322    #[test]
323    fn lock_upgradable() {
324        let lock = OneShotRwLock::new(42);
325        let guard = lock.upgradable_read();
326        assert_eq!(*guard, 42);
327        assert!(lock.try_write().is_none());
328
329        let mut upgraded = RwLockUpgradableReadGuard::upgrade(guard);
330        *upgraded += 1;
331        drop(upgraded);
332        let guard2 = lock.upgradable_read();
333        assert_eq!(*guard2, 43);
334    }
335
336    #[test]
337    #[should_panic]
338    fn lock_upgradable_panic() {
339        let lock = OneShotRwLock::new(42);
340        let _guard = lock.upgradable_read();
341        let _guard2 = lock.upgradable_read();
342    }
343
344    #[test]
345    #[should_panic]
346    fn lock_upgradable_write_panic() {
347        let lock = OneShotRwLock::new(42);
348        let _guard = lock.write();
349        let _guard2 = lock.upgradable_read();
350    }
351
352    #[test]
353    fn try_lock_upgradable() {
354        let lock = OneShotRwLock::new(42);
355        let guard = lock.try_upgradable_read().unwrap();
356        assert_eq!(*guard, 42);
357        assert!(lock.try_write().is_none());
358
359        let mut upgraded = RwLockUpgradableReadGuard::try_upgrade(guard).unwrap();
360        *upgraded += 1;
361        drop(upgraded);
362        let guard2 = lock.try_upgradable_read().unwrap();
363        assert_eq!(*guard2, 43);
364    }
365
366    #[test]
367    #[should_panic]
368    fn upgrade_panic() {
369        let lock = OneShotRwLock::new(42);
370        let guard = lock.upgradable_read();
371        let _guard2 = lock.read();
372        let _guard3 = RwLockUpgradableReadGuard::upgrade(guard);
373    }
374}