1use alloc::boxed::Box;
2use alloc::collections::BTreeSet;
3use alloc::sync::Arc;
4use core::future;
5use core::mem::MaybeUninit;
6use core::sync::atomic::{AtomicU16, Ordering};
7use core::task::Poll;
8
9use async_trait::async_trait;
10use smoltcp::iface;
11use smoltcp::socket::tcp;
12use smoltcp::time::Duration;
13
14use crate::executor::block_on;
15use crate::executor::network::{Handle, NIC};
16use crate::fd::{self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent, SocketOption};
17use crate::{DEFAULT_KEEP_ALIVE_INTERVAL, io};
18
19pub const SHUT_RD: i32 = 0;
21pub const SHUT_WR: i32 = 1;
23pub const SHUT_RDWR: i32 = 2;
25pub const DEFAULT_BACKLOG: i32 = 128;
27
28fn get_ephemeral_port() -> u16 {
29 static LOCAL_ENDPOINT: AtomicU16 = AtomicU16::new(49152);
30
31 LOCAL_ENDPOINT.fetch_add(1, Ordering::SeqCst)
32}
33
34#[derive(Debug)]
35pub struct Socket {
36 handle: BTreeSet<Handle>,
37 port: u16,
38 is_nonblocking: bool,
39 is_listen: bool,
40 domain: i32,
42}
43
44impl Socket {
45 pub fn new(h: Handle, domain: i32) -> Self {
46 let mut handle = BTreeSet::new();
47 handle.insert(h);
48
49 Self {
50 handle,
51 port: 0,
52 is_nonblocking: false,
53 is_listen: false,
54 domain,
55 }
56 }
57
58 fn with<R>(&self, f: impl FnOnce(&mut tcp::Socket<'_>) -> R) -> R {
59 let mut guard = NIC.lock();
60 let nic = guard.as_nic_mut().unwrap();
61 f(nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap()))
62 }
63
64 fn with_context<R>(&self, f: impl FnOnce(&mut tcp::Socket<'_>, &mut iface::Context) -> R) -> R {
65 let mut guard = NIC.lock();
66 let nic = guard.as_nic_mut().unwrap();
67 let (s, cx) = nic.get_socket_and_context::<tcp::Socket<'_>>(*self.handle.first().unwrap());
68 f(s, cx)
69 }
70
71 async fn close(&self) -> io::Result<()> {
72 future::poll_fn(|_cx| {
73 self.with(|socket| {
74 if socket.is_active() {
75 socket.close();
76 Poll::Ready(Ok(()))
77 } else {
78 Poll::Ready(Err(io::Error::EIO))
79 }
80 })
81 })
82 .await?;
83
84 if self.handle.len() > 1 {
85 let mut guard = NIC.lock();
86 let nic = guard.as_nic_mut().unwrap();
87
88 for handle in self.handle.iter().skip(1) {
89 let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*handle);
90 if socket.is_active() {
91 socket.close();
92 }
93 }
94 }
95
96 future::poll_fn(|cx| {
97 self.with(|socket| {
98 if socket.is_active() {
99 socket.register_send_waker(cx.waker());
100 socket.register_recv_waker(cx.waker());
101 Poll::Pending
102 } else {
103 Poll::Ready(Ok(()))
104 }
105 })
106 })
107 .await
108 }
109
110 async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
111 future::poll_fn(|cx| {
112 self.with(|socket| match socket.state() {
113 tcp::State::Closed | tcp::State::Closing | tcp::State::CloseWait => {
114 let available = PollEvent::POLLOUT
115 | PollEvent::POLLWRNORM
116 | PollEvent::POLLWRBAND
117 | PollEvent::POLLIN
118 | PollEvent::POLLRDNORM
119 | PollEvent::POLLRDBAND;
120
121 let ret = event & available;
122
123 if ret.is_empty() {
124 Poll::Ready(Ok(PollEvent::POLLHUP))
125 } else {
126 Poll::Ready(Ok(ret))
127 }
128 }
129 tcp::State::FinWait1 | tcp::State::FinWait2 | tcp::State::TimeWait => {
130 Poll::Ready(Ok(PollEvent::POLLHUP))
131 }
132 tcp::State::Listen => {
133 socket.register_recv_waker(cx.waker());
134 socket.register_send_waker(cx.waker());
135 Poll::Pending
136 }
137 _ => {
138 let mut available = PollEvent::empty();
139
140 if socket.can_recv() || socket.may_recv() && self.is_listen {
141 available.insert(
143 PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND,
144 );
145 }
146
147 if socket.can_send() {
148 available.insert(
149 PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND,
150 );
151 }
152
153 let ret = event & available;
154
155 if ret.is_empty() {
156 if event.intersects(
157 PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND,
158 ) {
159 socket.register_recv_waker(cx.waker());
160 }
161
162 if event.intersects(
163 PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND,
164 ) {
165 socket.register_send_waker(cx.waker());
166 }
167
168 Poll::Pending
169 } else {
170 Poll::Ready(Ok(ret))
171 }
172 }
173 })
174 })
175 .await
176 }
177
178 async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
179 future::poll_fn(|cx| {
180 self.with(|socket| {
181 let state = socket.state();
182 match state {
183 tcp::State::Closed => Poll::Ready(Ok(0)),
184 tcp::State::FinWait1
185 | tcp::State::FinWait2
186 | tcp::State::Listen
187 | tcp::State::TimeWait => Poll::Ready(Err(io::Error::EIO)),
188 _ => {
189 if socket.can_recv() {
190 Poll::Ready(
191 socket
192 .recv(|data| {
193 let len = core::cmp::min(buffer.len(), data.len());
194 buffer[..len].write_copy_of_slice(&data[..len]);
195 (len, len)
196 })
197 .map_err(|_| io::Error::EIO),
198 )
199 } else if state == tcp::State::CloseWait {
200 Poll::Ready(Ok(0))
203 } else if self.is_nonblocking {
204 Poll::Ready(Err(io::Error::EAGAIN))
205 } else {
206 socket.register_recv_waker(cx.waker());
207 Poll::Pending
208 }
209 }
210 }
211 })
212 })
213 .await
214 }
215
216 async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
217 let mut pos: usize = 0;
218
219 while pos < buffer.len() {
220 let n = future::poll_fn(|cx| {
221 self.with(|socket| {
222 match socket.state() {
223 tcp::State::Closed | tcp::State::Closing | tcp::State::CloseWait => {
224 Poll::Ready(Ok(0))
225 }
226 tcp::State::FinWait1
227 | tcp::State::FinWait2
228 | tcp::State::Listen
229 | tcp::State::TimeWait => Poll::Ready(Err(io::Error::EIO)),
230 _ => {
231 if socket.can_send() {
232 Poll::Ready(
233 socket
234 .send_slice(&buffer[pos..])
235 .map_err(|_| io::Error::EIO),
236 )
237 } else if pos > 0 {
238 Poll::Ready(Ok(0))
241 } else if self.is_nonblocking {
242 Poll::Ready(Err(io::Error::EAGAIN))
243 } else {
244 socket.register_send_waker(cx.waker());
245 Poll::Pending
246 }
247 }
248 }
249 })
250 })
251 .await?;
252
253 if n == 0 {
254 break;
255 }
256
257 pos += n;
258 }
259
260 Ok(pos)
261 }
262
263 async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> {
264 #[allow(irrefutable_let_patterns)]
265 if let ListenEndpoint::Ip(endpoint) = endpoint {
266 self.port = endpoint.port;
267 Ok(())
268 } else {
269 Err(io::Error::EIO)
270 }
271 }
272
273 async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
274 #[allow(irrefutable_let_patterns)]
275 if let Endpoint::Ip(endpoint) = endpoint {
276 self.with_context(|socket, cx| socket.connect(cx, endpoint, get_ephemeral_port()))
277 .map_err(|_| io::Error::EIO)?;
278
279 future::poll_fn(|cx| {
280 self.with(|socket| match socket.state() {
281 tcp::State::Closed | tcp::State::TimeWait => {
282 Poll::Ready(Err(io::Error::EFAULT))
283 }
284 tcp::State::Listen => Poll::Ready(Err(io::Error::EIO)),
285 tcp::State::SynSent | tcp::State::SynReceived => {
286 socket.register_send_waker(cx.waker());
287 Poll::Pending
288 }
289 _ => Poll::Ready(Ok(())),
290 })
291 })
292 .await
293 } else {
294 Err(io::Error::EIO)
295 }
296 }
297
298 async fn accept(&mut self) -> io::Result<(Socket, Endpoint)> {
299 if !self.is_listen {
300 self.listen(DEFAULT_BACKLOG).await?;
301 }
302
303 let connection_handle = future::poll_fn(|cx| {
304 let mut guard = NIC.lock();
305 let nic = guard.as_nic_mut().unwrap();
306 let mut socket_handle = None;
307
308 for handle in self.handle.iter() {
309 let s = nic.get_mut_socket::<tcp::Socket<'_>>(*handle);
310
311 if s.is_active() {
312 socket_handle = Some(*handle);
313 break;
314 }
315 }
316
317 if let Some(handle) = socket_handle {
318 self.handle.remove(&handle);
319 Poll::Ready(Ok(handle))
320 } else if self.is_nonblocking {
321 Poll::Ready(Err(io::Error::EAGAIN))
322 } else {
323 for handle in self.handle.iter() {
324 let s = nic.get_mut_socket::<tcp::Socket<'_>>(*handle);
325 s.register_recv_waker(cx.waker());
326 }
327
328 Poll::Pending
329 }
330 })
331 .await?;
332
333 let mut guard = NIC.lock();
334 let nic = guard.as_nic_mut().map_err(|_| io::Error::EIO)?;
335 let socket = nic.get_mut_socket::<tcp::Socket<'_>>(connection_handle);
336 socket.set_keep_alive(Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL)));
337 let endpoint = Endpoint::Ip(socket.remote_endpoint().unwrap());
338 let nagle_enabled = socket.nagle_enabled();
339
340 let new_handle = nic.create_tcp_handle().unwrap();
342 self.handle.insert(new_handle);
343 let socket = nic.get_mut_socket::<tcp::Socket<'_>>(new_handle);
344 socket.set_nagle_enabled(nagle_enabled);
345 socket.listen(self.port).map_err(|_| io::Error::EIO)?;
346
347 let mut handle = BTreeSet::new();
348 handle.insert(connection_handle);
349
350 let socket = Socket {
351 handle,
352 port: self.port,
353 is_nonblocking: self.is_nonblocking,
354 is_listen: false,
355 domain: self.domain,
356 };
357
358 Ok((socket, endpoint))
359 }
360
361 async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
362 Ok(self
363 .with(|socket| socket.remote_endpoint())
364 .map(Endpoint::Ip))
365 }
366
367 async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
368 Ok(self
369 .with(|socket| socket.local_endpoint())
370 .map(Endpoint::Ip))
371 }
372
373 async fn listen(&mut self, backlog: i32) -> io::Result<()> {
374 let nagle_enabled = self.with(|socket| socket.nagle_enabled());
375 let mut guard = NIC.lock();
376 let nic = guard.as_nic_mut().unwrap();
377
378 let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap());
379
380 if socket.is_open() {
381 return Err(io::Error::EIO);
382 }
383
384 if backlog <= 0 {
385 return Err(io::Error::EINVAL);
386 }
387
388 socket.listen(self.port).map_err(|_| io::Error::EIO)?;
389
390 self.is_listen = true;
391
392 for _ in 1..backlog {
393 let handle = nic.create_tcp_handle().unwrap();
394
395 let s = nic.get_mut_socket::<tcp::Socket<'_>>(handle);
396 s.set_nagle_enabled(nagle_enabled);
397 s.listen(self.port).map_err(|_| io::Error::EIO)?;
398
399 self.handle.insert(handle);
400 }
401
402 Ok(())
403 }
404
405 async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> {
406 if opt == SocketOption::TcpNoDelay {
407 let mut guard = NIC.lock();
408 let nic = guard.as_nic_mut().unwrap();
409
410 for i in self.handle.iter() {
411 let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*i);
412 socket.set_nagle_enabled(optval);
413 }
414
415 Ok(())
416 } else {
417 Err(io::Error::EINVAL)
418 }
419 }
420
421 async fn getsockopt(&self, opt: SocketOption) -> io::Result<bool> {
422 if opt == SocketOption::TcpNoDelay {
423 let mut guard = NIC.lock();
424 let nic = guard.as_nic_mut().unwrap();
425 let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap());
426
427 Ok(socket.nagle_enabled())
428 } else {
429 Err(io::Error::EINVAL)
430 }
431 }
432
433 async fn shutdown(&self, how: i32) -> io::Result<()> {
434 match how {
435 SHUT_RD |
436 SHUT_WR |
437 SHUT_RDWR => Ok(()),
438 _ => Err(io::Error::EINVAL),
439 }
440 }
441
442 async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
443 let status_flags = if self.is_nonblocking {
444 fd::StatusFlags::O_NONBLOCK
445 } else {
446 fd::StatusFlags::empty()
447 };
448
449 Ok(status_flags)
450 }
451
452 async fn set_status_flags(&mut self, status_flags: fd::StatusFlags) -> io::Result<()> {
453 self.is_nonblocking = status_flags.contains(fd::StatusFlags::O_NONBLOCK);
454 Ok(())
455 }
456}
457
458impl Drop for Socket {
459 fn drop(&mut self) {
460 let _ = block_on(self.close(), None);
461
462 let mut guard = NIC.lock();
463 for h in self.handle.iter() {
464 guard.as_nic_mut().unwrap().destroy_socket(*h);
465 }
466 }
467}
468
469#[async_trait]
470impl ObjectInterface for async_lock::RwLock<Socket> {
471 async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
472 self.read().await.poll(event).await
473 }
474
475 async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
476 self.read().await.read(buffer).await
477 }
478
479 async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
480 self.read().await.write(buffer).await
481 }
482
483 async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
484 self.write().await.bind(endpoint).await
485 }
486
487 async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
488 self.read().await.connect(endpoint).await
489 }
490
491 async fn accept(&self) -> io::Result<(Arc<dyn ObjectInterface>, Endpoint)> {
492 let (socket, endpoint) = self.write().await.accept().await?;
493 Ok((Arc::new(async_lock::RwLock::new(socket)), endpoint))
494 }
495
496 async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
497 self.read().await.getpeername().await
498 }
499
500 async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
501 self.read().await.getsockname().await
502 }
503
504 async fn listen(&self, backlog: i32) -> io::Result<()> {
505 self.write().await.listen(backlog).await
506 }
507
508 async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> {
509 self.read().await.setsockopt(opt, optval).await
510 }
511
512 async fn getsockopt(&self, opt: SocketOption) -> io::Result<bool> {
513 self.read().await.getsockopt(opt).await
514 }
515
516 async fn shutdown(&self, how: i32) -> io::Result<()> {
517 self.read().await.shutdown(how).await
518 }
519
520 async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
521 self.read().await.status_flags().await
522 }
523
524 async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
525 self.write().await.set_status_flags(status_flags).await
526 }
527
528 async fn inet_domain(&self) -> io::Result<i32> {
529 let domain = self.read().await.domain;
530 Ok(domain)
531 }
532}