1use alloc::collections::BTreeMap;
2use alloc::vec::Vec;
3use core::future;
4use core::task::Poll;
5
6use hermit_sync::InterruptTicketMutex;
7use virtio::vsock::{Hdr, Op, Type};
8use virtio::{le16, le32};
9
10#[cfg(not(feature = "pci"))]
11use crate::arch::kernel::mmio as hardware;
12#[cfg(feature = "pci")]
13use crate::drivers::pci as hardware;
14use crate::executor::{WakerRegistration, spawn};
15use crate::io;
16use crate::io::Error::EADDRINUSE;
17
18pub(crate) static VSOCK_MAP: InterruptTicketMutex<VsockMap> =
19 InterruptTicketMutex::new(VsockMap::new());
20
21#[derive(Debug, Copy, Clone, PartialEq)]
22pub(crate) enum VsockState {
23 Listen,
24 ReceiveRequest,
25 Connected,
26 Connecting,
27 Shutdown,
28}
29
30pub(crate) const RAW_SOCKET_BUFFER_SIZE: usize = 256 * 1024;
31
32#[derive(Debug)]
33pub(crate) struct RawSocket {
34 pub remote_cid: u32,
35 pub remote_port: u32,
36 pub fwd_cnt: u32,
37 pub peer_fwd_cnt: u32,
38 pub peer_buf_alloc: u32,
39 pub tx_cnt: u32,
40 pub state: VsockState,
41 pub rx_waker: WakerRegistration,
42 pub tx_waker: WakerRegistration,
43 pub buffer: Vec<u8>,
44}
45
46impl RawSocket {
47 pub fn new(state: VsockState) -> Self {
48 Self {
49 remote_cid: 0,
50 remote_port: 0,
51 fwd_cnt: 0,
52 peer_fwd_cnt: 0,
53 peer_buf_alloc: 0,
54 tx_cnt: 0,
55 state,
56 rx_waker: WakerRegistration::new(),
57 tx_waker: WakerRegistration::new(),
58 buffer: Vec::with_capacity(RAW_SOCKET_BUFFER_SIZE),
59 }
60 }
61}
62
63async fn vsock_run() {
64 future::poll_fn(|_cx| {
65 if let Some(driver) = hardware::get_vsock_driver() {
66 const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
67 let mut driver_guard = driver.lock();
68 let mut hdr: Option<Hdr> = None;
69 let mut fwd_cnt: u32 = 0;
70
71 driver_guard.process_packet(|header, data| {
72 let op = Op::try_from(header.op.to_ne()).unwrap();
73 let port = header.dst_port.to_ne();
74 let type_ = Type::try_from(header.type_.to_ne()).unwrap();
75 let mut vsock_guard = VSOCK_MAP.lock();
76 let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap();
77
78 if let Some(raw) = vsock_guard.get_mut_socket(port) {
79 if op == Op::Request && raw.state == VsockState::Listen && type_ == Type::Stream
80 {
81 raw.state = VsockState::ReceiveRequest;
82 raw.remote_cid = header_cid;
83 raw.remote_port = header.src_port.to_ne();
84 raw.peer_buf_alloc = header.buf_alloc.to_ne();
85 raw.rx_waker.wake();
86 } else if (raw.state == VsockState::Connected
87 || raw.state == VsockState::Shutdown)
88 && type_ == Type::Stream
89 && op == Op::Rw
90 {
91 if raw.remote_cid == header_cid {
92 raw.buffer.extend_from_slice(data);
93 raw.fwd_cnt =
94 raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap());
95 raw.peer_fwd_cnt = header.fwd_cnt.to_ne();
96 raw.tx_waker.wake();
97 raw.rx_waker.wake();
98 hdr = Some(*header);
99 fwd_cnt = raw.fwd_cnt;
100 } else {
101 trace!("Receive message from invalid source {header_cid}");
102 }
103 } else if op == Op::CreditUpdate {
104 if raw.remote_cid == header_cid {
105 raw.peer_fwd_cnt = header.fwd_cnt.to_ne();
106 raw.tx_waker.wake();
107 } else {
108 trace!("Receive message from invalid source {header_cid}");
109 }
110 } else if op == Op::Shutdown {
111 if raw.remote_cid == header_cid {
112 raw.state = VsockState::Shutdown;
113 } else {
114 trace!("Receive message from invalid source {header_cid}");
115 }
116 } else if op == Op::Response && type_ == Type::Stream {
117 if raw.remote_cid == header_cid && raw.state == VsockState::Connecting {
118 raw.state = VsockState::Connected;
119 }
120 } else if raw.remote_cid == header_cid {
121 hdr = Some(*header);
122 fwd_cnt = raw.fwd_cnt;
123 }
124 }
125 });
126
127 if let Some(hdr) = hdr {
128 driver_guard.send_packet(HEADER_SIZE, |buffer| {
129 let response = unsafe { &mut *buffer.as_mut_ptr().cast::<Hdr>() };
130
131 response.src_cid = hdr.dst_cid;
132 response.dst_cid = hdr.src_cid;
133 response.src_port = hdr.dst_port;
134 response.dst_port = hdr.src_port;
135 response.len = le32::from_ne(0);
136 response.type_ = hdr.type_;
137 if hdr.op.to_ne() == u16::from(Op::CreditRequest)
138 || hdr.op.to_ne() == u16::from(Op::Rw)
139 {
140 response.op = le16::from_ne(Op::CreditUpdate.into());
141 } else {
142 response.op = le16::from_ne(Op::Rst.into());
144 }
145 response.flags = le32::from_ne(0);
146 response.buf_alloc = le32::from_ne(RAW_SOCKET_BUFFER_SIZE as u32);
147 response.fwd_cnt = le32::from_ne(fwd_cnt);
148 });
149 }
150
151 Poll::Pending
152 } else {
153 Poll::Ready(())
154 }
155 })
156 .await;
157}
158
159pub(crate) struct VsockMap {
160 port_map: BTreeMap<u32, RawSocket>,
161}
162
163impl VsockMap {
164 pub const fn new() -> Self {
165 Self {
166 port_map: BTreeMap::new(),
167 }
168 }
169
170 pub fn bind(&mut self, port: u32) -> io::Result<()> {
171 self.port_map
172 .try_insert(port, RawSocket::new(VsockState::Listen))
173 .map_err(|_| EADDRINUSE)?;
174 Ok(())
175 }
176
177 pub fn connect(&mut self, port: u32, cid: u32) -> io::Result<u32> {
178 for i in u32::MAX / 4..u32::MAX {
179 let mut raw = RawSocket::new(VsockState::Connecting);
180 raw.remote_cid = cid;
181 raw.remote_port = port;
182
183 if self.port_map.try_insert(i, raw).is_ok() {
184 return Ok(i);
185 }
186 }
187
188 Err(io::Error::EBADF)
189 }
190
191 pub fn get_socket(&self, port: u32) -> Option<&RawSocket> {
192 self.port_map.get(&port)
193 }
194
195 pub fn get_mut_socket(&mut self, port: u32) -> Option<&mut RawSocket> {
196 self.port_map.get_mut(&port)
197 }
198
199 pub fn remove_socket(&mut self, port: u32) {
200 let _ = self.port_map.remove(&port);
201 }
202}
203
204pub(crate) fn init() {
205 info!("Try to initialize vsock interface!");
206
207 spawn(vsock_run());
208}