hermit_sync/mutex/
ticket.rs1use core::sync::atomic::{AtomicUsize, Ordering};
2
3use lock_api::{GuardSend, RawMutex, RawMutexFair};
4use spinning_top::relax::{Backoff, Relax};
5
6pub struct RawTicketMutex {
13 next_ticket: AtomicUsize,
14 next_serving: AtomicUsize,
15}
16
17unsafe impl RawMutex for RawTicketMutex {
18 #[allow(clippy::declare_interior_mutable_const)]
19 const INIT: Self = Self {
20 next_ticket: AtomicUsize::new(0),
21 next_serving: AtomicUsize::new(0),
22 };
23
24 type GuardMarker = GuardSend;
25
26 #[inline]
27 fn lock(&self) {
28 let ticket = self.next_ticket.fetch_add(1, Ordering::Relaxed);
29
30 let mut backoff = Backoff::default();
31 while self.next_serving.load(Ordering::Acquire) != ticket {
32 backoff.relax();
33 }
34 }
35
36 #[inline]
37 fn try_lock(&self) -> bool {
38 let ticket = self
39 .next_ticket
40 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |ticket| {
41 if self.next_serving.load(Ordering::Acquire) == ticket {
42 Some(ticket + 1)
43 } else {
44 None
45 }
46 });
47
48 ticket.is_ok()
49 }
50
51 #[inline]
52 unsafe fn unlock(&self) {
53 self.next_serving.fetch_add(1, Ordering::Release);
54 }
55
56 #[inline]
57 fn is_locked(&self) -> bool {
58 let ticket = self.next_ticket.load(Ordering::Relaxed);
59 self.next_serving.load(Ordering::Relaxed) != ticket
60 }
61}
62
63unsafe impl RawMutexFair for RawTicketMutex {
64 #[inline]
65 unsafe fn unlock_fair(&self) {
66 unsafe { self.unlock() }
67 }
68
69 #[inline]
70 unsafe fn bump(&self) {
71 let ticket = self.next_ticket.load(Ordering::Relaxed);
72 let serving = self.next_serving.load(Ordering::Relaxed);
73 if serving + 1 != ticket {
74 unsafe {
75 self.unlock_fair();
76 self.lock();
77 }
78 }
79 }
80}
81
82pub type TicketMutex<T> = lock_api::Mutex<RawTicketMutex, T>;
84
85pub type TicketMutexGuard<'a, T> = lock_api::MutexGuard<'a, RawTicketMutex, T>;
87
88#[cfg(test)]
90mod tests {
91 use std::sync::atomic::{AtomicUsize, Ordering};
92 use std::sync::mpsc::channel;
93 use std::sync::Arc;
94 use std::thread;
95
96 use super::*;
97
98 #[test]
99 fn smoke() {
100 let m = TicketMutex::<_>::new(());
101 drop(m.lock());
102 drop(m.lock());
103 }
104
105 #[test]
106 #[cfg_attr(miri, ignore)]
107 fn lots_and_lots() {
108 static M: TicketMutex<()> = TicketMutex::<_>::new(());
109 static mut CNT: u32 = 0;
110 const J: u32 = 1000;
111 const K: u32 = 3;
112
113 fn inc() {
114 for _ in 0..J {
115 unsafe {
116 let _g = M.lock();
117 CNT += 1;
118 }
119 }
120 }
121
122 let (tx, rx) = channel();
123 for _ in 0..K {
124 let tx2 = tx.clone();
125 thread::spawn(move || {
126 inc();
127 tx2.send(()).unwrap();
128 });
129 let tx2 = tx.clone();
130 thread::spawn(move || {
131 inc();
132 tx2.send(()).unwrap();
133 });
134 }
135
136 drop(tx);
137 for _ in 0..2 * K {
138 rx.recv().unwrap();
139 }
140 assert_eq!(unsafe { CNT }, J * K * 2);
141 }
142
143 #[test]
144 fn try_lock() {
145 let mutex = TicketMutex::<_>::new(42);
146
147 let a = mutex.try_lock();
149 assert_eq!(a.as_ref().map(|r| **r), Some(42));
150
151 let b = mutex.try_lock();
153 assert!(b.is_none());
154
155 ::core::mem::drop(a);
157 let c = mutex.try_lock();
158 assert_eq!(c.as_ref().map(|r| **r), Some(42));
159 }
160
161 #[test]
162 fn test_into_inner() {
163 let m = TicketMutex::<_>::new(Box::new(10));
164 assert_eq!(m.into_inner(), Box::new(10));
165 }
166
167 #[test]
168 fn test_into_inner_drop() {
169 struct Foo(Arc<AtomicUsize>);
170 impl Drop for Foo {
171 fn drop(&mut self) {
172 self.0.fetch_add(1, Ordering::SeqCst);
173 }
174 }
175 let num_drops = Arc::new(AtomicUsize::new(0));
176 let m = TicketMutex::<_>::new(Foo(num_drops.clone()));
177 assert_eq!(num_drops.load(Ordering::SeqCst), 0);
178 {
179 let _inner = m.into_inner();
180 assert_eq!(num_drops.load(Ordering::SeqCst), 0);
181 }
182 assert_eq!(num_drops.load(Ordering::SeqCst), 1);
183 }
184
185 #[test]
186 fn test_mutex_arc_nested() {
187 let arc = Arc::new(TicketMutex::<_>::new(1));
190 let arc2 = Arc::new(TicketMutex::<_>::new(arc));
191 let (tx, rx) = channel();
192 let _t = thread::spawn(move || {
193 let lock = arc2.lock();
194 let lock2 = lock.lock();
195 assert_eq!(*lock2, 1);
196 tx.send(()).unwrap();
197 });
198 rx.recv().unwrap();
199 }
200
201 #[test]
202 fn test_mutex_arc_access_in_unwind() {
203 let arc = Arc::new(TicketMutex::<_>::new(1));
204 let arc2 = arc.clone();
205 let _ = thread::spawn(move || -> () {
206 struct Unwinder {
207 i: Arc<TicketMutex<i32>>,
208 }
209 impl Drop for Unwinder {
210 fn drop(&mut self) {
211 *self.i.lock() += 1;
212 }
213 }
214 let _u = Unwinder { i: arc2 };
215 panic!();
216 })
217 .join();
218 let lock = arc.lock();
219 assert_eq!(*lock, 2);
220 }
221
222 #[test]
223 fn test_mutex_unsized() {
224 let mutex: &TicketMutex<[i32]> = &TicketMutex::<_>::new([1, 2, 3]);
225 {
226 let b = &mut *mutex.lock();
227 b[0] = 4;
228 b[2] = 5;
229 }
230 let comp: &[i32] = &[4, 2, 5];
231 assert_eq!(&*mutex.lock(), comp);
232 }
233
234 #[test]
235 fn is_locked() {
236 let mutex = TicketMutex::<_>::new(());
237 assert!(!mutex.is_locked());
238 let lock = mutex.lock();
239 assert!(mutex.is_locked());
240 drop(lock);
241 assert!(!mutex.is_locked());
242 }
243}