#![allow(unused)]
use super::IpProtocol;
use super::{Error, Result};
mod field {
#![allow(non_snake_case)]
use crate::wire::field::*;
pub const MIN_HEADER_SIZE: usize = 8;
pub const NXT_HDR: usize = 0;
pub const LENGTH: usize = 1;
pub const fn PAYLOAD(length_field: u8) -> Field {
let bytes = length_field as usize * 8 + 8;
2..bytes
}
}
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Header<T: AsRef<[u8]>> {
buffer: T,
}
impl<T: AsRef<[u8]>> Header<T> {
pub const fn new_unchecked(buffer: T) -> Self {
Header { buffer }
}
pub fn new_checked(buffer: T) -> Result<Self> {
let header = Self::new_unchecked(buffer);
header.check_len()?;
Ok(header)
}
pub fn check_len(&self) -> Result<()> {
let data = self.buffer.as_ref();
let len = data.len();
if len < field::MIN_HEADER_SIZE {
return Err(Error);
}
let of = field::PAYLOAD(data[field::LENGTH]);
if len < of.end {
return Err(Error);
}
Ok(())
}
pub fn into_inner(self) -> T {
self.buffer
}
pub fn next_header(&self) -> IpProtocol {
let data = self.buffer.as_ref();
IpProtocol::from(data[field::NXT_HDR])
}
pub fn header_len(&self) -> u8 {
let data = self.buffer.as_ref();
data[field::LENGTH]
}
}
impl<'h, T: AsRef<[u8]> + ?Sized> Header<&'h T> {
pub fn payload(&self) -> &'h [u8] {
let data = self.buffer.as_ref();
&data[field::PAYLOAD(data[field::LENGTH])]
}
}
impl<T: AsRef<[u8]> + AsMut<[u8]>> Header<T> {
#[inline]
pub fn set_next_header(&mut self, value: IpProtocol) {
let data = self.buffer.as_mut();
data[field::NXT_HDR] = value.into();
}
#[inline]
pub fn set_header_len(&mut self, value: u8) {
let data = self.buffer.as_mut();
data[field::LENGTH] = value;
}
}
impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Header<&'a mut T> {
#[inline]
pub fn payload_mut(&mut self) -> &mut [u8] {
let data = self.buffer.as_mut();
let len = data[field::LENGTH];
&mut data[field::PAYLOAD(len)]
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Repr<'a> {
pub next_header: IpProtocol,
pub length: u8,
pub data: &'a [u8],
}
impl<'a> Repr<'a> {
pub fn parse<T>(header: &Header<&'a T>) -> Result<Self>
where
T: AsRef<[u8]> + ?Sized,
{
Ok(Self {
next_header: header.next_header(),
length: header.header_len(),
data: header.payload(),
})
}
pub const fn header_len(&self) -> usize {
2
}
pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, header: &mut Header<&mut T>) {
header.set_next_header(self.next_header);
header.set_header_len(self.length);
}
}
#[cfg(test)]
mod test {
use super::*;
static REPR_PACKET_PAD4: [u8; 8] = [0x6, 0x0, 0x1, 0x4, 0x0, 0x0, 0x0, 0x0];
static REPR_PACKET_PAD12: [u8; 16] = [
0x06, 0x1, 0x1, 0x0C, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
];
#[test]
fn test_check_len() {
assert_eq!(
Err(Error),
Header::new_unchecked(&REPR_PACKET_PAD4[..0]).check_len()
);
assert_eq!(
Err(Error),
Header::new_unchecked(&REPR_PACKET_PAD4[..1]).check_len()
);
assert_eq!(
Err(Error),
Header::new_unchecked(&REPR_PACKET_PAD4[..7]).check_len()
);
assert_eq!(Ok(()), Header::new_unchecked(&REPR_PACKET_PAD4).check_len());
assert_eq!(
Ok(()),
Header::new_unchecked(&REPR_PACKET_PAD12).check_len()
);
let header: [u8; 8] = [0x06, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0];
assert_eq!(Err(Error), Header::new_unchecked(&header).check_len());
}
#[test]
fn test_header_deconstruct() {
let header = Header::new_unchecked(&REPR_PACKET_PAD4);
assert_eq!(header.next_header(), IpProtocol::Tcp);
assert_eq!(header.header_len(), 0);
assert_eq!(header.payload(), &REPR_PACKET_PAD4[2..]);
let header = Header::new_unchecked(&REPR_PACKET_PAD12);
assert_eq!(header.next_header(), IpProtocol::Tcp);
assert_eq!(header.header_len(), 1);
assert_eq!(header.payload(), &REPR_PACKET_PAD12[2..]);
}
#[test]
fn test_overlong() {
let mut bytes = vec![];
bytes.extend(&REPR_PACKET_PAD4[..]);
bytes.push(0);
assert_eq!(
Header::new_unchecked(&bytes).payload().len(),
REPR_PACKET_PAD4[2..].len()
);
assert_eq!(
Header::new_unchecked(&mut bytes).payload_mut().len(),
REPR_PACKET_PAD4[2..].len()
);
let mut bytes = vec![];
bytes.extend(&REPR_PACKET_PAD12[..]);
bytes.push(0);
assert_eq!(
Header::new_unchecked(&bytes).payload().len(),
REPR_PACKET_PAD12[2..].len()
);
assert_eq!(
Header::new_unchecked(&mut bytes).payload_mut().len(),
REPR_PACKET_PAD12[2..].len()
);
}
#[test]
fn test_header_len_overflow() {
let mut bytes = vec![];
bytes.extend(REPR_PACKET_PAD4);
let len = bytes.len() as u8;
Header::new_unchecked(&mut bytes).set_header_len(len + 1);
assert_eq!(Header::new_checked(&bytes).unwrap_err(), Error);
let mut bytes = vec![];
bytes.extend(REPR_PACKET_PAD12);
let len = bytes.len() as u8;
Header::new_unchecked(&mut bytes).set_header_len(len + 1);
assert_eq!(Header::new_checked(&bytes).unwrap_err(), Error);
}
#[test]
fn test_repr_parse_valid() {
let header = Header::new_unchecked(&REPR_PACKET_PAD4);
let repr = Repr::parse(&header).unwrap();
assert_eq!(
repr,
Repr {
next_header: IpProtocol::Tcp,
length: 0,
data: &REPR_PACKET_PAD4[2..]
}
);
let header = Header::new_unchecked(&REPR_PACKET_PAD12);
let repr = Repr::parse(&header).unwrap();
assert_eq!(
repr,
Repr {
next_header: IpProtocol::Tcp,
length: 1,
data: &REPR_PACKET_PAD12[2..]
}
);
}
#[test]
fn test_repr_emit() {
let repr = Repr {
next_header: IpProtocol::Tcp,
length: 0,
data: &REPR_PACKET_PAD4[2..],
};
let mut bytes = [0u8; 2];
let mut header = Header::new_unchecked(&mut bytes);
repr.emit(&mut header);
assert_eq!(header.into_inner(), &REPR_PACKET_PAD4[..2]);
let repr = Repr {
next_header: IpProtocol::Tcp,
length: 1,
data: &REPR_PACKET_PAD12[2..],
};
let mut bytes = [0u8; 2];
let mut header = Header::new_unchecked(&mut bytes);
repr.emit(&mut header);
assert_eq!(header.into_inner(), &REPR_PACKET_PAD12[..2]);
}
}