smoltcp/wire/
udp.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use core::fmt;
3
4use super::{Error, Result};
5use crate::phy::ChecksumCapabilities;
6use crate::wire::ip::checksum;
7use crate::wire::{IpAddress, IpProtocol};
8
9/// A read/write wrapper around an User Datagram Protocol packet buffer.
10#[derive(Debug, PartialEq, Eq, Clone)]
11pub struct Packet<T: AsRef<[u8]>> {
12    buffer: T,
13}
14
15mod field {
16    #![allow(non_snake_case)]
17
18    use crate::wire::field::*;
19
20    pub const SRC_PORT: Field = 0..2;
21    pub const DST_PORT: Field = 2..4;
22    pub const LENGTH: Field = 4..6;
23    pub const CHECKSUM: Field = 6..8;
24
25    pub const fn PAYLOAD(length: u16) -> Field {
26        CHECKSUM.end..(length as usize)
27    }
28}
29
30pub const HEADER_LEN: usize = field::CHECKSUM.end;
31
32#[allow(clippy::len_without_is_empty)]
33impl<T: AsRef<[u8]>> Packet<T> {
34    /// Imbue a raw octet buffer with UDP packet structure.
35    pub const fn new_unchecked(buffer: T) -> Packet<T> {
36        Packet { buffer }
37    }
38
39    /// Shorthand for a combination of [new_unchecked] and [check_len].
40    ///
41    /// [new_unchecked]: #method.new_unchecked
42    /// [check_len]: #method.check_len
43    pub fn new_checked(buffer: T) -> Result<Packet<T>> {
44        let packet = Self::new_unchecked(buffer);
45        packet.check_len()?;
46        Ok(packet)
47    }
48
49    /// Ensure that no accessor method will panic if called.
50    /// Returns `Err(Error)` if the buffer is too short.
51    /// Returns `Err(Error)` if the length field has a value smaller
52    /// than the header length.
53    ///
54    /// The result of this check is invalidated by calling [set_len].
55    ///
56    /// [set_len]: #method.set_len
57    pub fn check_len(&self) -> Result<()> {
58        let buffer_len = self.buffer.as_ref().len();
59        if buffer_len < HEADER_LEN {
60            Err(Error)
61        } else {
62            let field_len = self.len() as usize;
63            if buffer_len < field_len || field_len < HEADER_LEN {
64                Err(Error)
65            } else {
66                Ok(())
67            }
68        }
69    }
70
71    /// Consume the packet, returning the underlying buffer.
72    pub fn into_inner(self) -> T {
73        self.buffer
74    }
75
76    /// Return the source port field.
77    #[inline]
78    pub fn src_port(&self) -> u16 {
79        let data = self.buffer.as_ref();
80        NetworkEndian::read_u16(&data[field::SRC_PORT])
81    }
82
83    /// Return the destination port field.
84    #[inline]
85    pub fn dst_port(&self) -> u16 {
86        let data = self.buffer.as_ref();
87        NetworkEndian::read_u16(&data[field::DST_PORT])
88    }
89
90    /// Return the length field.
91    #[inline]
92    pub fn len(&self) -> u16 {
93        let data = self.buffer.as_ref();
94        NetworkEndian::read_u16(&data[field::LENGTH])
95    }
96
97    /// Return the checksum field.
98    #[inline]
99    pub fn checksum(&self) -> u16 {
100        let data = self.buffer.as_ref();
101        NetworkEndian::read_u16(&data[field::CHECKSUM])
102    }
103
104    /// Validate the packet checksum.
105    ///
106    /// # Panics
107    /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
108    /// and that family is IPv4 or IPv6.
109    ///
110    /// # Fuzzing
111    /// This function always returns `true` when fuzzing.
112    pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool {
113        if cfg!(fuzzing) {
114            return true;
115        }
116
117        // From the RFC:
118        // > An all zero transmitted checksum value means that the transmitter
119        // > generated no checksum (for debugging or for higher level protocols
120        // > that don't care).
121        if self.checksum() == 0 {
122            return true;
123        }
124
125        let data = self.buffer.as_ref();
126        checksum::combine(&[
127            checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
128            checksum::data(&data[..self.len() as usize]),
129        ]) == !0
130    }
131}
132
133impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
134    /// Return a pointer to the payload.
135    #[inline]
136    pub fn payload(&self) -> &'a [u8] {
137        let length = self.len();
138        let data = self.buffer.as_ref();
139        &data[field::PAYLOAD(length)]
140    }
141}
142
143impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
144    /// Set the source port field.
145    #[inline]
146    pub fn set_src_port(&mut self, value: u16) {
147        let data = self.buffer.as_mut();
148        NetworkEndian::write_u16(&mut data[field::SRC_PORT], value)
149    }
150
151    /// Set the destination port field.
152    #[inline]
153    pub fn set_dst_port(&mut self, value: u16) {
154        let data = self.buffer.as_mut();
155        NetworkEndian::write_u16(&mut data[field::DST_PORT], value)
156    }
157
158    /// Set the length field.
159    #[inline]
160    pub fn set_len(&mut self, value: u16) {
161        let data = self.buffer.as_mut();
162        NetworkEndian::write_u16(&mut data[field::LENGTH], value)
163    }
164
165    /// Set the checksum field.
166    #[inline]
167    pub fn set_checksum(&mut self, value: u16) {
168        let data = self.buffer.as_mut();
169        NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
170    }
171
172    /// Compute and fill in the header checksum.
173    ///
174    /// # Panics
175    /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
176    /// and that family is IPv4 or IPv6.
177    pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) {
178        self.set_checksum(0);
179        let checksum = {
180            let data = self.buffer.as_ref();
181            !checksum::combine(&[
182                checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
183                checksum::data(&data[..self.len() as usize]),
184            ])
185        };
186        // UDP checksum value of 0 means no checksum; if the checksum really is zero,
187        // use all-ones, which indicates that the remote end must verify the checksum.
188        // Arithmetically, RFC 1071 checksums of all-zeroes and all-ones behave identically,
189        // so no action is necessary on the remote end.
190        self.set_checksum(if checksum == 0 { 0xffff } else { checksum })
191    }
192
193    /// Return a mutable pointer to the payload.
194    #[inline]
195    pub fn payload_mut(&mut self) -> &mut [u8] {
196        let length = self.len();
197        let data = self.buffer.as_mut();
198        &mut data[field::PAYLOAD(length)]
199    }
200}
201
202impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
203    fn as_ref(&self) -> &[u8] {
204        self.buffer.as_ref()
205    }
206}
207
208/// A high-level representation of an User Datagram Protocol packet.
209#[derive(Debug, PartialEq, Eq, Clone, Copy)]
210pub struct Repr {
211    pub src_port: u16,
212    pub dst_port: u16,
213}
214
215impl Repr {
216    /// Parse an User Datagram Protocol packet and return a high-level representation.
217    pub fn parse<T>(
218        packet: &Packet<&T>,
219        src_addr: &IpAddress,
220        dst_addr: &IpAddress,
221        checksum_caps: &ChecksumCapabilities,
222    ) -> Result<Repr>
223    where
224        T: AsRef<[u8]> + ?Sized,
225    {
226        packet.check_len()?;
227
228        // Destination port cannot be omitted (but source port can be).
229        if packet.dst_port() == 0 {
230            return Err(Error);
231        }
232        // Valid checksum is expected...
233        if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) {
234            match (src_addr, dst_addr) {
235                // ... except on UDP-over-IPv4, where it can be omitted.
236                #[cfg(feature = "proto-ipv4")]
237                (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_)) if packet.checksum() == 0 => (),
238                _ => return Err(Error),
239            }
240        }
241
242        Ok(Repr {
243            src_port: packet.src_port(),
244            dst_port: packet.dst_port(),
245        })
246    }
247
248    /// Return the length of the packet header that will be emitted from this high-level representation.
249    pub const fn header_len(&self) -> usize {
250        HEADER_LEN
251    }
252
253    /// Emit a high-level representation into an User Datagram Protocol packet.
254    ///
255    /// This never calculates the checksum, and is intended for internal-use only,
256    /// not for packets that are going to be actually sent over the network. For
257    /// example, when decompressing 6lowpan.
258    pub(crate) fn emit_header<T>(&self, packet: &mut Packet<&mut T>, payload_len: usize)
259    where
260        T: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
261    {
262        packet.set_src_port(self.src_port);
263        packet.set_dst_port(self.dst_port);
264        packet.set_len((HEADER_LEN + payload_len) as u16);
265        packet.set_checksum(0);
266    }
267
268    /// Emit a high-level representation into an User Datagram Protocol packet.
269    pub fn emit<T>(
270        &self,
271        packet: &mut Packet<&mut T>,
272        src_addr: &IpAddress,
273        dst_addr: &IpAddress,
274        payload_len: usize,
275        emit_payload: impl FnOnce(&mut [u8]),
276        checksum_caps: &ChecksumCapabilities,
277    ) where
278        T: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
279    {
280        packet.set_src_port(self.src_port);
281        packet.set_dst_port(self.dst_port);
282        packet.set_len((HEADER_LEN + payload_len) as u16);
283        emit_payload(packet.payload_mut());
284
285        if checksum_caps.udp.tx() {
286            packet.fill_checksum(src_addr, dst_addr)
287        } else {
288            // make sure we get a consistently zeroed checksum,
289            // since implementations might rely on it
290            packet.set_checksum(0);
291        }
292    }
293}
294
295impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
296    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
297        // Cannot use Repr::parse because we don't have the IP addresses.
298        write!(
299            f,
300            "UDP src={} dst={} len={}",
301            self.src_port(),
302            self.dst_port(),
303            self.payload().len()
304        )
305    }
306}
307
308#[cfg(feature = "defmt")]
309impl<'a, T: AsRef<[u8]> + ?Sized> defmt::Format for Packet<&'a T> {
310    fn format(&self, fmt: defmt::Formatter) {
311        // Cannot use Repr::parse because we don't have the IP addresses.
312        defmt::write!(
313            fmt,
314            "UDP src={} dst={} len={}",
315            self.src_port(),
316            self.dst_port(),
317            self.payload().len()
318        );
319    }
320}
321
322impl fmt::Display for Repr {
323    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
324        write!(f, "UDP src={} dst={}", self.src_port, self.dst_port)
325    }
326}
327
328#[cfg(feature = "defmt")]
329impl defmt::Format for Repr {
330    fn format(&self, fmt: defmt::Formatter) {
331        defmt::write!(fmt, "UDP src={} dst={}", self.src_port, self.dst_port);
332    }
333}
334
335use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
336
337impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
338    fn pretty_print(
339        buffer: &dyn AsRef<[u8]>,
340        f: &mut fmt::Formatter,
341        indent: &mut PrettyIndent,
342    ) -> fmt::Result {
343        match Packet::new_checked(buffer) {
344            Err(err) => write!(f, "{indent}({err})"),
345            Ok(packet) => write!(f, "{indent}{packet}"),
346        }
347    }
348}
349
350#[cfg(test)]
351mod test {
352    use super::*;
353    #[cfg(feature = "proto-ipv4")]
354    use crate::wire::Ipv4Address;
355
356    #[cfg(feature = "proto-ipv4")]
357    const SRC_ADDR: Ipv4Address = Ipv4Address::new(192, 168, 1, 1);
358    #[cfg(feature = "proto-ipv4")]
359    const DST_ADDR: Ipv4Address = Ipv4Address::new(192, 168, 1, 2);
360
361    #[cfg(feature = "proto-ipv4")]
362    static PACKET_BYTES: [u8; 12] = [
363        0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff,
364    ];
365
366    #[cfg(feature = "proto-ipv4")]
367    static NO_CHECKSUM_PACKET: [u8; 12] = [
368        0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff,
369    ];
370
371    #[cfg(feature = "proto-ipv4")]
372    static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff];
373
374    #[test]
375    #[cfg(feature = "proto-ipv4")]
376    fn test_deconstruct() {
377        let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
378        assert_eq!(packet.src_port(), 48896);
379        assert_eq!(packet.dst_port(), 53);
380        assert_eq!(packet.len(), 12);
381        assert_eq!(packet.checksum(), 0x124d);
382        assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]);
383        assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
384    }
385
386    #[test]
387    #[cfg(feature = "proto-ipv4")]
388    fn test_construct() {
389        let mut bytes = vec![0xa5; 12];
390        let mut packet = Packet::new_unchecked(&mut bytes);
391        packet.set_src_port(48896);
392        packet.set_dst_port(53);
393        packet.set_len(12);
394        packet.set_checksum(0xffff);
395        packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]);
396        packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
397        assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]);
398    }
399
400    #[test]
401    fn test_impossible_len() {
402        let mut bytes = vec![0; 12];
403        let mut packet = Packet::new_unchecked(&mut bytes);
404        packet.set_len(4);
405        assert_eq!(packet.check_len(), Err(Error));
406    }
407
408    #[test]
409    #[cfg(feature = "proto-ipv4")]
410    fn test_zero_checksum() {
411        let mut bytes = vec![0; 8];
412        let mut packet = Packet::new_unchecked(&mut bytes);
413        packet.set_src_port(1);
414        packet.set_dst_port(31881);
415        packet.set_len(8);
416        packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
417        assert_eq!(packet.checksum(), 0xffff);
418    }
419
420    #[test]
421    #[cfg(feature = "proto-ipv4")]
422    fn test_no_checksum() {
423        let mut bytes = vec![0; 8];
424        let mut packet = Packet::new_unchecked(&mut bytes);
425        packet.set_src_port(1);
426        packet.set_dst_port(31881);
427        packet.set_len(8);
428        packet.set_checksum(0);
429        assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
430    }
431
432    #[cfg(feature = "proto-ipv4")]
433    fn packet_repr() -> Repr {
434        Repr {
435            src_port: 48896,
436            dst_port: 53,
437        }
438    }
439
440    #[test]
441    #[cfg(feature = "proto-ipv4")]
442    fn test_parse() {
443        let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
444        let repr = Repr::parse(
445            &packet,
446            &SRC_ADDR.into(),
447            &DST_ADDR.into(),
448            &ChecksumCapabilities::default(),
449        )
450        .unwrap();
451        assert_eq!(repr, packet_repr());
452    }
453
454    #[test]
455    #[cfg(feature = "proto-ipv4")]
456    fn test_emit() {
457        let repr = packet_repr();
458        let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()];
459        let mut packet = Packet::new_unchecked(&mut bytes);
460        repr.emit(
461            &mut packet,
462            &SRC_ADDR.into(),
463            &DST_ADDR.into(),
464            PAYLOAD_BYTES.len(),
465            |payload| payload.copy_from_slice(&PAYLOAD_BYTES),
466            &ChecksumCapabilities::default(),
467        );
468        assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]);
469    }
470
471    #[test]
472    #[cfg(feature = "proto-ipv4")]
473    fn test_checksum_omitted() {
474        let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]);
475        let repr = Repr::parse(
476            &packet,
477            &SRC_ADDR.into(),
478            &DST_ADDR.into(),
479            &ChecksumCapabilities::default(),
480        )
481        .unwrap();
482        assert_eq!(repr, packet_repr());
483    }
484}