smoltcp/iface/
fragmentation.rs

1#![allow(unused)]
2
3use core::fmt;
4
5use managed::{ManagedMap, ManagedSlice};
6
7use crate::config::{FRAGMENTATION_BUFFER_SIZE, REASSEMBLY_BUFFER_COUNT, REASSEMBLY_BUFFER_SIZE};
8use crate::storage::Assembler;
9use crate::time::{Duration, Instant};
10use crate::wire::*;
11
12use core::result::Result;
13
14#[cfg(feature = "alloc")]
15type Buffer = alloc::vec::Vec<u8>;
16#[cfg(not(feature = "alloc"))]
17type Buffer = [u8; REASSEMBLY_BUFFER_SIZE];
18
19/// Problem when assembling: something was out of bounds.
20#[derive(Copy, Clone, PartialEq, Eq, Debug)]
21#[cfg_attr(feature = "defmt", derive(defmt::Format))]
22pub struct AssemblerError;
23
24impl fmt::Display for AssemblerError {
25    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26        write!(f, "AssemblerError")
27    }
28}
29
30#[cfg(feature = "std")]
31impl std::error::Error for AssemblerError {}
32
33/// Packet assembler is full
34#[derive(Copy, Clone, PartialEq, Eq, Debug)]
35#[cfg_attr(feature = "defmt", derive(defmt::Format))]
36pub struct AssemblerFullError;
37
38impl fmt::Display for AssemblerFullError {
39    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
40        write!(f, "AssemblerFullError")
41    }
42}
43
44#[cfg(feature = "std")]
45impl std::error::Error for AssemblerFullError {}
46
47/// Holds different fragments of one packet, used for assembling fragmented packets.
48///
49/// The buffer used for the `PacketAssembler` should either be dynamically sized (ex: Vec<u8>)
50/// or should be statically allocated based upon the MTU of the type of packet being
51/// assembled (ex: 1280 for a IPv6 frame).
52#[derive(Debug)]
53pub struct PacketAssembler<K> {
54    key: Option<K>,
55    buffer: Buffer,
56
57    assembler: Assembler,
58    total_size: Option<usize>,
59    expires_at: Instant,
60}
61
62impl<K> PacketAssembler<K> {
63    /// Create a new empty buffer for fragments.
64    pub const fn new() -> Self {
65        Self {
66            key: None,
67
68            #[cfg(feature = "alloc")]
69            buffer: Buffer::new(),
70            #[cfg(not(feature = "alloc"))]
71            buffer: [0u8; REASSEMBLY_BUFFER_SIZE],
72
73            assembler: Assembler::new(),
74            total_size: None,
75            expires_at: Instant::ZERO,
76        }
77    }
78
79    pub(crate) fn reset(&mut self) {
80        self.key = None;
81        self.assembler.clear();
82        self.total_size = None;
83        self.expires_at = Instant::ZERO;
84    }
85
86    /// Set the total size of the packet assembler.
87    pub(crate) fn set_total_size(&mut self, size: usize) -> Result<(), AssemblerError> {
88        if let Some(old_size) = self.total_size {
89            if old_size != size {
90                return Err(AssemblerError);
91            }
92        }
93
94        #[cfg(not(feature = "alloc"))]
95        if self.buffer.len() < size {
96            return Err(AssemblerError);
97        }
98
99        #[cfg(feature = "alloc")]
100        if self.buffer.len() < size {
101            self.buffer.resize(size, 0);
102        }
103
104        self.total_size = Some(size);
105        Ok(())
106    }
107
108    /// Return the instant when the assembler expires.
109    pub(crate) fn expires_at(&self) -> Instant {
110        self.expires_at
111    }
112
113    pub(crate) fn add_with(
114        &mut self,
115        offset: usize,
116        f: impl Fn(&mut [u8]) -> Result<usize, AssemblerError>,
117    ) -> Result<(), AssemblerError> {
118        if self.buffer.len() < offset {
119            return Err(AssemblerError);
120        }
121
122        let len = f(&mut self.buffer[offset..])?;
123        assert!(offset + len <= self.buffer.len());
124
125        net_debug!(
126            "frag assembler: receiving {} octets at offset {}",
127            len,
128            offset
129        );
130
131        self.assembler.add(offset, len);
132        Ok(())
133    }
134
135    /// Add a fragment into the packet that is being reassembled.
136    ///
137    /// # Errors
138    ///
139    /// - Returns [`Error::PacketAssemblerBufferTooSmall`] when trying to add data into the buffer at a non-existing
140    ///   place.
141    pub(crate) fn add(&mut self, data: &[u8], offset: usize) -> Result<(), AssemblerError> {
142        #[cfg(not(feature = "alloc"))]
143        if self.buffer.len() < offset + data.len() {
144            return Err(AssemblerError);
145        }
146
147        #[cfg(feature = "alloc")]
148        if self.buffer.len() < offset + data.len() {
149            self.buffer.resize(offset + data.len(), 0);
150        }
151
152        let len = data.len();
153        self.buffer[offset..][..len].copy_from_slice(data);
154
155        net_debug!(
156            "frag assembler: receiving {} octets at offset {}",
157            len,
158            offset
159        );
160
161        self.assembler.add(offset, data.len());
162        Ok(())
163    }
164
165    /// Get an immutable slice of the underlying packet data, if reassembly complete.
166    /// This will mark the assembler as empty, so that it can be reused.
167    pub(crate) fn assemble(&mut self) -> Option<&'_ [u8]> {
168        if !self.is_complete() {
169            return None;
170        }
171
172        // NOTE: we can unwrap because `is_complete` already checks this.
173        let total_size = self.total_size.unwrap();
174        self.reset();
175        Some(&self.buffer[..total_size])
176    }
177
178    /// Returns `true` when all fragments have been received, otherwise `false`.
179    pub(crate) fn is_complete(&self) -> bool {
180        self.total_size == Some(self.assembler.peek_front())
181    }
182
183    /// Returns `true` when the packet assembler is free to use.
184    fn is_free(&self) -> bool {
185        self.key.is_none()
186    }
187}
188
189/// Set holding multiple [`PacketAssembler`].
190#[derive(Debug)]
191pub struct PacketAssemblerSet<K: Eq + Copy> {
192    assemblers: [PacketAssembler<K>; REASSEMBLY_BUFFER_COUNT],
193}
194
195impl<K: Eq + Copy> PacketAssemblerSet<K> {
196    const NEW_PA: PacketAssembler<K> = PacketAssembler::new();
197
198    /// Create a new set of packet assemblers.
199    pub fn new() -> Self {
200        Self {
201            assemblers: [Self::NEW_PA; REASSEMBLY_BUFFER_COUNT],
202        }
203    }
204
205    /// Get a [`PacketAssembler`] for a specific key.
206    ///
207    /// If it doesn't exist, it is created, with the `expires_at` timestamp.
208    ///
209    /// If the assembler set is full, in which case an error is returned.
210    pub(crate) fn get(
211        &mut self,
212        key: &K,
213        expires_at: Instant,
214    ) -> Result<&mut PacketAssembler<K>, AssemblerFullError> {
215        let mut empty_slot = None;
216        for slot in &mut self.assemblers {
217            if slot.key.as_ref() == Some(key) {
218                return Ok(slot);
219            }
220            if slot.is_free() {
221                empty_slot = Some(slot)
222            }
223        }
224
225        let slot = empty_slot.ok_or(AssemblerFullError)?;
226        slot.key = Some(*key);
227        slot.expires_at = expires_at;
228        Ok(slot)
229    }
230
231    /// Remove all [`PacketAssembler`]s that are expired.
232    pub fn remove_expired(&mut self, timestamp: Instant) {
233        for frag in &mut self.assemblers {
234            if !frag.is_free() && frag.expires_at < timestamp {
235                frag.reset();
236            }
237        }
238    }
239}
240
241// Max len of non-fragmented packets after decompression (including ipv6 header and payload)
242// TODO: lower. Should be (6lowpan mtu) - (min 6lowpan header size) + (max ipv6 header size)
243pub(crate) const MAX_DECOMPRESSED_LEN: usize = 1500;
244
245#[cfg(feature = "_proto-fragmentation")]
246#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)]
247#[cfg_attr(feature = "defmt", derive(defmt::Format))]
248pub(crate) enum FragKey {
249    #[cfg(feature = "proto-ipv4-fragmentation")]
250    Ipv4(Ipv4FragKey),
251    #[cfg(feature = "proto-sixlowpan-fragmentation")]
252    Sixlowpan(SixlowpanFragKey),
253}
254
255pub(crate) struct FragmentsBuffer {
256    #[cfg(feature = "proto-sixlowpan")]
257    pub decompress_buf: [u8; MAX_DECOMPRESSED_LEN],
258
259    #[cfg(feature = "_proto-fragmentation")]
260    pub assembler: PacketAssemblerSet<FragKey>,
261
262    #[cfg(feature = "_proto-fragmentation")]
263    pub reassembly_timeout: Duration,
264}
265
266#[cfg(not(feature = "_proto-fragmentation"))]
267pub(crate) struct Fragmenter {}
268
269#[cfg(not(feature = "_proto-fragmentation"))]
270impl Fragmenter {
271    pub(crate) fn new() -> Self {
272        Self {}
273    }
274}
275
276#[cfg(feature = "_proto-fragmentation")]
277pub(crate) struct Fragmenter {
278    /// The buffer that holds the unfragmented 6LoWPAN packet.
279    pub buffer: [u8; FRAGMENTATION_BUFFER_SIZE],
280    /// The size of the packet without the IEEE802.15.4 header and the fragmentation headers.
281    pub packet_len: usize,
282    /// The amount of bytes that already have been transmitted.
283    pub sent_bytes: usize,
284
285    #[cfg(feature = "proto-ipv4-fragmentation")]
286    pub ipv4: Ipv4Fragmenter,
287    #[cfg(feature = "proto-sixlowpan-fragmentation")]
288    pub sixlowpan: SixlowpanFragmenter,
289}
290
291#[cfg(feature = "proto-ipv4-fragmentation")]
292pub(crate) struct Ipv4Fragmenter {
293    /// The IPv4 representation.
294    pub repr: Ipv4Repr,
295    /// The destination hardware address.
296    #[cfg(feature = "medium-ethernet")]
297    pub dst_hardware_addr: EthernetAddress,
298    /// The offset of the next fragment.
299    pub frag_offset: u16,
300    /// The identifier of the stream.
301    pub ident: u16,
302}
303
304#[cfg(feature = "proto-sixlowpan-fragmentation")]
305pub(crate) struct SixlowpanFragmenter {
306    /// The datagram size that is used for the fragmentation headers.
307    pub datagram_size: u16,
308    /// The datagram tag that is used for the fragmentation headers.
309    pub datagram_tag: u16,
310    pub datagram_offset: usize,
311
312    /// The size of the FRAG_N packets.
313    pub fragn_size: usize,
314
315    /// The link layer IEEE802.15.4 source address.
316    pub ll_dst_addr: Ieee802154Address,
317    /// The link layer IEEE802.15.4 source address.
318    pub ll_src_addr: Ieee802154Address,
319}
320
321#[cfg(feature = "_proto-fragmentation")]
322impl Fragmenter {
323    pub(crate) fn new() -> Self {
324        Self {
325            buffer: [0u8; FRAGMENTATION_BUFFER_SIZE],
326            packet_len: 0,
327            sent_bytes: 0,
328
329            #[cfg(feature = "proto-ipv4-fragmentation")]
330            ipv4: Ipv4Fragmenter {
331                repr: Ipv4Repr {
332                    src_addr: Ipv4Address::new(0, 0, 0, 0),
333                    dst_addr: Ipv4Address::new(0, 0, 0, 0),
334                    next_header: IpProtocol::Unknown(0),
335                    payload_len: 0,
336                    hop_limit: 0,
337                },
338                #[cfg(feature = "medium-ethernet")]
339                dst_hardware_addr: EthernetAddress::default(),
340                frag_offset: 0,
341                ident: 0,
342            },
343
344            #[cfg(feature = "proto-sixlowpan-fragmentation")]
345            sixlowpan: SixlowpanFragmenter {
346                datagram_size: 0,
347                datagram_tag: 0,
348                datagram_offset: 0,
349                fragn_size: 0,
350                ll_dst_addr: Ieee802154Address::Absent,
351                ll_src_addr: Ieee802154Address::Absent,
352            },
353        }
354    }
355
356    /// Return `true` when everything is transmitted.
357    #[inline]
358    pub(crate) fn finished(&self) -> bool {
359        self.packet_len == self.sent_bytes
360    }
361
362    /// Returns `true` when there is nothing to transmit.
363    #[inline]
364    pub(crate) fn is_empty(&self) -> bool {
365        self.packet_len == 0
366    }
367
368    // Reset the buffer.
369    pub(crate) fn reset(&mut self) {
370        self.packet_len = 0;
371        self.sent_bytes = 0;
372
373        #[cfg(feature = "proto-ipv4-fragmentation")]
374        {
375            self.ipv4.repr = Ipv4Repr {
376                src_addr: Ipv4Address::new(0, 0, 0, 0),
377                dst_addr: Ipv4Address::new(0, 0, 0, 0),
378                next_header: IpProtocol::Unknown(0),
379                payload_len: 0,
380                hop_limit: 0,
381            };
382            #[cfg(feature = "medium-ethernet")]
383            {
384                self.ipv4.dst_hardware_addr = EthernetAddress::default();
385            }
386        }
387
388        #[cfg(feature = "proto-sixlowpan-fragmentation")]
389        {
390            self.sixlowpan.datagram_size = 0;
391            self.sixlowpan.datagram_tag = 0;
392            self.sixlowpan.fragn_size = 0;
393            self.sixlowpan.ll_dst_addr = Ieee802154Address::Absent;
394            self.sixlowpan.ll_src_addr = Ieee802154Address::Absent;
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
404    struct Key {
405        id: usize,
406    }
407
408    #[test]
409    fn packet_assembler_overlap() {
410        let mut p_assembler = PacketAssembler::<Key>::new();
411
412        p_assembler.set_total_size(5).unwrap();
413
414        let data = b"Rust";
415        p_assembler.add(&data[..], 0);
416        p_assembler.add(&data[..], 1);
417
418        assert_eq!(p_assembler.assemble(), Some(&b"RRust"[..]))
419    }
420
421    #[test]
422    fn packet_assembler_assemble() {
423        let mut p_assembler = PacketAssembler::<Key>::new();
424
425        let data = b"Hello World!";
426
427        p_assembler.set_total_size(data.len()).unwrap();
428
429        p_assembler.add(b"Hello ", 0).unwrap();
430        assert_eq!(p_assembler.assemble(), None);
431
432        p_assembler.add(b"World!", b"Hello ".len()).unwrap();
433
434        assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..]));
435    }
436
437    #[test]
438    fn packet_assembler_out_of_order_assemble() {
439        let mut p_assembler = PacketAssembler::<Key>::new();
440
441        let data = b"Hello World!";
442
443        p_assembler.set_total_size(data.len()).unwrap();
444
445        p_assembler.add(b"World!", b"Hello ".len()).unwrap();
446        assert_eq!(p_assembler.assemble(), None);
447
448        p_assembler.add(b"Hello ", 0).unwrap();
449
450        assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..]));
451    }
452
453    #[test]
454    fn packet_assembler_set() {
455        let key = Key { id: 1 };
456
457        let mut set = PacketAssemblerSet::new();
458
459        assert!(set.get(&key, Instant::ZERO).is_ok());
460    }
461
462    #[test]
463    fn packet_assembler_set_full() {
464        let mut set = PacketAssemblerSet::new();
465        for i in 0..REASSEMBLY_BUFFER_COUNT {
466            set.get(&Key { id: i }, Instant::ZERO).unwrap();
467        }
468        assert!(set.get(&Key { id: 4 }, Instant::ZERO).is_err());
469    }
470
471    #[test]
472    fn packet_assembler_set_assembling_many() {
473        let mut set = PacketAssemblerSet::new();
474
475        let key = Key { id: 0 };
476        let assr = set.get(&key, Instant::ZERO).unwrap();
477        assert_eq!(assr.assemble(), None);
478        assr.set_total_size(0).unwrap();
479        assr.assemble().unwrap();
480
481        // Test that `.assemble()` effectively deletes it.
482        let assr = set.get(&key, Instant::ZERO).unwrap();
483        assert_eq!(assr.assemble(), None);
484        assr.set_total_size(0).unwrap();
485        assr.assemble().unwrap();
486
487        let key = Key { id: 1 };
488        let assr = set.get(&key, Instant::ZERO).unwrap();
489        assr.set_total_size(0).unwrap();
490        assr.assemble().unwrap();
491
492        let key = Key { id: 2 };
493        let assr = set.get(&key, Instant::ZERO).unwrap();
494        assr.set_total_size(0).unwrap();
495        assr.assemble().unwrap();
496
497        let key = Key { id: 2 };
498        let assr = set.get(&key, Instant::ZERO).unwrap();
499        assr.set_total_size(2).unwrap();
500        assr.add(&[0x00], 0).unwrap();
501        assert_eq!(assr.assemble(), None);
502        let assr = set.get(&key, Instant::ZERO).unwrap();
503        assr.add(&[0x01], 1).unwrap();
504        assert_eq!(assr.assemble(), Some(&[0x00, 0x01][..]));
505    }
506}