hermit/executor/
vsock.rs

1use alloc::collections::BTreeMap;
2use alloc::vec::Vec;
3use core::future;
4use core::task::Poll;
5
6use hermit_sync::InterruptTicketMutex;
7use virtio::vsock::{Hdr, Op, Type};
8use virtio::{le16, le32};
9
10#[cfg(not(feature = "pci"))]
11use crate::arch::kernel::mmio as hardware;
12#[cfg(feature = "pci")]
13use crate::drivers::pci as hardware;
14use crate::executor::{WakerRegistration, spawn};
15use crate::io;
16use crate::io::Error::EADDRINUSE;
17
18pub(crate) static VSOCK_MAP: InterruptTicketMutex<VsockMap> =
19	InterruptTicketMutex::new(VsockMap::new());
20
21#[derive(Debug, Copy, Clone, PartialEq)]
22pub(crate) enum VsockState {
23	Listen,
24	ReceiveRequest,
25	Connected,
26	Connecting,
27	Shutdown,
28}
29
30pub(crate) const RAW_SOCKET_BUFFER_SIZE: usize = 256 * 1024;
31
32#[derive(Debug)]
33pub(crate) struct RawSocket {
34	pub remote_cid: u32,
35	pub remote_port: u32,
36	pub fwd_cnt: u32,
37	pub peer_fwd_cnt: u32,
38	pub peer_buf_alloc: u32,
39	pub tx_cnt: u32,
40	pub state: VsockState,
41	pub rx_waker: WakerRegistration,
42	pub tx_waker: WakerRegistration,
43	pub buffer: Vec<u8>,
44}
45
46impl RawSocket {
47	pub fn new(state: VsockState) -> Self {
48		Self {
49			remote_cid: 0,
50			remote_port: 0,
51			fwd_cnt: 0,
52			peer_fwd_cnt: 0,
53			peer_buf_alloc: 0,
54			tx_cnt: 0,
55			state,
56			rx_waker: WakerRegistration::new(),
57			tx_waker: WakerRegistration::new(),
58			buffer: Vec::with_capacity(RAW_SOCKET_BUFFER_SIZE),
59		}
60	}
61}
62
63async fn vsock_run() {
64	future::poll_fn(|_cx| {
65		if let Some(driver) = hardware::get_vsock_driver() {
66			const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
67			let mut driver_guard = driver.lock();
68			let mut hdr: Option<Hdr> = None;
69			let mut fwd_cnt: u32 = 0;
70
71			driver_guard.process_packet(|header, data| {
72				let op = Op::try_from(header.op.to_ne()).unwrap();
73				let port = header.dst_port.to_ne();
74				let type_ = Type::try_from(header.type_.to_ne()).unwrap();
75				let mut vsock_guard = VSOCK_MAP.lock();
76				let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap();
77
78				if let Some(raw) = vsock_guard.get_mut_socket(port) {
79					if op == Op::Request && raw.state == VsockState::Listen && type_ == Type::Stream
80					{
81						raw.state = VsockState::ReceiveRequest;
82						raw.remote_cid = header_cid;
83						raw.remote_port = header.src_port.to_ne();
84						raw.peer_buf_alloc = header.buf_alloc.to_ne();
85						raw.rx_waker.wake();
86					} else if (raw.state == VsockState::Connected
87						|| raw.state == VsockState::Shutdown)
88						&& type_ == Type::Stream
89						&& op == Op::Rw
90					{
91						if raw.remote_cid == header_cid {
92							raw.buffer.extend_from_slice(data);
93							raw.fwd_cnt =
94								raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap());
95							raw.peer_fwd_cnt = header.fwd_cnt.to_ne();
96							raw.tx_waker.wake();
97							raw.rx_waker.wake();
98							hdr = Some(*header);
99							fwd_cnt = raw.fwd_cnt;
100						} else {
101							trace!("Receive message from invalid source {header_cid}");
102						}
103					} else if op == Op::CreditUpdate {
104						if raw.remote_cid == header_cid {
105							raw.peer_fwd_cnt = header.fwd_cnt.to_ne();
106							raw.tx_waker.wake();
107						} else {
108							trace!("Receive message from invalid source {header_cid}");
109						}
110					} else if op == Op::Shutdown {
111						if raw.remote_cid == header_cid {
112							raw.state = VsockState::Shutdown;
113						} else {
114							trace!("Receive message from invalid source {header_cid}");
115						}
116					} else if op == Op::Response && type_ == Type::Stream {
117						if raw.remote_cid == header_cid && raw.state == VsockState::Connecting {
118							raw.state = VsockState::Connected;
119						}
120					} else if raw.remote_cid == header_cid {
121						hdr = Some(*header);
122						fwd_cnt = raw.fwd_cnt;
123					}
124				}
125			});
126
127			if let Some(hdr) = hdr {
128				driver_guard.send_packet(HEADER_SIZE, |buffer| {
129					let response = unsafe { &mut *buffer.as_mut_ptr().cast::<Hdr>() };
130
131					response.src_cid = hdr.dst_cid;
132					response.dst_cid = hdr.src_cid;
133					response.src_port = hdr.dst_port;
134					response.dst_port = hdr.src_port;
135					response.len = le32::from_ne(0);
136					response.type_ = hdr.type_;
137					if hdr.op.to_ne() == u16::from(Op::CreditRequest)
138						|| hdr.op.to_ne() == u16::from(Op::Rw)
139					{
140						response.op = le16::from_ne(Op::CreditUpdate.into());
141					} else {
142						// reset connection
143						response.op = le16::from_ne(Op::Rst.into());
144					}
145					response.flags = le32::from_ne(0);
146					response.buf_alloc = le32::from_ne(RAW_SOCKET_BUFFER_SIZE as u32);
147					response.fwd_cnt = le32::from_ne(fwd_cnt);
148				});
149			}
150
151			Poll::Pending
152		} else {
153			Poll::Ready(())
154		}
155	})
156	.await;
157}
158
159pub(crate) struct VsockMap {
160	port_map: BTreeMap<u32, RawSocket>,
161}
162
163impl VsockMap {
164	pub const fn new() -> Self {
165		Self {
166			port_map: BTreeMap::new(),
167		}
168	}
169
170	pub fn bind(&mut self, port: u32) -> io::Result<()> {
171		self.port_map
172			.try_insert(port, RawSocket::new(VsockState::Listen))
173			.map_err(|_| EADDRINUSE)?;
174		Ok(())
175	}
176
177	pub fn connect(&mut self, port: u32, cid: u32) -> io::Result<u32> {
178		for i in u32::MAX / 4..u32::MAX {
179			let mut raw = RawSocket::new(VsockState::Connecting);
180			raw.remote_cid = cid;
181			raw.remote_port = port;
182
183			if self.port_map.try_insert(i, raw).is_ok() {
184				return Ok(i);
185			}
186		}
187
188		Err(io::Error::EBADF)
189	}
190
191	pub fn get_socket(&self, port: u32) -> Option<&RawSocket> {
192		self.port_map.get(&port)
193	}
194
195	pub fn get_mut_socket(&mut self, port: u32) -> Option<&mut RawSocket> {
196		self.port_map.get_mut(&port)
197	}
198
199	pub fn remove_socket(&mut self, port: u32) {
200		let _ = self.port_map.remove(&port);
201	}
202}
203
204pub(crate) fn init() {
205	info!("Try to initialize vsock interface!");
206
207	spawn(vsock_run());
208}