hermit/fd/
eventfd.rs

1use alloc::boxed::Box;
2use alloc::collections::vec_deque::VecDeque;
3use core::future::{self, Future};
4use core::mem::{self, MaybeUninit};
5use core::task::{Poll, Waker, ready};
6
7use async_lock::Mutex;
8use async_trait::async_trait;
9
10use crate::fd::{EventFlags, ObjectInterface, PollEvent};
11use crate::io;
12
13#[derive(Debug)]
14struct EventState {
15	pub counter: u64,
16	pub read_queue: VecDeque<Waker>,
17	pub write_queue: VecDeque<Waker>,
18}
19
20impl EventState {
21	pub fn new(counter: u64) -> Self {
22		Self {
23			counter,
24			read_queue: VecDeque::new(),
25			write_queue: VecDeque::new(),
26		}
27	}
28}
29
30#[derive(Debug)]
31pub(crate) struct EventFd {
32	state: Mutex<EventState>,
33	flags: EventFlags,
34}
35
36impl EventFd {
37	pub fn new(initval: u64, flags: EventFlags) -> Self {
38		debug!("Create EventFd {initval}, {flags:?}");
39		Self {
40			state: Mutex::new(EventState::new(initval)),
41			flags,
42		}
43	}
44}
45
46#[async_trait]
47impl ObjectInterface for EventFd {
48	async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
49		let len = mem::size_of::<u64>();
50
51		if buf.len() < len {
52			return Err(io::Error::EINVAL);
53		}
54
55		future::poll_fn(|cx| {
56			if self.flags.contains(EventFlags::EFD_SEMAPHORE) {
57				let mut pinned = core::pin::pin!(self.state.lock());
58				let mut guard = ready!(pinned.as_mut().poll(cx));
59				if guard.counter > 0 {
60					guard.counter -= 1;
61					buf[..len].write_copy_of_slice(&u64::to_ne_bytes(1));
62					if let Some(cx) = guard.write_queue.pop_front() {
63						cx.wake_by_ref();
64					}
65					Poll::Ready(Ok(len))
66				} else {
67					guard.read_queue.push_back(cx.waker().clone());
68					Poll::Pending
69				}
70			} else {
71				let mut pinned = core::pin::pin!(self.state.lock());
72				let mut guard = ready!(pinned.as_mut().poll(cx));
73				let tmp = guard.counter;
74				if tmp > 0 {
75					guard.counter = 0;
76					buf[..len].write_copy_of_slice(&u64::to_ne_bytes(tmp));
77					if let Some(cx) = guard.read_queue.pop_front() {
78						cx.wake_by_ref();
79					}
80					Poll::Ready(Ok(len))
81				} else if self.flags.contains(EventFlags::EFD_NONBLOCK) {
82					Poll::Ready(Err(io::Error::EAGAIN))
83				} else {
84					guard.read_queue.push_back(cx.waker().clone());
85					Poll::Pending
86				}
87			}
88		})
89		.await
90	}
91
92	async fn write(&self, buf: &[u8]) -> io::Result<usize> {
93		let len = mem::size_of::<u64>();
94
95		if buf.len() < len {
96			return Err(io::Error::EINVAL);
97		}
98
99		let c = u64::from_ne_bytes(buf[..len].try_into().unwrap());
100
101		future::poll_fn(|cx| {
102			let mut pinned = core::pin::pin!(self.state.lock());
103			let mut guard = ready!(pinned.as_mut().poll(cx));
104			if u64::MAX - guard.counter > c {
105				guard.counter += c;
106				if self.flags.contains(EventFlags::EFD_SEMAPHORE) {
107					for _i in 0..c {
108						if let Some(cx) = guard.read_queue.pop_front() {
109							cx.wake_by_ref();
110						} else {
111							break;
112						}
113					}
114				} else if let Some(cx) = guard.read_queue.pop_front() {
115					cx.wake_by_ref();
116				}
117
118				Poll::Ready(Ok(len))
119			} else if self.flags.contains(EventFlags::EFD_NONBLOCK) {
120				Poll::Ready(Err(io::Error::EAGAIN))
121			} else {
122				guard.write_queue.push_back(cx.waker().clone());
123				Poll::Pending
124			}
125		})
126		.await
127	}
128
129	async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
130		let guard = self.state.lock().await;
131
132		let mut available = PollEvent::empty();
133
134		if guard.counter < u64::MAX - 1 {
135			available.insert(PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND);
136		}
137
138		if guard.counter > 0 {
139			available.insert(PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND);
140		}
141
142		drop(guard);
143
144		let ret = event & available;
145
146		future::poll_fn(|cx| {
147			if ret.is_empty() {
148				let mut pinned = core::pin::pin!(self.state.lock());
149				let mut guard = ready!(pinned.as_mut().poll(cx));
150				if event
151					.intersects(PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDNORM)
152				{
153					guard.read_queue.push_back(cx.waker().clone());
154					Poll::Pending
155				} else if event
156					.intersects(PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND)
157				{
158					guard.write_queue.push_back(cx.waker().clone());
159					Poll::Pending
160				} else {
161					Poll::Ready(Ok(ret))
162				}
163			} else {
164				Poll::Ready(Ok(ret))
165			}
166		})
167		.await
168	}
169}