hermit/fd/socket/
tcp.rs

1use alloc::boxed::Box;
2use alloc::collections::BTreeSet;
3use alloc::sync::Arc;
4use core::future;
5use core::mem::MaybeUninit;
6use core::sync::atomic::{AtomicU16, Ordering};
7use core::task::Poll;
8
9use async_trait::async_trait;
10use smoltcp::iface;
11use smoltcp::socket::tcp;
12use smoltcp::time::Duration;
13
14use crate::executor::block_on;
15use crate::executor::network::{Handle, NIC};
16use crate::fd::{self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent, SocketOption};
17use crate::{DEFAULT_KEEP_ALIVE_INTERVAL, io};
18
19/// further receives will be disallowed
20pub const SHUT_RD: i32 = 0;
21/// further sends will be disallowed
22pub const SHUT_WR: i32 = 1;
23/// further sends and receives will be disallowed
24pub const SHUT_RDWR: i32 = 2;
25/// The default queue size for incoming connections
26pub const DEFAULT_BACKLOG: i32 = 128;
27
28fn get_ephemeral_port() -> u16 {
29	static LOCAL_ENDPOINT: AtomicU16 = AtomicU16::new(49152);
30
31	LOCAL_ENDPOINT.fetch_add(1, Ordering::SeqCst)
32}
33
34#[derive(Debug)]
35pub struct Socket {
36	handle: BTreeSet<Handle>,
37	port: u16,
38	is_nonblocking: bool,
39	is_listen: bool,
40	// FIXME: remove once the ecosystem has migrated away from `AF_INET_OLD`.
41	domain: i32,
42}
43
44impl Socket {
45	pub fn new(h: Handle, domain: i32) -> Self {
46		let mut handle = BTreeSet::new();
47		handle.insert(h);
48
49		Self {
50			handle,
51			port: 0,
52			is_nonblocking: false,
53			is_listen: false,
54			domain,
55		}
56	}
57
58	fn with<R>(&self, f: impl FnOnce(&mut tcp::Socket<'_>) -> R) -> R {
59		let mut guard = NIC.lock();
60		let nic = guard.as_nic_mut().unwrap();
61		f(nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap()))
62	}
63
64	fn with_context<R>(&self, f: impl FnOnce(&mut tcp::Socket<'_>, &mut iface::Context) -> R) -> R {
65		let mut guard = NIC.lock();
66		let nic = guard.as_nic_mut().unwrap();
67		let (s, cx) = nic.get_socket_and_context::<tcp::Socket<'_>>(*self.handle.first().unwrap());
68		f(s, cx)
69	}
70
71	async fn close(&self) -> io::Result<()> {
72		future::poll_fn(|_cx| {
73			self.with(|socket| {
74				if socket.is_active() {
75					socket.close();
76					Poll::Ready(Ok(()))
77				} else {
78					Poll::Ready(Err(io::Error::EIO))
79				}
80			})
81		})
82		.await?;
83
84		if self.handle.len() > 1 {
85			let mut guard = NIC.lock();
86			let nic = guard.as_nic_mut().unwrap();
87
88			for handle in self.handle.iter().skip(1) {
89				let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*handle);
90				if socket.is_active() {
91					socket.close();
92				}
93			}
94		}
95
96		future::poll_fn(|cx| {
97			self.with(|socket| {
98				if socket.is_active() {
99					socket.register_send_waker(cx.waker());
100					socket.register_recv_waker(cx.waker());
101					Poll::Pending
102				} else {
103					Poll::Ready(Ok(()))
104				}
105			})
106		})
107		.await
108	}
109
110	async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
111		future::poll_fn(|cx| {
112			self.with(|socket| match socket.state() {
113				tcp::State::Closed | tcp::State::Closing | tcp::State::CloseWait => {
114					let available = PollEvent::POLLOUT
115						| PollEvent::POLLWRNORM
116						| PollEvent::POLLWRBAND
117						| PollEvent::POLLIN
118						| PollEvent::POLLRDNORM
119						| PollEvent::POLLRDBAND;
120
121					let ret = event & available;
122
123					if ret.is_empty() {
124						Poll::Ready(Ok(PollEvent::POLLHUP))
125					} else {
126						Poll::Ready(Ok(ret))
127					}
128				}
129				tcp::State::FinWait1 | tcp::State::FinWait2 | tcp::State::TimeWait => {
130					Poll::Ready(Ok(PollEvent::POLLHUP))
131				}
132				tcp::State::Listen => {
133					socket.register_recv_waker(cx.waker());
134					socket.register_send_waker(cx.waker());
135					Poll::Pending
136				}
137				_ => {
138					let mut available = PollEvent::empty();
139
140					if socket.can_recv() || socket.may_recv() && self.is_listen {
141						// In case, we just establish a fresh connection in non-blocking mode, we try to read data.
142						available.insert(
143							PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND,
144						);
145					}
146
147					if socket.can_send() {
148						available.insert(
149							PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND,
150						);
151					}
152
153					let ret = event & available;
154
155					if ret.is_empty() {
156						if event.intersects(
157							PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND,
158						) {
159							socket.register_recv_waker(cx.waker());
160						}
161
162						if event.intersects(
163							PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND,
164						) {
165							socket.register_send_waker(cx.waker());
166						}
167
168						Poll::Pending
169					} else {
170						Poll::Ready(Ok(ret))
171					}
172				}
173			})
174		})
175		.await
176	}
177
178	async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
179		future::poll_fn(|cx| {
180			self.with(|socket| {
181				let state = socket.state();
182				match state {
183					tcp::State::Closed => Poll::Ready(Ok(0)),
184					tcp::State::FinWait1
185					| tcp::State::FinWait2
186					| tcp::State::Listen
187					| tcp::State::TimeWait => Poll::Ready(Err(io::Error::EIO)),
188					_ => {
189						if socket.can_recv() {
190							Poll::Ready(
191								socket
192									.recv(|data| {
193										let len = core::cmp::min(buffer.len(), data.len());
194										buffer[..len].write_copy_of_slice(&data[..len]);
195										(len, len)
196									})
197									.map_err(|_| io::Error::EIO),
198							)
199						} else if state == tcp::State::CloseWait {
200							// The local end-point has received a connection termination request
201							// and not data are in the receive buffer => return 0 to close the connection
202							Poll::Ready(Ok(0))
203						} else if self.is_nonblocking {
204							Poll::Ready(Err(io::Error::EAGAIN))
205						} else {
206							socket.register_recv_waker(cx.waker());
207							Poll::Pending
208						}
209					}
210				}
211			})
212		})
213		.await
214	}
215
216	async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
217		let mut pos: usize = 0;
218
219		while pos < buffer.len() {
220			let n = future::poll_fn(|cx| {
221				self.with(|socket| {
222					match socket.state() {
223						tcp::State::Closed | tcp::State::Closing | tcp::State::CloseWait => {
224							Poll::Ready(Ok(0))
225						}
226						tcp::State::FinWait1
227						| tcp::State::FinWait2
228						| tcp::State::Listen
229						| tcp::State::TimeWait => Poll::Ready(Err(io::Error::EIO)),
230						_ => {
231							if socket.can_send() {
232								Poll::Ready(
233									socket
234										.send_slice(&buffer[pos..])
235										.map_err(|_| io::Error::EIO),
236								)
237							} else if pos > 0 {
238								// we already send some data => return 0 as signal to stop the
239								// async write
240								Poll::Ready(Ok(0))
241							} else if self.is_nonblocking {
242								Poll::Ready(Err(io::Error::EAGAIN))
243							} else {
244								socket.register_send_waker(cx.waker());
245								Poll::Pending
246							}
247						}
248					}
249				})
250			})
251			.await?;
252
253			if n == 0 {
254				break;
255			}
256
257			pos += n;
258		}
259
260		Ok(pos)
261	}
262
263	async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
264		#[allow(irrefutable_let_patterns)]
265		if let ListenEndpoint::Ip(endpoint) = endpoint {
266			self.port = endpoint.port;
267			Ok(())
268		} else {
269			Err(io::Error::EIO)
270		}
271	}
272
273	async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
274		#[allow(irrefutable_let_patterns)]
275		if let Endpoint::Ip(endpoint) = endpoint {
276			self.with_context(|socket, cx| socket.connect(cx, endpoint, get_ephemeral_port()))
277				.map_err(|_| io::Error::EIO)?;
278
279			future::poll_fn(|cx| {
280				self.with(|socket| match socket.state() {
281					tcp::State::Closed | tcp::State::TimeWait => {
282						Poll::Ready(Err(io::Error::EFAULT))
283					}
284					tcp::State::Listen => Poll::Ready(Err(io::Error::EIO)),
285					tcp::State::SynSent | tcp::State::SynReceived => {
286						socket.register_send_waker(cx.waker());
287						Poll::Pending
288					}
289					_ => Poll::Ready(Ok(())),
290				})
291			})
292			.await
293		} else {
294			Err(io::Error::EIO)
295		}
296	}
297
298	async fn accept(&mut self) -> io::Result<(Socket, Endpoint)> {
299		if !self.is_listen {
300			self.listen(DEFAULT_BACKLOG).await?;
301		}
302
303		let connection_handle = future::poll_fn(|cx| {
304			let mut guard = NIC.lock();
305			let nic = guard.as_nic_mut().unwrap();
306			let mut socket_handle = None;
307
308			for handle in self.handle.iter() {
309				let s = nic.get_mut_socket::<tcp::Socket<'_>>(*handle);
310
311				if s.is_active() {
312					socket_handle = Some(*handle);
313					break;
314				}
315			}
316
317			if let Some(handle) = socket_handle {
318				self.handle.remove(&handle);
319				Poll::Ready(Ok(handle))
320			} else if self.is_nonblocking {
321				Poll::Ready(Err(io::Error::EAGAIN))
322			} else {
323				for handle in self.handle.iter() {
324					let s = nic.get_mut_socket::<tcp::Socket<'_>>(*handle);
325					s.register_recv_waker(cx.waker());
326				}
327
328				Poll::Pending
329			}
330		})
331		.await?;
332
333		let mut guard = NIC.lock();
334		let nic = guard.as_nic_mut().map_err(|_| io::Error::EIO)?;
335		let socket = nic.get_mut_socket::<tcp::Socket<'_>>(connection_handle);
336		socket.set_keep_alive(Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL)));
337		let endpoint = Endpoint::Ip(socket.remote_endpoint().unwrap());
338		let nagle_enabled = socket.nagle_enabled();
339
340		// fill up queue for pending connections
341		let new_handle = nic.create_tcp_handle().unwrap();
342		self.handle.insert(new_handle);
343		let socket = nic.get_mut_socket::<tcp::Socket<'_>>(new_handle);
344		socket.set_nagle_enabled(nagle_enabled);
345		socket.listen(self.port).map_err(|_| io::Error::EIO)?;
346
347		let mut handle = BTreeSet::new();
348		handle.insert(connection_handle);
349
350		let socket = Socket {
351			handle,
352			port: self.port,
353			is_nonblocking: self.is_nonblocking,
354			is_listen: false,
355			domain: self.domain,
356		};
357
358		Ok((socket, endpoint))
359	}
360
361	async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
362		Ok(self
363			.with(|socket| socket.remote_endpoint())
364			.map(Endpoint::Ip))
365	}
366
367	async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
368		Ok(self
369			.with(|socket| socket.local_endpoint())
370			.map(Endpoint::Ip))
371	}
372
373	async fn listen(&mut self, backlog: i32) -> io::Result<()> {
374		let nagle_enabled = self.with(|socket| socket.nagle_enabled());
375		let mut guard = NIC.lock();
376		let nic = guard.as_nic_mut().unwrap();
377
378		let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap());
379
380		if socket.is_open() {
381			return Err(io::Error::EIO);
382		}
383
384		if backlog <= 0 {
385			return Err(io::Error::EINVAL);
386		}
387
388		socket.listen(self.port).map_err(|_| io::Error::EIO)?;
389
390		self.is_listen = true;
391
392		for _ in 1..backlog {
393			let handle = nic.create_tcp_handle().unwrap();
394
395			let s = nic.get_mut_socket::<tcp::Socket<'_>>(handle);
396			s.set_nagle_enabled(nagle_enabled);
397			s.listen(self.port).map_err(|_| io::Error::EIO)?;
398
399			self.handle.insert(handle);
400		}
401
402		Ok(())
403	}
404
405	async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> {
406		if opt == SocketOption::TcpNoDelay {
407			let mut guard = NIC.lock();
408			let nic = guard.as_nic_mut().unwrap();
409
410			for i in self.handle.iter() {
411				let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*i);
412				socket.set_nagle_enabled(optval);
413			}
414
415			Ok(())
416		} else {
417			Err(io::Error::EINVAL)
418		}
419	}
420
421	async fn getsockopt(&self, opt: SocketOption) -> io::Result<bool> {
422		if opt == SocketOption::TcpNoDelay {
423			let mut guard = NIC.lock();
424			let nic = guard.as_nic_mut().unwrap();
425			let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap());
426
427			Ok(socket.nagle_enabled())
428		} else {
429			Err(io::Error::EINVAL)
430		}
431	}
432
433	async fn shutdown(&self, how: i32) -> io::Result<()> {
434		match how {
435			SHUT_RD /* Read  */ |
436			SHUT_WR /* Write */ |
437			SHUT_RDWR /* Both */ => Ok(()),
438			_ => Err(io::Error::EINVAL),
439		}
440	}
441
442	async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
443		let status_flags = if self.is_nonblocking {
444			fd::StatusFlags::O_NONBLOCK
445		} else {
446			fd::StatusFlags::empty()
447		};
448
449		Ok(status_flags)
450	}
451
452	async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
453		self.is_nonblocking = status_flags.contains(fd::StatusFlags::O_NONBLOCK);
454		Ok(())
455	}
456}
457
458impl Drop for Socket {
459	fn drop(&mut self) {
460		let _ = block_on(self.close(), None);
461
462		let mut guard = NIC.lock();
463		for h in self.handle.iter() {
464			guard.as_nic_mut().unwrap().destroy_socket(*h);
465		}
466	}
467}
468
469#[async_trait]
470impl ObjectInterface for async_lock::RwLock<Socket> {
471	async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
472		self.read().await.poll(event).await
473	}
474
475	async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
476		self.read().await.read(buffer).await
477	}
478
479	async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
480		self.read().await.write(buffer).await
481	}
482
483	async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
484		self.write().await.bind(endpoint).await
485	}
486
487	async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
488		self.read().await.connect(endpoint).await
489	}
490
491	async fn accept(&self) -> io::Result<(Arc<dyn ObjectInterface>, Endpoint)> {
492		let (socket, endpoint) = self.write().await.accept().await?;
493		Ok((Arc::new(async_lock::RwLock::new(socket)), endpoint))
494	}
495
496	async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
497		self.read().await.getpeername().await
498	}
499
500	async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
501		self.read().await.getsockname().await
502	}
503
504	async fn listen(&self, backlog: i32) -> io::Result<()> {
505		self.write().await.listen(backlog).await
506	}
507
508	async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> {
509		self.read().await.setsockopt(opt, optval).await
510	}
511
512	async fn getsockopt(&self, opt: SocketOption) -> io::Result<bool> {
513		self.read().await.getsockopt(opt).await
514	}
515
516	async fn shutdown(&self, how: i32) -> io::Result<()> {
517		self.read().await.shutdown(how).await
518	}
519
520	async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
521		self.read().await.status_flags().await
522	}
523
524	async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
525		self.write().await.set_status_flags(status_flags).await
526	}
527
528	async fn inet_domain(&self) -> io::Result<i32> {
529		let domain = self.read().await.domain;
530		Ok(domain)
531	}
532}