smoltcp/wire/
igmp.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use core::fmt;
3
4use super::{Error, Result};
5use crate::time::Duration;
6use crate::wire::ip::checksum;
7
8use crate::wire::{Ipv4Address, Ipv4AddressExt};
9
10enum_with_unknown! {
11    /// Internet Group Management Protocol v1/v2 message version/type.
12    pub enum Message(u8) {
13        /// Membership Query
14        MembershipQuery = 0x11,
15        /// Version 2 Membership Report
16        MembershipReportV2 = 0x16,
17        /// Leave Group
18        LeaveGroup = 0x17,
19        /// Version 1 Membership Report
20        MembershipReportV1 = 0x12
21    }
22}
23
24/// A read/write wrapper around an Internet Group Management Protocol v1/v2 packet buffer.
25#[derive(Debug)]
26#[cfg_attr(feature = "defmt", derive(defmt::Format))]
27pub struct Packet<T: AsRef<[u8]>> {
28    buffer: T,
29}
30
31mod field {
32    use crate::wire::field::*;
33
34    pub const TYPE: usize = 0;
35    pub const MAX_RESP_CODE: usize = 1;
36    pub const CHECKSUM: Field = 2..4;
37    pub const GROUP_ADDRESS: Field = 4..8;
38}
39
40impl fmt::Display for Message {
41    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        match *self {
43            Message::MembershipQuery => write!(f, "membership query"),
44            Message::MembershipReportV2 => write!(f, "version 2 membership report"),
45            Message::LeaveGroup => write!(f, "leave group"),
46            Message::MembershipReportV1 => write!(f, "version 1 membership report"),
47            Message::Unknown(id) => write!(f, "{id}"),
48        }
49    }
50}
51
52/// Internet Group Management Protocol v1/v2 defined in [RFC 2236].
53///
54/// [RFC 2236]: https://tools.ietf.org/html/rfc2236
55impl<T: AsRef<[u8]>> Packet<T> {
56    /// Imbue a raw octet buffer with IGMPv2 packet structure.
57    pub const fn new_unchecked(buffer: T) -> Packet<T> {
58        Packet { buffer }
59    }
60
61    /// Shorthand for a combination of [new_unchecked] and [check_len].
62    ///
63    /// [new_unchecked]: #method.new_unchecked
64    /// [check_len]: #method.check_len
65    pub fn new_checked(buffer: T) -> Result<Packet<T>> {
66        let packet = Self::new_unchecked(buffer);
67        packet.check_len()?;
68        Ok(packet)
69    }
70
71    /// Ensure that no accessor method will panic if called.
72    /// Returns `Err(Error)` if the buffer is too short.
73    pub fn check_len(&self) -> Result<()> {
74        let len = self.buffer.as_ref().len();
75        if len < field::GROUP_ADDRESS.end {
76            Err(Error)
77        } else {
78            Ok(())
79        }
80    }
81
82    /// Consume the packet, returning the underlying buffer.
83    pub fn into_inner(self) -> T {
84        self.buffer
85    }
86
87    /// Return the message type field.
88    #[inline]
89    pub fn msg_type(&self) -> Message {
90        let data = self.buffer.as_ref();
91        Message::from(data[field::TYPE])
92    }
93
94    /// Return the maximum response time, using the encoding specified in
95    /// [RFC 3376]: 4.1.1. Max Resp Code.
96    ///
97    /// [RFC 3376]: https://tools.ietf.org/html/rfc3376
98    #[inline]
99    pub fn max_resp_code(&self) -> u8 {
100        let data = self.buffer.as_ref();
101        data[field::MAX_RESP_CODE]
102    }
103
104    /// Return the checksum field.
105    #[inline]
106    pub fn checksum(&self) -> u16 {
107        let data = self.buffer.as_ref();
108        NetworkEndian::read_u16(&data[field::CHECKSUM])
109    }
110
111    /// Return the source address field.
112    #[inline]
113    pub fn group_addr(&self) -> Ipv4Address {
114        let data = self.buffer.as_ref();
115        Ipv4Address::from_bytes(&data[field::GROUP_ADDRESS])
116    }
117
118    /// Validate the header checksum.
119    ///
120    /// # Fuzzing
121    /// This function always returns `true` when fuzzing.
122    pub fn verify_checksum(&self) -> bool {
123        if cfg!(fuzzing) {
124            return true;
125        }
126
127        let data = self.buffer.as_ref();
128        checksum::data(data) == !0
129    }
130}
131
132impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
133    /// Set the message type field.
134    #[inline]
135    pub fn set_msg_type(&mut self, value: Message) {
136        let data = self.buffer.as_mut();
137        data[field::TYPE] = value.into()
138    }
139
140    /// Set the maximum response time, using the encoding specified in
141    /// [RFC 3376]: 4.1.1. Max Resp Code.
142    #[inline]
143    pub fn set_max_resp_code(&mut self, value: u8) {
144        let data = self.buffer.as_mut();
145        data[field::MAX_RESP_CODE] = value;
146    }
147
148    /// Set the checksum field.
149    #[inline]
150    pub fn set_checksum(&mut self, value: u16) {
151        let data = self.buffer.as_mut();
152        NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
153    }
154
155    /// Set the group address field
156    #[inline]
157    pub fn set_group_address(&mut self, addr: Ipv4Address) {
158        let data = self.buffer.as_mut();
159        data[field::GROUP_ADDRESS].copy_from_slice(&addr.octets());
160    }
161
162    /// Compute and fill in the header checksum.
163    pub fn fill_checksum(&mut self) {
164        self.set_checksum(0);
165        let checksum = {
166            let data = self.buffer.as_ref();
167            !checksum::data(data)
168        };
169        self.set_checksum(checksum)
170    }
171}
172
173/// A high-level representation of an Internet Group Management Protocol v1/v2 header.
174#[derive(Debug, PartialEq, Eq, Clone)]
175#[cfg_attr(feature = "defmt", derive(defmt::Format))]
176pub enum Repr {
177    MembershipQuery {
178        max_resp_time: Duration,
179        group_addr: Ipv4Address,
180        version: IgmpVersion,
181    },
182    MembershipReport {
183        group_addr: Ipv4Address,
184        version: IgmpVersion,
185    },
186    LeaveGroup {
187        group_addr: Ipv4Address,
188    },
189}
190
191/// Type of IGMP membership report version
192#[derive(Debug, PartialEq, Eq, Clone, Copy)]
193#[cfg_attr(feature = "defmt", derive(defmt::Format))]
194pub enum IgmpVersion {
195    /// IGMPv1
196    Version1,
197    /// IGMPv2
198    Version2,
199}
200
201impl Repr {
202    /// Parse an Internet Group Management Protocol v1/v2 packet and return
203    /// a high-level representation.
204    pub fn parse<T>(packet: &Packet<&T>) -> Result<Repr>
205    where
206        T: AsRef<[u8]> + ?Sized,
207    {
208        packet.check_len()?;
209
210        // Check if the address is 0.0.0.0 or multicast
211        let addr = packet.group_addr();
212        if !addr.is_unspecified() && !addr.is_multicast() {
213            return Err(Error);
214        }
215
216        // construct a packet based on the Type field
217        match packet.msg_type() {
218            Message::MembershipQuery => {
219                let max_resp_time = max_resp_code_to_duration(packet.max_resp_code());
220                // See RFC 3376: 7.1. Query Version Distinctions
221                let version = if packet.max_resp_code() == 0 {
222                    IgmpVersion::Version1
223                } else {
224                    IgmpVersion::Version2
225                };
226                Ok(Repr::MembershipQuery {
227                    max_resp_time,
228                    group_addr: addr,
229                    version,
230                })
231            }
232            Message::MembershipReportV2 => Ok(Repr::MembershipReport {
233                group_addr: packet.group_addr(),
234                version: IgmpVersion::Version2,
235            }),
236            Message::LeaveGroup => Ok(Repr::LeaveGroup {
237                group_addr: packet.group_addr(),
238            }),
239            Message::MembershipReportV1 => {
240                // for backwards compatibility with IGMPv1
241                Ok(Repr::MembershipReport {
242                    group_addr: packet.group_addr(),
243                    version: IgmpVersion::Version1,
244                })
245            }
246            _ => Err(Error),
247        }
248    }
249
250    /// Return the length of a packet that will be emitted from this high-level representation.
251    pub const fn buffer_len(&self) -> usize {
252        // always 8 bytes
253        field::GROUP_ADDRESS.end
254    }
255
256    /// Emit a high-level representation into an Internet Group Management Protocol v2 packet.
257    pub fn emit<T>(&self, packet: &mut Packet<&mut T>)
258    where
259        T: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
260    {
261        match *self {
262            Repr::MembershipQuery {
263                max_resp_time,
264                group_addr,
265                version,
266            } => {
267                packet.set_msg_type(Message::MembershipQuery);
268                match version {
269                    IgmpVersion::Version1 => packet.set_max_resp_code(0),
270                    IgmpVersion::Version2 => {
271                        packet.set_max_resp_code(duration_to_max_resp_code(max_resp_time))
272                    }
273                }
274                packet.set_group_address(group_addr);
275            }
276            Repr::MembershipReport {
277                group_addr,
278                version,
279            } => {
280                match version {
281                    IgmpVersion::Version1 => packet.set_msg_type(Message::MembershipReportV1),
282                    IgmpVersion::Version2 => packet.set_msg_type(Message::MembershipReportV2),
283                };
284                packet.set_max_resp_code(0);
285                packet.set_group_address(group_addr);
286            }
287            Repr::LeaveGroup { group_addr } => {
288                packet.set_msg_type(Message::LeaveGroup);
289                packet.set_group_address(group_addr);
290            }
291        }
292
293        packet.fill_checksum()
294    }
295}
296
297fn max_resp_code_to_duration(value: u8) -> Duration {
298    let value: u64 = value.into();
299    let decisecs = if value < 128 {
300        value
301    } else {
302        let mant = value & 0xF;
303        let exp = (value >> 4) & 0x7;
304        (mant | 0x10) << (exp + 3)
305    };
306    Duration::from_millis(decisecs * 100)
307}
308
309const fn duration_to_max_resp_code(duration: Duration) -> u8 {
310    let decisecs = duration.total_millis() / 100;
311    if decisecs < 128 {
312        decisecs as u8
313    } else if decisecs < 31744 {
314        let mut mant = decisecs >> 3;
315        let mut exp = 0u8;
316        while mant > 0x1F && exp < 0x8 {
317            mant >>= 1;
318            exp += 1;
319        }
320        0x80 | (exp << 4) | (mant as u8 & 0xF)
321    } else {
322        0xFF
323    }
324}
325
326impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
327    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
328        match Repr::parse(self) {
329            Ok(repr) => write!(f, "{repr}"),
330            Err(err) => write!(f, "IGMP ({err})"),
331        }
332    }
333}
334
335impl fmt::Display for Repr {
336    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
337        match *self {
338            Repr::MembershipQuery {
339                max_resp_time,
340                group_addr,
341                version,
342            } => write!(
343                f,
344                "IGMP membership query max_resp_time={max_resp_time} group_addr={group_addr} version={version:?}"
345            ),
346            Repr::MembershipReport {
347                group_addr,
348                version,
349            } => write!(
350                f,
351                "IGMP membership report group_addr={group_addr} version={version:?}"
352            ),
353            Repr::LeaveGroup { group_addr } => {
354                write!(f, "IGMP leave group group_addr={group_addr})")
355            }
356        }
357    }
358}
359
360use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
361
362impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
363    fn pretty_print(
364        buffer: &dyn AsRef<[u8]>,
365        f: &mut fmt::Formatter,
366        indent: &mut PrettyIndent,
367    ) -> fmt::Result {
368        match Packet::new_checked(buffer) {
369            Err(err) => writeln!(f, "{indent}({err})"),
370            Ok(packet) => writeln!(f, "{indent}{packet}"),
371        }
372    }
373}
374
375#[cfg(test)]
376mod test {
377    use super::*;
378
379    static LEAVE_PACKET_BYTES: [u8; 8] = [0x17, 0x00, 0x02, 0x69, 0xe0, 0x00, 0x06, 0x96];
380    static REPORT_PACKET_BYTES: [u8; 8] = [0x16, 0x00, 0x08, 0xda, 0xe1, 0x00, 0x00, 0x25];
381
382    #[test]
383    fn test_leave_group_deconstruct() {
384        let packet = Packet::new_unchecked(&LEAVE_PACKET_BYTES[..]);
385        assert_eq!(packet.msg_type(), Message::LeaveGroup);
386        assert_eq!(packet.max_resp_code(), 0);
387        assert_eq!(packet.checksum(), 0x269);
388        assert_eq!(
389            packet.group_addr(),
390            Ipv4Address::from_bytes(&[224, 0, 6, 150])
391        );
392        assert!(packet.verify_checksum());
393    }
394
395    #[test]
396    fn test_report_deconstruct() {
397        let packet = Packet::new_unchecked(&REPORT_PACKET_BYTES[..]);
398        assert_eq!(packet.msg_type(), Message::MembershipReportV2);
399        assert_eq!(packet.max_resp_code(), 0);
400        assert_eq!(packet.checksum(), 0x08da);
401        assert_eq!(
402            packet.group_addr(),
403            Ipv4Address::from_bytes(&[225, 0, 0, 37])
404        );
405        assert!(packet.verify_checksum());
406    }
407
408    #[test]
409    fn test_leave_construct() {
410        let mut bytes = vec![0xa5; 8];
411        let mut packet = Packet::new_unchecked(&mut bytes);
412        packet.set_msg_type(Message::LeaveGroup);
413        packet.set_max_resp_code(0);
414        packet.set_group_address(Ipv4Address::from_bytes(&[224, 0, 6, 150]));
415        packet.fill_checksum();
416        assert_eq!(&*packet.into_inner(), &LEAVE_PACKET_BYTES[..]);
417    }
418
419    #[test]
420    fn test_report_construct() {
421        let mut bytes = vec![0xa5; 8];
422        let mut packet = Packet::new_unchecked(&mut bytes);
423        packet.set_msg_type(Message::MembershipReportV2);
424        packet.set_max_resp_code(0);
425        packet.set_group_address(Ipv4Address::from_bytes(&[225, 0, 0, 37]));
426        packet.fill_checksum();
427        assert_eq!(&*packet.into_inner(), &REPORT_PACKET_BYTES[..]);
428    }
429
430    #[test]
431    fn max_resp_time_to_duration_and_back() {
432        for i in 0..256usize {
433            let time1 = i as u8;
434            let duration = max_resp_code_to_duration(time1);
435            let time2 = duration_to_max_resp_code(duration);
436            assert!(time1 == time2);
437        }
438    }
439
440    #[test]
441    fn duration_to_max_resp_time_max() {
442        for duration in 31744..65536 {
443            let time = duration_to_max_resp_code(Duration::from_millis(duration * 100));
444            assert_eq!(time, 0xFF);
445        }
446    }
447}