hermit/fd/socket/
vsock.rs

1use alloc::boxed::Box;
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use core::future;
5use core::mem::MaybeUninit;
6use core::task::Poll;
7
8use async_trait::async_trait;
9use virtio::vsock::{Hdr, Op, Type};
10use virtio::{le16, le32, le64};
11
12#[cfg(not(feature = "pci"))]
13use crate::arch::kernel::mmio as hardware;
14#[cfg(feature = "pci")]
15use crate::drivers::pci as hardware;
16use crate::executor::vsock::{VSOCK_MAP, VsockState};
17use crate::fd::{self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent};
18use crate::io::{self, Error};
19
20#[derive(Debug)]
21pub struct VsockListenEndpoint {
22	pub port: u32,
23	pub cid: Option<u32>,
24}
25
26impl VsockListenEndpoint {
27	pub const fn new(port: u32, cid: Option<u32>) -> Self {
28		Self { port, cid }
29	}
30}
31
32#[derive(Debug)]
33pub struct VsockEndpoint {
34	pub port: u32,
35	pub cid: u32,
36}
37
38impl VsockEndpoint {
39	pub const fn new(port: u32, cid: u32) -> Self {
40		Self { port, cid }
41	}
42}
43
44#[derive(Debug)]
45pub struct NullSocket;
46
47impl NullSocket {
48	pub const fn new() -> Self {
49		Self {}
50	}
51}
52
53#[async_trait]
54impl ObjectInterface for async_lock::RwLock<NullSocket> {}
55
56#[derive(Debug)]
57pub struct Socket {
58	port: u32,
59	cid: u32,
60	is_nonblocking: bool,
61}
62
63impl Socket {
64	pub fn new() -> Self {
65		Self {
66			port: 0,
67			cid: u32::MAX,
68			is_nonblocking: false,
69		}
70	}
71
72	async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
73		future::poll_fn(|cx| {
74			let mut guard = VSOCK_MAP.lock();
75			let raw = guard.get_mut_socket(self.port).ok_or(Error::EINVAL)?;
76
77			match raw.state {
78				VsockState::Shutdown | VsockState::ReceiveRequest => {
79					let available = PollEvent::POLLOUT
80						| PollEvent::POLLWRNORM
81						| PollEvent::POLLWRBAND
82						| PollEvent::POLLIN
83						| PollEvent::POLLRDNORM
84						| PollEvent::POLLRDBAND;
85
86					let ret = event & available;
87
88					if ret.is_empty() {
89						Poll::Ready(Ok(PollEvent::POLLHUP))
90					} else {
91						Poll::Ready(Ok(ret))
92					}
93				}
94				VsockState::Listen | VsockState::Connecting => {
95					raw.rx_waker.register(cx.waker());
96					raw.tx_waker.register(cx.waker());
97					Poll::Pending
98				}
99				VsockState::Connected => {
100					let mut available = PollEvent::empty();
101
102					if !raw.buffer.is_empty() {
103						// In case, we just establish a fresh connection in non-blocking mode, we try to read data.
104						available.insert(
105							PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND,
106						);
107					}
108
109					let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt);
110					if diff < raw.peer_buf_alloc {
111						available.insert(
112							PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND,
113						);
114					}
115
116					let ret = event & available;
117
118					if ret.is_empty() {
119						if event.intersects(
120							PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND,
121						) {
122							raw.rx_waker.register(cx.waker());
123						}
124
125						if event.intersects(
126							PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND,
127						) {
128							raw.tx_waker.register(cx.waker());
129						}
130
131						Poll::Pending
132					} else {
133						Poll::Ready(Ok(ret))
134					}
135				}
136			}
137		})
138		.await
139	}
140
141	async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
142		match endpoint {
143			ListenEndpoint::Vsock(ep) => {
144				self.port = ep.port;
145				if let Some(cid) = ep.cid {
146					self.cid = cid;
147				} else {
148					self.cid = u32::MAX;
149				}
150				VSOCK_MAP.lock().bind(ep.port)
151			}
152			#[cfg(any(feature = "tcp", feature = "udp"))]
153			_ => Err(io::Error::EINVAL),
154		}
155	}
156
157	async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
158		match endpoint {
159			Endpoint::Vsock(ep) => {
160				const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
161				let port = VSOCK_MAP.lock().connect(ep.port, ep.cid)?;
162				self.port = port;
163				self.port = ep.cid;
164
165				future::poll_fn(|_cx| {
166					if let Some(mut driver_guard) = hardware::get_vsock_driver().unwrap().try_lock()
167					{
168						let local_cid = driver_guard.get_cid();
169
170						driver_guard.send_packet(HEADER_SIZE, |buffer| {
171							let response = unsafe { &mut *buffer.as_mut_ptr().cast::<Hdr>() };
172
173							response.src_cid = le64::from_ne(local_cid);
174							response.dst_cid = le64::from_ne(ep.cid.into());
175							response.src_port = le32::from_ne(port);
176							response.dst_port = le32::from_ne(ep.port);
177							response.len = le32::from_ne(0);
178							response.type_ = le16::from_ne(Type::Stream.into());
179							response.op = le16::from_ne(Op::Request.into());
180							response.flags = le32::from_ne(0);
181							response.buf_alloc = le32::from_ne(
182								crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32,
183							);
184							response.fwd_cnt = le32::from_ne(0);
185						});
186
187						Poll::Ready(())
188					} else {
189						Poll::Pending
190					}
191				})
192				.await;
193
194				future::poll_fn(|cx| {
195					let mut guard = VSOCK_MAP.lock();
196					let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?;
197
198					match raw.state {
199						VsockState::Connected => Poll::Ready(Ok(())),
200						VsockState::Connecting => {
201							raw.rx_waker.register(cx.waker());
202							Poll::Pending
203						}
204						_ => Poll::Ready(Err(io::Error::EBADF)),
205					}
206				})
207				.await
208			}
209			#[cfg(any(feature = "tcp", feature = "udp"))]
210			_ => Err(io::Error::EINVAL),
211		}
212	}
213
214	async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
215		let guard = VSOCK_MAP.lock();
216		let raw = guard.get_socket(self.port).ok_or(Error::EINVAL)?;
217
218		Ok(Some(Endpoint::Vsock(VsockEndpoint::new(
219			raw.remote_port,
220			raw.remote_cid,
221		))))
222	}
223
224	async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
225		let local_cid = hardware::get_vsock_driver().unwrap().lock().get_cid();
226
227		Ok(Some(Endpoint::Vsock(VsockEndpoint::new(
228			self.port,
229			local_cid.try_into().unwrap(),
230		))))
231	}
232
233	async fn listen(&self, _backlog: i32) -> io::Result<()> {
234		Ok(())
235	}
236
237	async fn accept(&mut self) -> io::Result<(NullSocket, Endpoint)> {
238		let port = self.port;
239		let cid = self.cid;
240
241		let endpoint = future::poll_fn(|cx| {
242			let mut guard = VSOCK_MAP.lock();
243			let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?;
244
245			match raw.state {
246				VsockState::Listen => {
247					if self.is_nonblocking {
248						Poll::Ready(Err(io::Error::EAGAIN))
249					} else {
250						raw.rx_waker.register(cx.waker());
251						Poll::Pending
252					}
253				}
254				VsockState::ReceiveRequest => {
255					let result = {
256						const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
257						let mut driver_guard = hardware::get_vsock_driver().unwrap().lock();
258						let local_cid = driver_guard.get_cid();
259
260						driver_guard.send_packet(HEADER_SIZE, |buffer| {
261							let response = unsafe { &mut *buffer.as_mut_ptr().cast::<Hdr>() };
262
263							response.src_cid = le64::from_ne(local_cid);
264							response.dst_cid = le64::from_ne(raw.remote_cid.into());
265							response.src_port = le32::from_ne(port);
266							response.dst_port = le32::from_ne(raw.remote_port);
267							response.len = le32::from_ne(0);
268							response.type_ = le16::from_ne(Type::Stream.into());
269							if local_cid != u64::from(cid) && cid != u32::MAX {
270								response.op = le16::from_ne(Op::Rst.into());
271							} else {
272								response.op = le16::from_ne(Op::Response.into());
273							}
274							response.flags = le32::from_ne(0);
275							response.buf_alloc = le32::from_ne(
276								crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32,
277							);
278							response.fwd_cnt = le32::from_ne(raw.fwd_cnt);
279						});
280
281						raw.state = VsockState::Connected;
282
283						Ok(VsockEndpoint::new(raw.remote_port, raw.remote_cid))
284					};
285
286					Poll::Ready(result)
287				}
288				_ => Poll::Ready(Err(Error::EBADF)),
289			}
290		})
291		.await?;
292
293		Ok((NullSocket::new(), Endpoint::Vsock(endpoint)))
294	}
295
296	async fn shutdown(&self, _how: i32) -> io::Result<()> {
297		Ok(())
298	}
299
300	async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
301		let status_flags = if self.is_nonblocking {
302			fd::StatusFlags::O_NONBLOCK
303		} else {
304			fd::StatusFlags::empty()
305		};
306
307		Ok(status_flags)
308	}
309
310	async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
311		self.is_nonblocking = status_flags.contains(fd::StatusFlags::O_NONBLOCK);
312		Ok(())
313	}
314
315	async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
316		let port = self.port;
317		future::poll_fn(|cx| {
318			let mut guard = VSOCK_MAP.lock();
319			let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?;
320
321			match raw.state {
322				VsockState::Connected => {
323					let len = core::cmp::min(buffer.len(), raw.buffer.len());
324
325					if len == 0 {
326						if self.is_nonblocking {
327							Poll::Ready(Err(io::Error::EAGAIN))
328						} else {
329							raw.rx_waker.register(cx.waker());
330							Poll::Pending
331						}
332					} else {
333						let tmp: Vec<_> = raw.buffer.drain(..len).collect();
334						buffer[..len].write_copy_of_slice(tmp.as_slice());
335
336						Poll::Ready(Ok(len))
337					}
338				}
339				VsockState::Shutdown => {
340					let len = core::cmp::min(buffer.len(), raw.buffer.len());
341
342					if len == 0 {
343						Poll::Ready(Ok(0))
344					} else {
345						let tmp: Vec<_> = raw.buffer.drain(..len).collect();
346						buffer[..len].write_copy_of_slice(tmp.as_slice());
347
348						Poll::Ready(Ok(len))
349					}
350				}
351				_ => Poll::Ready(Err(Error::EIO)),
352			}
353		})
354		.await
355	}
356
357	async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
358		let port = self.port;
359		future::poll_fn(|cx| {
360			let mut guard = VSOCK_MAP.lock();
361			let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?;
362			let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt);
363
364			match raw.state {
365				VsockState::Connected => {
366					if diff >= raw.peer_buf_alloc {
367						if self.is_nonblocking {
368							Poll::Ready(Err(io::Error::EAGAIN))
369						} else {
370							raw.tx_waker.register(cx.waker());
371							Poll::Pending
372						}
373					} else {
374						const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
375						let mut driver_guard = hardware::get_vsock_driver().unwrap().lock();
376						let local_cid = driver_guard.get_cid();
377						let len = core::cmp::min(
378							buffer.len(),
379							usize::try_from(raw.peer_buf_alloc - diff).unwrap(),
380						);
381
382						driver_guard.send_packet(HEADER_SIZE + len, |virtio_buffer| {
383							let response =
384								unsafe { &mut *virtio_buffer.as_mut_ptr().cast::<Hdr>() };
385
386							raw.tx_cnt = raw.tx_cnt.wrapping_add(len.try_into().unwrap());
387							response.src_cid = le64::from_ne(local_cid);
388							response.dst_cid = le64::from_ne(raw.remote_cid.into());
389							response.src_port = le32::from_ne(port);
390							response.dst_port = le32::from_ne(raw.remote_port);
391							response.len = le32::from_ne(len.try_into().unwrap());
392							response.type_ = le16::from_ne(Type::Stream.into());
393							response.op = le16::from_ne(Op::Rw.into());
394							response.flags = le32::from_ne(0);
395							response.buf_alloc = le32::from_ne(
396								crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32,
397							);
398							response.fwd_cnt = le32::from_ne(raw.fwd_cnt);
399
400							virtio_buffer[HEADER_SIZE..HEADER_SIZE + len]
401								.copy_from_slice(&buffer[..len]);
402						});
403
404						Poll::Ready(Ok(len))
405					}
406				}
407				_ => Poll::Ready(Err(Error::EIO)),
408			}
409		})
410		.await
411	}
412}
413
414impl Drop for Socket {
415	fn drop(&mut self) {
416		let mut guard = VSOCK_MAP.lock();
417		guard.remove_socket(self.port);
418	}
419}
420
421#[async_trait]
422impl ObjectInterface for async_lock::RwLock<Socket> {
423	async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
424		self.read().await.poll(event).await
425	}
426
427	async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
428		self.read().await.read(buffer).await
429	}
430
431	async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
432		self.read().await.write(buffer).await
433	}
434
435	async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
436		self.write().await.bind(endpoint).await
437	}
438
439	async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
440		self.write().await.connect(endpoint).await
441	}
442
443	async fn accept(&self) -> io::Result<(Arc<dyn ObjectInterface>, Endpoint)> {
444		let (handle, endpoint) = self.write().await.accept().await?;
445		Ok((Arc::new(async_lock::RwLock::new(handle)), endpoint))
446	}
447
448	async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
449		self.read().await.getpeername().await
450	}
451
452	async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
453		self.read().await.getsockname().await
454	}
455
456	async fn listen(&self, backlog: i32) -> io::Result<()> {
457		self.write().await.listen(backlog).await
458	}
459
460	async fn shutdown(&self, how: i32) -> io::Result<()> {
461		self.read().await.shutdown(how).await
462	}
463
464	async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
465		self.read().await.status_flags().await
466	}
467
468	async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
469		self.write().await.set_status_flags(status_flags).await
470	}
471}