1use byteorder::{ByteOrder, NetworkEndian};
2use core::fmt;
3
4use super::{Error, Result};
5use crate::time::Duration;
6use crate::wire::ip::checksum;
7
8use crate::wire::{Ipv4Address, Ipv4AddressExt};
9
10enum_with_unknown! {
11 pub enum Message(u8) {
13 MembershipQuery = 0x11,
15 MembershipReportV2 = 0x16,
17 LeaveGroup = 0x17,
19 MembershipReportV1 = 0x12
21 }
22}
23
24#[derive(Debug)]
26#[cfg_attr(feature = "defmt", derive(defmt::Format))]
27pub struct Packet<T: AsRef<[u8]>> {
28 buffer: T,
29}
30
31mod field {
32 use crate::wire::field::*;
33
34 pub const TYPE: usize = 0;
35 pub const MAX_RESP_CODE: usize = 1;
36 pub const CHECKSUM: Field = 2..4;
37 pub const GROUP_ADDRESS: Field = 4..8;
38}
39
40impl fmt::Display for Message {
41 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42 match *self {
43 Message::MembershipQuery => write!(f, "membership query"),
44 Message::MembershipReportV2 => write!(f, "version 2 membership report"),
45 Message::LeaveGroup => write!(f, "leave group"),
46 Message::MembershipReportV1 => write!(f, "version 1 membership report"),
47 Message::Unknown(id) => write!(f, "{id}"),
48 }
49 }
50}
51
52impl<T: AsRef<[u8]>> Packet<T> {
56 pub const fn new_unchecked(buffer: T) -> Packet<T> {
58 Packet { buffer }
59 }
60
61 pub fn new_checked(buffer: T) -> Result<Packet<T>> {
66 let packet = Self::new_unchecked(buffer);
67 packet.check_len()?;
68 Ok(packet)
69 }
70
71 pub fn check_len(&self) -> Result<()> {
74 let len = self.buffer.as_ref().len();
75 if len < field::GROUP_ADDRESS.end {
76 Err(Error)
77 } else {
78 Ok(())
79 }
80 }
81
82 pub fn into_inner(self) -> T {
84 self.buffer
85 }
86
87 #[inline]
89 pub fn msg_type(&self) -> Message {
90 let data = self.buffer.as_ref();
91 Message::from(data[field::TYPE])
92 }
93
94 #[inline]
99 pub fn max_resp_code(&self) -> u8 {
100 let data = self.buffer.as_ref();
101 data[field::MAX_RESP_CODE]
102 }
103
104 #[inline]
106 pub fn checksum(&self) -> u16 {
107 let data = self.buffer.as_ref();
108 NetworkEndian::read_u16(&data[field::CHECKSUM])
109 }
110
111 #[inline]
113 pub fn group_addr(&self) -> Ipv4Address {
114 let data = self.buffer.as_ref();
115 Ipv4Address::from_bytes(&data[field::GROUP_ADDRESS])
116 }
117
118 pub fn verify_checksum(&self) -> bool {
123 if cfg!(fuzzing) {
124 return true;
125 }
126
127 let data = self.buffer.as_ref();
128 checksum::data(data) == !0
129 }
130}
131
132impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
133 #[inline]
135 pub fn set_msg_type(&mut self, value: Message) {
136 let data = self.buffer.as_mut();
137 data[field::TYPE] = value.into()
138 }
139
140 #[inline]
143 pub fn set_max_resp_code(&mut self, value: u8) {
144 let data = self.buffer.as_mut();
145 data[field::MAX_RESP_CODE] = value;
146 }
147
148 #[inline]
150 pub fn set_checksum(&mut self, value: u16) {
151 let data = self.buffer.as_mut();
152 NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
153 }
154
155 #[inline]
157 pub fn set_group_address(&mut self, addr: Ipv4Address) {
158 let data = self.buffer.as_mut();
159 data[field::GROUP_ADDRESS].copy_from_slice(&addr.octets());
160 }
161
162 pub fn fill_checksum(&mut self) {
164 self.set_checksum(0);
165 let checksum = {
166 let data = self.buffer.as_ref();
167 !checksum::data(data)
168 };
169 self.set_checksum(checksum)
170 }
171}
172
173#[derive(Debug, PartialEq, Eq, Clone)]
175#[cfg_attr(feature = "defmt", derive(defmt::Format))]
176pub enum Repr {
177 MembershipQuery {
178 max_resp_time: Duration,
179 group_addr: Ipv4Address,
180 version: IgmpVersion,
181 },
182 MembershipReport {
183 group_addr: Ipv4Address,
184 version: IgmpVersion,
185 },
186 LeaveGroup {
187 group_addr: Ipv4Address,
188 },
189}
190
191#[derive(Debug, PartialEq, Eq, Clone, Copy)]
193#[cfg_attr(feature = "defmt", derive(defmt::Format))]
194pub enum IgmpVersion {
195 Version1,
197 Version2,
199}
200
201impl Repr {
202 pub fn parse<T>(packet: &Packet<&T>) -> Result<Repr>
205 where
206 T: AsRef<[u8]> + ?Sized,
207 {
208 packet.check_len()?;
209
210 let addr = packet.group_addr();
212 if !addr.is_unspecified() && !addr.is_multicast() {
213 return Err(Error);
214 }
215
216 match packet.msg_type() {
218 Message::MembershipQuery => {
219 let max_resp_time = max_resp_code_to_duration(packet.max_resp_code());
220 let version = if packet.max_resp_code() == 0 {
222 IgmpVersion::Version1
223 } else {
224 IgmpVersion::Version2
225 };
226 Ok(Repr::MembershipQuery {
227 max_resp_time,
228 group_addr: addr,
229 version,
230 })
231 }
232 Message::MembershipReportV2 => Ok(Repr::MembershipReport {
233 group_addr: packet.group_addr(),
234 version: IgmpVersion::Version2,
235 }),
236 Message::LeaveGroup => Ok(Repr::LeaveGroup {
237 group_addr: packet.group_addr(),
238 }),
239 Message::MembershipReportV1 => {
240 Ok(Repr::MembershipReport {
242 group_addr: packet.group_addr(),
243 version: IgmpVersion::Version1,
244 })
245 }
246 _ => Err(Error),
247 }
248 }
249
250 pub const fn buffer_len(&self) -> usize {
252 field::GROUP_ADDRESS.end
254 }
255
256 pub fn emit<T>(&self, packet: &mut Packet<&mut T>)
258 where
259 T: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
260 {
261 match *self {
262 Repr::MembershipQuery {
263 max_resp_time,
264 group_addr,
265 version,
266 } => {
267 packet.set_msg_type(Message::MembershipQuery);
268 match version {
269 IgmpVersion::Version1 => packet.set_max_resp_code(0),
270 IgmpVersion::Version2 => {
271 packet.set_max_resp_code(duration_to_max_resp_code(max_resp_time))
272 }
273 }
274 packet.set_group_address(group_addr);
275 }
276 Repr::MembershipReport {
277 group_addr,
278 version,
279 } => {
280 match version {
281 IgmpVersion::Version1 => packet.set_msg_type(Message::MembershipReportV1),
282 IgmpVersion::Version2 => packet.set_msg_type(Message::MembershipReportV2),
283 };
284 packet.set_max_resp_code(0);
285 packet.set_group_address(group_addr);
286 }
287 Repr::LeaveGroup { group_addr } => {
288 packet.set_msg_type(Message::LeaveGroup);
289 packet.set_group_address(group_addr);
290 }
291 }
292
293 packet.fill_checksum()
294 }
295}
296
297fn max_resp_code_to_duration(value: u8) -> Duration {
298 let value: u64 = value.into();
299 let decisecs = if value < 128 {
300 value
301 } else {
302 let mant = value & 0xF;
303 let exp = (value >> 4) & 0x7;
304 (mant | 0x10) << (exp + 3)
305 };
306 Duration::from_millis(decisecs * 100)
307}
308
309const fn duration_to_max_resp_code(duration: Duration) -> u8 {
310 let decisecs = duration.total_millis() / 100;
311 if decisecs < 128 {
312 decisecs as u8
313 } else if decisecs < 31744 {
314 let mut mant = decisecs >> 3;
315 let mut exp = 0u8;
316 while mant > 0x1F && exp < 0x8 {
317 mant >>= 1;
318 exp += 1;
319 }
320 0x80 | (exp << 4) | (mant as u8 & 0xF)
321 } else {
322 0xFF
323 }
324}
325
326impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
327 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
328 match Repr::parse(self) {
329 Ok(repr) => write!(f, "{repr}"),
330 Err(err) => write!(f, "IGMP ({err})"),
331 }
332 }
333}
334
335impl fmt::Display for Repr {
336 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
337 match *self {
338 Repr::MembershipQuery {
339 max_resp_time,
340 group_addr,
341 version,
342 } => write!(
343 f,
344 "IGMP membership query max_resp_time={max_resp_time} group_addr={group_addr} version={version:?}"
345 ),
346 Repr::MembershipReport {
347 group_addr,
348 version,
349 } => write!(
350 f,
351 "IGMP membership report group_addr={group_addr} version={version:?}"
352 ),
353 Repr::LeaveGroup { group_addr } => {
354 write!(f, "IGMP leave group group_addr={group_addr})")
355 }
356 }
357 }
358}
359
360use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
361
362impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
363 fn pretty_print(
364 buffer: &dyn AsRef<[u8]>,
365 f: &mut fmt::Formatter,
366 indent: &mut PrettyIndent,
367 ) -> fmt::Result {
368 match Packet::new_checked(buffer) {
369 Err(err) => writeln!(f, "{indent}({err})"),
370 Ok(packet) => writeln!(f, "{indent}{packet}"),
371 }
372 }
373}
374
375#[cfg(test)]
376mod test {
377 use super::*;
378
379 static LEAVE_PACKET_BYTES: [u8; 8] = [0x17, 0x00, 0x02, 0x69, 0xe0, 0x00, 0x06, 0x96];
380 static REPORT_PACKET_BYTES: [u8; 8] = [0x16, 0x00, 0x08, 0xda, 0xe1, 0x00, 0x00, 0x25];
381
382 #[test]
383 fn test_leave_group_deconstruct() {
384 let packet = Packet::new_unchecked(&LEAVE_PACKET_BYTES[..]);
385 assert_eq!(packet.msg_type(), Message::LeaveGroup);
386 assert_eq!(packet.max_resp_code(), 0);
387 assert_eq!(packet.checksum(), 0x269);
388 assert_eq!(
389 packet.group_addr(),
390 Ipv4Address::from_bytes(&[224, 0, 6, 150])
391 );
392 assert!(packet.verify_checksum());
393 }
394
395 #[test]
396 fn test_report_deconstruct() {
397 let packet = Packet::new_unchecked(&REPORT_PACKET_BYTES[..]);
398 assert_eq!(packet.msg_type(), Message::MembershipReportV2);
399 assert_eq!(packet.max_resp_code(), 0);
400 assert_eq!(packet.checksum(), 0x08da);
401 assert_eq!(
402 packet.group_addr(),
403 Ipv4Address::from_bytes(&[225, 0, 0, 37])
404 );
405 assert!(packet.verify_checksum());
406 }
407
408 #[test]
409 fn test_leave_construct() {
410 let mut bytes = vec![0xa5; 8];
411 let mut packet = Packet::new_unchecked(&mut bytes);
412 packet.set_msg_type(Message::LeaveGroup);
413 packet.set_max_resp_code(0);
414 packet.set_group_address(Ipv4Address::from_bytes(&[224, 0, 6, 150]));
415 packet.fill_checksum();
416 assert_eq!(&*packet.into_inner(), &LEAVE_PACKET_BYTES[..]);
417 }
418
419 #[test]
420 fn test_report_construct() {
421 let mut bytes = vec![0xa5; 8];
422 let mut packet = Packet::new_unchecked(&mut bytes);
423 packet.set_msg_type(Message::MembershipReportV2);
424 packet.set_max_resp_code(0);
425 packet.set_group_address(Ipv4Address::from_bytes(&[225, 0, 0, 37]));
426 packet.fill_checksum();
427 assert_eq!(&*packet.into_inner(), &REPORT_PACKET_BYTES[..]);
428 }
429
430 #[test]
431 fn max_resp_time_to_duration_and_back() {
432 for i in 0..256usize {
433 let time1 = i as u8;
434 let duration = max_resp_code_to_duration(time1);
435 let time2 = duration_to_max_resp_code(duration);
436 assert!(time1 == time2);
437 }
438 }
439
440 #[test]
441 fn duration_to_max_resp_time_max() {
442 for duration in 31744..65536 {
443 let time = duration_to_max_resp_code(Duration::from_millis(duration * 100));
444 assert_eq!(time, 0xFF);
445 }
446 }
447}