hermit/fd/socket/
vsock.rs1use alloc::boxed::Box;
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use core::future;
5use core::mem::MaybeUninit;
6use core::task::Poll;
7
8use async_trait::async_trait;
9use virtio::vsock::{Hdr, Op, Type};
10use virtio::{le16, le32, le64};
11
12#[cfg(not(feature = "pci"))]
13use crate::arch::kernel::mmio as hardware;
14#[cfg(feature = "pci")]
15use crate::drivers::pci as hardware;
16use crate::executor::vsock::{VSOCK_MAP, VsockState};
17use crate::fd::{self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent};
18use crate::io::{self, Error};
19
20#[derive(Debug)]
21pub struct VsockListenEndpoint {
22 pub port: u32,
23 pub cid: Option<u32>,
24}
25
26impl VsockListenEndpoint {
27 pub const fn new(port: u32, cid: Option<u32>) -> Self {
28 Self { port, cid }
29 }
30}
31
32#[derive(Debug)]
33pub struct VsockEndpoint {
34 pub port: u32,
35 pub cid: u32,
36}
37
38impl VsockEndpoint {
39 pub const fn new(port: u32, cid: u32) -> Self {
40 Self { port, cid }
41 }
42}
43
44#[derive(Debug)]
45pub struct NullSocket;
46
47impl NullSocket {
48 pub const fn new() -> Self {
49 Self {}
50 }
51}
52
53#[async_trait]
54impl ObjectInterface for async_lock::RwLock<NullSocket> {}
55
56#[derive(Debug)]
57pub struct Socket {
58 port: u32,
59 cid: u32,
60 is_nonblocking: bool,
61}
62
63impl Socket {
64 pub fn new() -> Self {
65 Self {
66 port: 0,
67 cid: u32::MAX,
68 is_nonblocking: false,
69 }
70 }
71
72 async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
73 future::poll_fn(|cx| {
74 let mut guard = VSOCK_MAP.lock();
75 let raw = guard.get_mut_socket(self.port).ok_or(Error::EINVAL)?;
76
77 match raw.state {
78 VsockState::Shutdown | VsockState::ReceiveRequest => {
79 let available = PollEvent::POLLOUT
80 | PollEvent::POLLWRNORM
81 | PollEvent::POLLWRBAND
82 | PollEvent::POLLIN
83 | PollEvent::POLLRDNORM
84 | PollEvent::POLLRDBAND;
85
86 let ret = event & available;
87
88 if ret.is_empty() {
89 Poll::Ready(Ok(PollEvent::POLLHUP))
90 } else {
91 Poll::Ready(Ok(ret))
92 }
93 }
94 VsockState::Listen | VsockState::Connecting => {
95 raw.rx_waker.register(cx.waker());
96 raw.tx_waker.register(cx.waker());
97 Poll::Pending
98 }
99 VsockState::Connected => {
100 let mut available = PollEvent::empty();
101
102 if !raw.buffer.is_empty() {
103 available.insert(
105 PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND,
106 );
107 }
108
109 let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt);
110 if diff < raw.peer_buf_alloc {
111 available.insert(
112 PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND,
113 );
114 }
115
116 let ret = event & available;
117
118 if ret.is_empty() {
119 if event.intersects(
120 PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND,
121 ) {
122 raw.rx_waker.register(cx.waker());
123 }
124
125 if event.intersects(
126 PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND,
127 ) {
128 raw.tx_waker.register(cx.waker());
129 }
130
131 Poll::Pending
132 } else {
133 Poll::Ready(Ok(ret))
134 }
135 }
136 }
137 })
138 .await
139 }
140
141 async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
142 match endpoint {
143 ListenEndpoint::Vsock(ep) => {
144 self.port = ep.port;
145 if let Some(cid) = ep.cid {
146 self.cid = cid;
147 } else {
148 self.cid = u32::MAX;
149 }
150 VSOCK_MAP.lock().bind(ep.port)
151 }
152 #[cfg(any(feature = "tcp", feature = "udp"))]
153 _ => Err(io::Error::EINVAL),
154 }
155 }
156
157 async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
158 match endpoint {
159 Endpoint::Vsock(ep) => {
160 const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
161 let port = VSOCK_MAP.lock().connect(ep.port, ep.cid)?;
162 self.port = port;
163 self.port = ep.cid;
164
165 future::poll_fn(|_cx| {
166 if let Some(mut driver_guard) = hardware::get_vsock_driver().unwrap().try_lock()
167 {
168 let local_cid = driver_guard.get_cid();
169
170 driver_guard.send_packet(HEADER_SIZE, |buffer| {
171 let response = unsafe { &mut *buffer.as_mut_ptr().cast::<Hdr>() };
172
173 response.src_cid = le64::from_ne(local_cid);
174 response.dst_cid = le64::from_ne(ep.cid.into());
175 response.src_port = le32::from_ne(port);
176 response.dst_port = le32::from_ne(ep.port);
177 response.len = le32::from_ne(0);
178 response.type_ = le16::from_ne(Type::Stream.into());
179 response.op = le16::from_ne(Op::Request.into());
180 response.flags = le32::from_ne(0);
181 response.buf_alloc = le32::from_ne(
182 crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32,
183 );
184 response.fwd_cnt = le32::from_ne(0);
185 });
186
187 Poll::Ready(())
188 } else {
189 Poll::Pending
190 }
191 })
192 .await;
193
194 future::poll_fn(|cx| {
195 let mut guard = VSOCK_MAP.lock();
196 let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?;
197
198 match raw.state {
199 VsockState::Connected => Poll::Ready(Ok(())),
200 VsockState::Connecting => {
201 raw.rx_waker.register(cx.waker());
202 Poll::Pending
203 }
204 _ => Poll::Ready(Err(io::Error::EBADF)),
205 }
206 })
207 .await
208 }
209 #[cfg(any(feature = "tcp", feature = "udp"))]
210 _ => Err(io::Error::EINVAL),
211 }
212 }
213
214 async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
215 let guard = VSOCK_MAP.lock();
216 let raw = guard.get_socket(self.port).ok_or(Error::EINVAL)?;
217
218 Ok(Some(Endpoint::Vsock(VsockEndpoint::new(
219 raw.remote_port,
220 raw.remote_cid,
221 ))))
222 }
223
224 async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
225 let local_cid = hardware::get_vsock_driver().unwrap().lock().get_cid();
226
227 Ok(Some(Endpoint::Vsock(VsockEndpoint::new(
228 self.port,
229 local_cid.try_into().unwrap(),
230 ))))
231 }
232
233 async fn listen(&self, _backlog: i32) -> io::Result<()> {
234 Ok(())
235 }
236
237 async fn accept(&mut self) -> io::Result<(NullSocket, Endpoint)> {
238 let port = self.port;
239 let cid = self.cid;
240
241 let endpoint = future::poll_fn(|cx| {
242 let mut guard = VSOCK_MAP.lock();
243 let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?;
244
245 match raw.state {
246 VsockState::Listen => {
247 if self.is_nonblocking {
248 Poll::Ready(Err(io::Error::EAGAIN))
249 } else {
250 raw.rx_waker.register(cx.waker());
251 Poll::Pending
252 }
253 }
254 VsockState::ReceiveRequest => {
255 let result = {
256 const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
257 let mut driver_guard = hardware::get_vsock_driver().unwrap().lock();
258 let local_cid = driver_guard.get_cid();
259
260 driver_guard.send_packet(HEADER_SIZE, |buffer| {
261 let response = unsafe { &mut *buffer.as_mut_ptr().cast::<Hdr>() };
262
263 response.src_cid = le64::from_ne(local_cid);
264 response.dst_cid = le64::from_ne(raw.remote_cid.into());
265 response.src_port = le32::from_ne(port);
266 response.dst_port = le32::from_ne(raw.remote_port);
267 response.len = le32::from_ne(0);
268 response.type_ = le16::from_ne(Type::Stream.into());
269 if local_cid != u64::from(cid) && cid != u32::MAX {
270 response.op = le16::from_ne(Op::Rst.into());
271 } else {
272 response.op = le16::from_ne(Op::Response.into());
273 }
274 response.flags = le32::from_ne(0);
275 response.buf_alloc = le32::from_ne(
276 crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32,
277 );
278 response.fwd_cnt = le32::from_ne(raw.fwd_cnt);
279 });
280
281 raw.state = VsockState::Connected;
282
283 Ok(VsockEndpoint::new(raw.remote_port, raw.remote_cid))
284 };
285
286 Poll::Ready(result)
287 }
288 _ => Poll::Ready(Err(Error::EBADF)),
289 }
290 })
291 .await?;
292
293 Ok((NullSocket::new(), Endpoint::Vsock(endpoint)))
294 }
295
296 async fn shutdown(&self, _how: i32) -> io::Result<()> {
297 Ok(())
298 }
299
300 async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
301 let status_flags = if self.is_nonblocking {
302 fd::StatusFlags::O_NONBLOCK
303 } else {
304 fd::StatusFlags::empty()
305 };
306
307 Ok(status_flags)
308 }
309
310 async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
311 self.is_nonblocking = status_flags.contains(fd::StatusFlags::O_NONBLOCK);
312 Ok(())
313 }
314
315 async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
316 let port = self.port;
317 future::poll_fn(|cx| {
318 let mut guard = VSOCK_MAP.lock();
319 let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?;
320
321 match raw.state {
322 VsockState::Connected => {
323 let len = core::cmp::min(buffer.len(), raw.buffer.len());
324
325 if len == 0 {
326 if self.is_nonblocking {
327 Poll::Ready(Err(io::Error::EAGAIN))
328 } else {
329 raw.rx_waker.register(cx.waker());
330 Poll::Pending
331 }
332 } else {
333 let tmp: Vec<_> = raw.buffer.drain(..len).collect();
334 buffer[..len].write_copy_of_slice(tmp.as_slice());
335
336 Poll::Ready(Ok(len))
337 }
338 }
339 VsockState::Shutdown => {
340 let len = core::cmp::min(buffer.len(), raw.buffer.len());
341
342 if len == 0 {
343 Poll::Ready(Ok(0))
344 } else {
345 let tmp: Vec<_> = raw.buffer.drain(..len).collect();
346 buffer[..len].write_copy_of_slice(tmp.as_slice());
347
348 Poll::Ready(Ok(len))
349 }
350 }
351 _ => Poll::Ready(Err(Error::EIO)),
352 }
353 })
354 .await
355 }
356
357 async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
358 let port = self.port;
359 future::poll_fn(|cx| {
360 let mut guard = VSOCK_MAP.lock();
361 let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?;
362 let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt);
363
364 match raw.state {
365 VsockState::Connected => {
366 if diff >= raw.peer_buf_alloc {
367 if self.is_nonblocking {
368 Poll::Ready(Err(io::Error::EAGAIN))
369 } else {
370 raw.tx_waker.register(cx.waker());
371 Poll::Pending
372 }
373 } else {
374 const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
375 let mut driver_guard = hardware::get_vsock_driver().unwrap().lock();
376 let local_cid = driver_guard.get_cid();
377 let len = core::cmp::min(
378 buffer.len(),
379 usize::try_from(raw.peer_buf_alloc - diff).unwrap(),
380 );
381
382 driver_guard.send_packet(HEADER_SIZE + len, |virtio_buffer| {
383 let response =
384 unsafe { &mut *virtio_buffer.as_mut_ptr().cast::<Hdr>() };
385
386 raw.tx_cnt = raw.tx_cnt.wrapping_add(len.try_into().unwrap());
387 response.src_cid = le64::from_ne(local_cid);
388 response.dst_cid = le64::from_ne(raw.remote_cid.into());
389 response.src_port = le32::from_ne(port);
390 response.dst_port = le32::from_ne(raw.remote_port);
391 response.len = le32::from_ne(len.try_into().unwrap());
392 response.type_ = le16::from_ne(Type::Stream.into());
393 response.op = le16::from_ne(Op::Rw.into());
394 response.flags = le32::from_ne(0);
395 response.buf_alloc = le32::from_ne(
396 crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32,
397 );
398 response.fwd_cnt = le32::from_ne(raw.fwd_cnt);
399
400 virtio_buffer[HEADER_SIZE..HEADER_SIZE + len]
401 .copy_from_slice(&buffer[..len]);
402 });
403
404 Poll::Ready(Ok(len))
405 }
406 }
407 _ => Poll::Ready(Err(Error::EIO)),
408 }
409 })
410 .await
411 }
412}
413
414impl Drop for Socket {
415 fn drop(&mut self) {
416 let mut guard = VSOCK_MAP.lock();
417 guard.remove_socket(self.port);
418 }
419}
420
421#[async_trait]
422impl ObjectInterface for async_lock::RwLock<Socket> {
423 async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
424 self.read().await.poll(event).await
425 }
426
427 async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
428 self.read().await.read(buffer).await
429 }
430
431 async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
432 self.read().await.write(buffer).await
433 }
434
435 async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
436 self.write().await.bind(endpoint).await
437 }
438
439 async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
440 self.write().await.connect(endpoint).await
441 }
442
443 async fn accept(&self) -> io::Result<(Arc<dyn ObjectInterface>, Endpoint)> {
444 let (handle, endpoint) = self.write().await.accept().await?;
445 Ok((Arc::new(async_lock::RwLock::new(handle)), endpoint))
446 }
447
448 async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
449 self.read().await.getpeername().await
450 }
451
452 async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
453 self.read().await.getsockname().await
454 }
455
456 async fn listen(&self, backlog: i32) -> io::Result<()> {
457 self.write().await.listen(backlog).await
458 }
459
460 async fn shutdown(&self, how: i32) -> io::Result<()> {
461 self.read().await.shutdown(how).await
462 }
463
464 async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
465 self.read().await.status_flags().await
466 }
467
468 async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
469 self.write().await.set_status_flags(status_flags).await
470 }
471}