hermit_sync/mutex/
ticket.rs

1use core::sync::atomic::{AtomicUsize, Ordering};
2
3use lock_api::{GuardSend, RawMutex, RawMutexFair};
4use spinning_top::relax::{Backoff, Relax};
5
6/// A [fair] [ticket lock] with [exponential backoff].
7///
8/// [fair]: https://en.wikipedia.org/wiki/Unbounded_nondeterminism
9/// [ticket lock]: https://en.wikipedia.org/wiki/Ticket_lock
10/// [exponential backoff]: https://en.wikipedia.org/wiki/Exponential_backoff
11// Based on `spin::mutex::TicketMutex`, but with backoff.
12pub struct RawTicketMutex {
13    next_ticket: AtomicUsize,
14    next_serving: AtomicUsize,
15}
16
17unsafe impl RawMutex for RawTicketMutex {
18    #[allow(clippy::declare_interior_mutable_const)]
19    const INIT: Self = Self {
20        next_ticket: AtomicUsize::new(0),
21        next_serving: AtomicUsize::new(0),
22    };
23
24    type GuardMarker = GuardSend;
25
26    #[inline]
27    fn lock(&self) {
28        let ticket = self.next_ticket.fetch_add(1, Ordering::Relaxed);
29
30        let mut backoff = Backoff::default();
31        while self.next_serving.load(Ordering::Acquire) != ticket {
32            backoff.relax();
33        }
34    }
35
36    #[inline]
37    fn try_lock(&self) -> bool {
38        let ticket = self
39            .next_ticket
40            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |ticket| {
41                if self.next_serving.load(Ordering::Acquire) == ticket {
42                    Some(ticket + 1)
43                } else {
44                    None
45                }
46            });
47
48        ticket.is_ok()
49    }
50
51    #[inline]
52    unsafe fn unlock(&self) {
53        self.next_serving.fetch_add(1, Ordering::Release);
54    }
55
56    #[inline]
57    fn is_locked(&self) -> bool {
58        let ticket = self.next_ticket.load(Ordering::Relaxed);
59        self.next_serving.load(Ordering::Relaxed) != ticket
60    }
61}
62
63unsafe impl RawMutexFair for RawTicketMutex {
64    #[inline]
65    unsafe fn unlock_fair(&self) {
66        unsafe { self.unlock() }
67    }
68
69    #[inline]
70    unsafe fn bump(&self) {
71        let ticket = self.next_ticket.load(Ordering::Relaxed);
72        let serving = self.next_serving.load(Ordering::Relaxed);
73        if serving + 1 != ticket {
74            unsafe {
75                self.unlock_fair();
76                self.lock();
77            }
78        }
79    }
80}
81
82/// A [`lock_api::Mutex`] based on [`RawTicketMutex`].
83pub type TicketMutex<T> = lock_api::Mutex<RawTicketMutex, T>;
84
85/// A [`lock_api::MutexGuard`] based on [`RawTicketMutex`].
86pub type TicketMutexGuard<'a, T> = lock_api::MutexGuard<'a, RawTicketMutex, T>;
87
88// From `spin::mutex::ticket`
89#[cfg(test)]
90mod tests {
91    use std::sync::atomic::{AtomicUsize, Ordering};
92    use std::sync::mpsc::channel;
93    use std::sync::Arc;
94    use std::thread;
95
96    use super::*;
97
98    #[test]
99    fn smoke() {
100        let m = TicketMutex::<_>::new(());
101        drop(m.lock());
102        drop(m.lock());
103    }
104
105    #[test]
106    #[cfg_attr(miri, ignore)]
107    fn lots_and_lots() {
108        static M: TicketMutex<()> = TicketMutex::<_>::new(());
109        static mut CNT: u32 = 0;
110        const J: u32 = 1000;
111        const K: u32 = 3;
112
113        fn inc() {
114            for _ in 0..J {
115                unsafe {
116                    let _g = M.lock();
117                    CNT += 1;
118                }
119            }
120        }
121
122        let (tx, rx) = channel();
123        for _ in 0..K {
124            let tx2 = tx.clone();
125            thread::spawn(move || {
126                inc();
127                tx2.send(()).unwrap();
128            });
129            let tx2 = tx.clone();
130            thread::spawn(move || {
131                inc();
132                tx2.send(()).unwrap();
133            });
134        }
135
136        drop(tx);
137        for _ in 0..2 * K {
138            rx.recv().unwrap();
139        }
140        assert_eq!(unsafe { CNT }, J * K * 2);
141    }
142
143    #[test]
144    fn try_lock() {
145        let mutex = TicketMutex::<_>::new(42);
146
147        // First lock succeeds
148        let a = mutex.try_lock();
149        assert_eq!(a.as_ref().map(|r| **r), Some(42));
150
151        // Additional lock failes
152        let b = mutex.try_lock();
153        assert!(b.is_none());
154
155        // After dropping lock, it succeeds again
156        ::core::mem::drop(a);
157        let c = mutex.try_lock();
158        assert_eq!(c.as_ref().map(|r| **r), Some(42));
159    }
160
161    #[test]
162    fn test_into_inner() {
163        let m = TicketMutex::<_>::new(Box::new(10));
164        assert_eq!(m.into_inner(), Box::new(10));
165    }
166
167    #[test]
168    fn test_into_inner_drop() {
169        struct Foo(Arc<AtomicUsize>);
170        impl Drop for Foo {
171            fn drop(&mut self) {
172                self.0.fetch_add(1, Ordering::SeqCst);
173            }
174        }
175        let num_drops = Arc::new(AtomicUsize::new(0));
176        let m = TicketMutex::<_>::new(Foo(num_drops.clone()));
177        assert_eq!(num_drops.load(Ordering::SeqCst), 0);
178        {
179            let _inner = m.into_inner();
180            assert_eq!(num_drops.load(Ordering::SeqCst), 0);
181        }
182        assert_eq!(num_drops.load(Ordering::SeqCst), 1);
183    }
184
185    #[test]
186    fn test_mutex_arc_nested() {
187        // Tests nested mutexes and access
188        // to underlying data.
189        let arc = Arc::new(TicketMutex::<_>::new(1));
190        let arc2 = Arc::new(TicketMutex::<_>::new(arc));
191        let (tx, rx) = channel();
192        let _t = thread::spawn(move || {
193            let lock = arc2.lock();
194            let lock2 = lock.lock();
195            assert_eq!(*lock2, 1);
196            tx.send(()).unwrap();
197        });
198        rx.recv().unwrap();
199    }
200
201    #[test]
202    fn test_mutex_arc_access_in_unwind() {
203        let arc = Arc::new(TicketMutex::<_>::new(1));
204        let arc2 = arc.clone();
205        let _ = thread::spawn(move || -> () {
206            struct Unwinder {
207                i: Arc<TicketMutex<i32>>,
208            }
209            impl Drop for Unwinder {
210                fn drop(&mut self) {
211                    *self.i.lock() += 1;
212                }
213            }
214            let _u = Unwinder { i: arc2 };
215            panic!();
216        })
217        .join();
218        let lock = arc.lock();
219        assert_eq!(*lock, 2);
220    }
221
222    #[test]
223    fn test_mutex_unsized() {
224        let mutex: &TicketMutex<[i32]> = &TicketMutex::<_>::new([1, 2, 3]);
225        {
226            let b = &mut *mutex.lock();
227            b[0] = 4;
228            b[2] = 5;
229        }
230        let comp: &[i32] = &[4, 2, 5];
231        assert_eq!(&*mutex.lock(), comp);
232    }
233
234    #[test]
235    fn is_locked() {
236        let mutex = TicketMutex::<_>::new(());
237        assert!(!mutex.is_locked());
238        let lock = mutex.lock();
239        assert!(mutex.is_locked());
240        drop(lock);
241        assert!(!mutex.is_locked());
242    }
243}