spinning_top/
rw_spinlock.rs

1use core::marker::PhantomData;
2use core::sync::atomic::{AtomicUsize, Ordering};
3
4use lock_api::{
5    GuardSend, RawRwLock, RawRwLockDowngrade, RawRwLockRecursive, RawRwLockUpgrade,
6    RawRwLockUpgradeDowngrade,
7};
8
9use crate::relax::{Backoff, Relax, Spin};
10
11/// A simple, spinning, read-preferring readers-writer lock.
12// Adapted from `spin::rwlock::RwLock`, but
13// - with separation of `UPGRADABLE` and `EXCLUSIVE`,
14// - with optional exponential backoff,
15// - with `impl RawRwLockRecursive`.
16// <https://github.com/mvdnes/spin-rs/blob/d064a66b450c6c90e49a7e73fbea161e39f7a724/src/rwlock.rs>
17#[derive(Debug)]
18pub struct RawRwSpinlock<R: Relax = Spin> {
19    lock: AtomicUsize,
20    relax: PhantomData<R>,
21}
22
23/// Normal shared lock counter
24const SHARED: usize = 1 << 2;
25/// Special upgradable shared lock flag
26const UPGRADABLE: usize = 1 << 1;
27/// Exclusive lock flag
28const EXCLUSIVE: usize = 1;
29
30impl<R: Relax> RawRwSpinlock<R> {
31    #[inline]
32    fn is_locked_shared(&self) -> bool {
33        self.lock.load(Ordering::Relaxed) & !(EXCLUSIVE | UPGRADABLE) != 0
34    }
35
36    #[inline]
37    fn is_locked_upgradable(&self) -> bool {
38        self.lock.load(Ordering::Relaxed) & UPGRADABLE == UPGRADABLE
39    }
40
41    /// Acquire a shared lock, returning the new lock value.
42    #[inline]
43    fn acquire_shared(&self) -> usize {
44        let value = self.lock.fetch_add(SHARED, Ordering::Acquire);
45
46        // An arbitrary cap that allows us to catch overflows long before they happen
47        if value > usize::MAX / 2 {
48            self.lock.fetch_sub(SHARED, Ordering::Relaxed);
49            panic!("Too many shared locks, cannot safely proceed");
50        }
51
52        value
53    }
54}
55
56unsafe impl<R: Relax> RawRwLock for RawRwSpinlock<R> {
57    #[allow(clippy::declare_interior_mutable_const)]
58    const INIT: Self = Self {
59        lock: AtomicUsize::new(0),
60        relax: PhantomData,
61    };
62
63    type GuardMarker = GuardSend;
64
65    #[inline]
66    fn lock_shared(&self) {
67        let mut relax = R::default();
68
69        while !self.try_lock_shared() {
70            relax.relax();
71        }
72    }
73
74    #[inline]
75    fn try_lock_shared(&self) -> bool {
76        let value = self.acquire_shared();
77
78        let acquired = value & EXCLUSIVE != EXCLUSIVE;
79
80        if !acquired {
81            unsafe {
82                self.unlock_shared();
83            }
84        }
85
86        acquired
87    }
88
89    #[inline]
90    unsafe fn unlock_shared(&self) {
91        debug_assert!(self.is_locked_shared());
92
93        self.lock.fetch_sub(SHARED, Ordering::Release);
94    }
95
96    #[inline]
97    fn lock_exclusive(&self) {
98        let mut relax = R::default();
99
100        while !self.try_lock_exclusive() {
101            relax.relax();
102        }
103    }
104
105    #[inline]
106    fn try_lock_exclusive(&self) -> bool {
107        self.lock
108            .compare_exchange(0, EXCLUSIVE, Ordering::Acquire, Ordering::Relaxed)
109            .is_ok()
110    }
111
112    #[inline]
113    unsafe fn unlock_exclusive(&self) {
114        debug_assert!(self.is_locked_exclusive());
115
116        self.lock.fetch_and(!EXCLUSIVE, Ordering::Release);
117    }
118
119    #[inline]
120    fn is_locked(&self) -> bool {
121        self.lock.load(Ordering::Relaxed) != 0
122    }
123
124    #[inline]
125    fn is_locked_exclusive(&self) -> bool {
126        self.lock.load(Ordering::Relaxed) & EXCLUSIVE == EXCLUSIVE
127    }
128}
129
130unsafe impl<R: Relax> RawRwLockRecursive for RawRwSpinlock<R> {
131    #[inline]
132    fn lock_shared_recursive(&self) {
133        self.lock_shared();
134    }
135
136    #[inline]
137    fn try_lock_shared_recursive(&self) -> bool {
138        self.try_lock_shared()
139    }
140}
141
142unsafe impl<R: Relax> RawRwLockDowngrade for RawRwSpinlock<R> {
143    #[inline]
144    unsafe fn downgrade(&self) {
145        // Reserve the shared guard for ourselves
146        self.acquire_shared();
147
148        unsafe {
149            self.unlock_exclusive();
150        }
151    }
152}
153
154unsafe impl<R: Relax> RawRwLockUpgrade for RawRwSpinlock<R> {
155    #[inline]
156    fn lock_upgradable(&self) {
157        let mut relax = R::default();
158
159        while !self.try_lock_upgradable() {
160            relax.relax();
161        }
162    }
163
164    #[inline]
165    fn try_lock_upgradable(&self) -> bool {
166        let value = self.lock.fetch_or(UPGRADABLE, Ordering::Acquire);
167
168        let acquired = value & (UPGRADABLE | EXCLUSIVE) == 0;
169
170        if !acquired && value & UPGRADABLE == 0 {
171            unsafe {
172                self.unlock_upgradable();
173            }
174        }
175
176        acquired
177    }
178
179    #[inline]
180    unsafe fn unlock_upgradable(&self) {
181        debug_assert!(self.is_locked_upgradable());
182
183        self.lock.fetch_and(!UPGRADABLE, Ordering::Release);
184    }
185
186    #[inline]
187    unsafe fn upgrade(&self) {
188        let mut relax = R::default();
189
190        while !self.try_upgrade() {
191            relax.relax();
192        }
193    }
194
195    #[inline]
196    unsafe fn try_upgrade(&self) -> bool {
197        self.lock
198            .compare_exchange(UPGRADABLE, EXCLUSIVE, Ordering::Acquire, Ordering::Relaxed)
199            .is_ok()
200    }
201}
202
203unsafe impl<R: Relax> RawRwLockUpgradeDowngrade for RawRwSpinlock<R> {
204    #[inline]
205    unsafe fn downgrade_upgradable(&self) {
206        self.acquire_shared();
207
208        unsafe {
209            self.unlock_upgradable();
210        }
211    }
212
213    #[inline]
214    unsafe fn downgrade_to_upgradable(&self) {
215        debug_assert!(self.is_locked_exclusive());
216
217        self.lock
218            .fetch_xor(UPGRADABLE | EXCLUSIVE, Ordering::Release);
219    }
220}
221
222/// A [`lock_api::RwLock`] based on [`RawRwSpinlock`].
223pub type RwSpinlock<T> = lock_api::RwLock<RawRwSpinlock<Spin>, T>;
224
225/// A [`lock_api::RwLockReadGuard`] based on [`RawRwSpinlock`].
226pub type RwSpinlockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, RawRwSpinlock<Spin>, T>;
227
228/// A [`lock_api::RwLockUpgradableReadGuard`] based on [`RawRwSpinlock`].
229pub type RwSpinlockUpgradableReadGuard<'a, T> =
230    lock_api::RwLockUpgradableReadGuard<'a, RawRwSpinlock<Spin>, T>;
231
232/// A [`lock_api::RwLockWriteGuard`] based on [`RawRwSpinlock`].
233pub type RwSpinlockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, RawRwSpinlock<Spin>, T>;
234
235/// A [`lock_api::ArcRwLockReadGuard`] based on [`RawRwSpinlock`].
236#[cfg(feature = "arc_lock")]
237pub type ArcRwSpinlockReadGuard<T> = lock_api::ArcRwLockReadGuard<RawRwSpinlock<Spin>, T>;
238
239/// A [`lock_api::ArcRwLockUpgradableReadGuard`] based on [`RawRwSpinlock`].
240#[cfg(feature = "arc_lock")]
241pub type ArcRwSpinlockUpgradableReadGuard<T> =
242    lock_api::ArcRwLockUpgradableReadGuard<RawRwSpinlock<Spin>, T>;
243
244/// A [`lock_api::ArcRwLockWriteGuard`] based on [`RawRwSpinlock`].
245#[cfg(feature = "arc_lock")]
246pub type ArcRwSpinlockWriteGuard<T> = lock_api::ArcRwLockWriteGuard<RawRwSpinlock<Spin>, T>;
247
248/// A [`lock_api::RwLock`] based on [`RawRwSpinlock`]`<`[`Backoff`]`>`.
249pub type BackoffRwSpinlock<T> = lock_api::RwLock<RawRwSpinlock<Backoff>, T>;
250
251/// A [`lock_api::RwLockReadGuard`] based on [`RawRwSpinlock`]`<`[`Backoff`]`>`.
252pub type BackoffRwSpinlockReadGuard<'a, T> =
253    lock_api::RwLockReadGuard<'a, RawRwSpinlock<Backoff>, T>;
254
255/// A [`lock_api::RwLockUpgradableReadGuard`] based on [`RawRwSpinlock`]`<`[`Backoff`]`>`.
256pub type BackoffRwSpinlockUpgradableReadGuard<'a, T> =
257    lock_api::RwLockUpgradableReadGuard<'a, RawRwSpinlock<Backoff>, T>;
258
259/// A [`lock_api::RwLockWriteGuard`] based on [`RawRwSpinlock`]`<`[`Backoff`]`>`.
260pub type BackoffRwSpinlockWriteGuard<'a, T> =
261    lock_api::RwLockWriteGuard<'a, RawRwSpinlock<Backoff>, T>;
262
263/// A [`lock_api::ArcRwLockReadGuard`] based on [`RawRwSpinlock`]`<`[`Backoff`]`>`.
264#[cfg(feature = "arc_lock")]
265pub type ArcBackoffRwSpinlockReadGuard<T> = lock_api::ArcRwLockReadGuard<RawRwSpinlock<Backoff>, T>;
266
267/// A [`lock_api::ArcRwLockUpgradableReadGuard`] based on [`RawRwSpinlock`]`<`[`Backoff`]`>`.
268#[cfg(feature = "arc_lock")]
269pub type ArcBackoffRwSpinlockUpgradableReadGuard<T> =
270    lock_api::ArcRwLockUpgradableReadGuard<RawRwSpinlock<Backoff>, T>;
271
272/// A [`lock_api::ArcRwLockWriteGuard`] based on [`RawRwSpinlock`]`<`[`Backoff`]`>`.
273#[cfg(feature = "arc_lock")]
274pub type ArcBackoffRwSpinlockWriteGuard<T> =
275    lock_api::ArcRwLockWriteGuard<RawRwSpinlock<Backoff>, T>;
276
277// Adapted from `spin::rwlock`.
278#[cfg(test)]
279mod tests {
280    use std::sync::atomic::{AtomicUsize, Ordering};
281    use std::sync::mpsc::channel;
282    use std::sync::Arc;
283    use std::{mem, thread};
284
285    use lock_api::{RwLockUpgradableReadGuard, RwLockWriteGuard};
286
287    use super::*;
288
289    #[test]
290    fn test_unlock_shared() {
291        let m: RawRwSpinlock = RawRwSpinlock::INIT;
292        m.lock_shared();
293        m.lock_shared();
294        m.lock_shared();
295        assert!(!m.try_lock_exclusive());
296        unsafe {
297            m.unlock_shared();
298            m.unlock_shared();
299        }
300        assert!(!m.try_lock_exclusive());
301        unsafe {
302            m.unlock_shared();
303        }
304        assert!(m.try_lock_exclusive());
305    }
306
307    #[test]
308    fn test_unlock_exclusive() {
309        let m: RawRwSpinlock = RawRwSpinlock::INIT;
310        m.lock_exclusive();
311        assert!(!m.try_lock_shared());
312        unsafe {
313            m.unlock_exclusive();
314        }
315        assert!(m.try_lock_shared());
316    }
317
318    #[test]
319    fn smoke() {
320        let l = RwSpinlock::new(());
321        drop(l.read());
322        drop(l.write());
323        drop((l.read(), l.read()));
324        drop(l.write());
325    }
326
327    #[test]
328    fn frob() {
329        use rand::Rng;
330
331        static R: RwSpinlock<usize> = RwSpinlock::new(0);
332        const N: usize = 10;
333        const M: usize = 1000;
334
335        let (tx, rx) = channel::<()>();
336        for _ in 0..N {
337            let tx = tx.clone();
338            thread::spawn(move || {
339                let mut rng = rand::thread_rng();
340                for _ in 0..M {
341                    if rng.gen_bool(1.0 / N as f64) {
342                        drop(R.write());
343                    } else {
344                        drop(R.read());
345                    }
346                }
347                drop(tx);
348            });
349        }
350        drop(tx);
351        let _ = rx.recv();
352    }
353
354    #[test]
355    fn test_rw_arc() {
356        let arc = Arc::new(RwSpinlock::new(0));
357        let arc2 = arc.clone();
358        let (tx, rx) = channel();
359
360        thread::spawn(move || {
361            let mut lock = arc2.write();
362            for _ in 0..10 {
363                let tmp = *lock;
364                *lock = -1;
365                thread::yield_now();
366                *lock = tmp + 1;
367            }
368            tx.send(()).unwrap();
369        });
370
371        // Readers try to catch the writer in the act
372        let mut children = Vec::new();
373        for _ in 0..5 {
374            let arc3 = arc.clone();
375            children.push(thread::spawn(move || {
376                let lock = arc3.read();
377                assert!(*lock >= 0);
378            }));
379        }
380
381        // Wait for children to pass their asserts
382        for r in children {
383            assert!(r.join().is_ok());
384        }
385
386        // Wait for writer to finish
387        rx.recv().unwrap();
388        let lock = arc.read();
389        assert_eq!(*lock, 10);
390    }
391
392    #[test]
393    fn test_rw_access_in_unwind() {
394        let arc = Arc::new(RwSpinlock::new(1));
395        let arc2 = arc.clone();
396        let _ = thread::spawn(move || -> () {
397            struct Unwinder {
398                i: Arc<RwSpinlock<isize>>,
399            }
400            impl Drop for Unwinder {
401                fn drop(&mut self) {
402                    let mut lock = self.i.write();
403                    *lock += 1;
404                }
405            }
406            let _u = Unwinder { i: arc2 };
407            panic!();
408        })
409        .join();
410        let lock = arc.read();
411        assert_eq!(*lock, 2);
412    }
413
414    #[test]
415    fn test_rwlock_unsized() {
416        let rw: &RwSpinlock<[i32]> = &RwSpinlock::new([1, 2, 3]);
417        {
418            let b = &mut *rw.write();
419            b[0] = 4;
420            b[2] = 5;
421        }
422        let comp: &[i32] = &[4, 2, 5];
423        assert_eq!(&*rw.read(), comp);
424    }
425
426    #[test]
427    fn test_rwlock_try_write() {
428        let lock = RwSpinlock::new(0isize);
429        let read_guard = lock.read();
430
431        let write_result = lock.try_write();
432        match write_result {
433            None => (),
434            Some(_) => assert!(
435                false,
436                "try_write should not succeed while read_guard is in scope"
437            ),
438        }
439
440        drop(read_guard);
441    }
442
443    #[test]
444    fn test_rw_try_read() {
445        let m = RwSpinlock::new(0);
446        mem::forget(m.write());
447        assert!(m.try_read().is_none());
448    }
449
450    #[test]
451    fn test_into_inner() {
452        let m = RwSpinlock::new(Box::new(10));
453        assert_eq!(m.into_inner(), Box::new(10));
454    }
455
456    #[test]
457    fn test_into_inner_drop() {
458        struct Foo(Arc<AtomicUsize>);
459        impl Drop for Foo {
460            fn drop(&mut self) {
461                self.0.fetch_add(1, Ordering::SeqCst);
462            }
463        }
464        let num_drops = Arc::new(AtomicUsize::new(0));
465        let m = RwSpinlock::new(Foo(num_drops.clone()));
466        assert_eq!(num_drops.load(Ordering::SeqCst), 0);
467        {
468            let _inner = m.into_inner();
469            assert_eq!(num_drops.load(Ordering::SeqCst), 0);
470        }
471        assert_eq!(num_drops.load(Ordering::SeqCst), 1);
472    }
473
474    #[test]
475    fn test_upgrade_downgrade() {
476        let m = RwSpinlock::new(());
477        {
478            let _r = m.read();
479            let upg = m.try_upgradable_read().unwrap();
480            assert!(m.try_read().is_some());
481            assert!(m.try_write().is_none());
482            assert!(RwLockUpgradableReadGuard::try_upgrade(upg).is_err());
483        }
484        {
485            let w = m.write();
486            assert!(m.try_upgradable_read().is_none());
487            let _r = RwLockWriteGuard::downgrade(w);
488            assert!(m.try_upgradable_read().is_some());
489            assert!(m.try_read().is_some());
490            assert!(m.try_write().is_none());
491        }
492        {
493            let _u = m.upgradable_read();
494            assert!(m.try_upgradable_read().is_none());
495        }
496
497        assert!(RwLockUpgradableReadGuard::try_upgrade(m.try_upgradable_read().unwrap()).is_ok());
498    }
499}