Skip to main content

memory_addresses/
lib.rs

1//! Universal crate for machine address types.
2
3#![no_std]
4
5use core::fmt;
6use core::fmt::Debug;
7use core::iter::FusedIterator;
8use core::marker::PhantomData;
9use core::ops::{
10    Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl,
11    ShlAssign, Shr, ShrAssign, Sub, SubAssign,
12};
13
14pub mod arch;
15#[macro_use]
16pub(crate) mod macros;
17pub(crate) use macros::impl_address;
18
19cfg_if::cfg_if! {
20    if #[cfg(all(target_arch = "x86_64", feature = "x86_64"))] {
21        pub use crate::arch::x86_64::{PhysAddr, VirtAddr};
22    } else if #[cfg(all(target_arch = "aarch64", feature = "aarch64"))] {
23        pub use crate::arch::aarch64::{PhysAddr, VirtAddr};
24    } else if #[cfg(all(target_arch = "riscv64", feature = "riscv64"))] {
25        pub use crate::arch::riscv64::{PhysAddr, VirtAddr};
26    } else {
27        pub use crate::arch::fallback::{PhysAddr, VirtAddr};
28    }
29}
30
31/// Trait that marks memory addresses.
32///
33/// An address must be a wrapper around a numeric value and thus behave like a number.
34pub trait MemoryAddress:
35    PartialEq
36    + Eq
37    + PartialOrd
38    + Ord
39    + Copy
40    + Clone
41    + Sized
42    + BitAnd<<Self>::RAW, Output = Self::RAW>
43    + BitAndAssign<<Self>::RAW>
44    + BitOr<<Self>::RAW, Output = Self::RAW>
45    + BitOrAssign<<Self>::RAW>
46    + BitXor<<Self>::RAW, Output = Self::RAW>
47    + BitXorAssign<<Self>::RAW>
48    + Add<<Self>::RAW>
49    + AddAssign<<Self>::RAW>
50    + Sub<Self, Output = Self::RAW>
51    + Sub<<Self>::RAW, Output = Self>
52    + SubAssign<<Self>::RAW>
53    + Shr<usize, Output = Self>
54    + ShrAssign<usize>
55    + Shl<usize, Output = Self>
56    + ShlAssign<usize>
57    + fmt::Binary
58    + fmt::LowerHex
59    + fmt::UpperHex
60    + fmt::Octal
61    + fmt::Pointer
62{
63    /// Inner address type
64    type RAW: Copy
65        + PartialEq
66        + Eq
67        + PartialOrd
68        + Ord
69        + Not<Output = Self::RAW>
70        + Add<Output = Self::RAW>
71        + Sub<Output = Self::RAW>
72        + BitAnd<Output = Self::RAW>
73        + BitOr<Output = Self::RAW>
74        + BitXor<Output = Self::RAW>
75        + Debug
76        + From<u8>
77        + TryInto<usize, Error: Debug>
78        + TryFrom<usize, Error: Debug>;
79
80    /// Get the raw underlying address value.
81    fn raw(self) -> Self::RAW;
82
83    /// Performs an add operation with `rhs`
84    /// returning `None` if the operation overflowed or resulted in an invalid address.
85    fn checked_add(self, rhs: Self::RAW) -> Option<Self>;
86
87    /// Performs a sub operation with `rhs`
88    /// returning `None` if the operation overflowed or resulted in an invalid address.
89    fn checked_sub(self, rhs: Self::RAW) -> Option<Self>;
90}
91
92/// Error type for [`AddrRange`]
93#[non_exhaustive]
94#[derive(Clone, PartialEq, Eq, Debug)]
95pub enum AddrRangeError {
96    /// The range was constructed with the end before the start
97    EndBeforeStart,
98}
99
100impl fmt::Display for AddrRangeError {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match self {
103            Self::EndBeforeStart => {
104                f.write_str("Range end address can't be smaller than range start address")
105            }
106        }
107    }
108}
109
110/// A memory range.
111pub struct AddrRange<T: MemoryAddress> {
112    /// Starting address
113    pub start: T,
114    /// End address (exclusive)
115    pub end: T,
116}
117impl<T: MemoryAddress> AddrRange<T> {
118    /// Construct a new memory range from `start` (inclusive) to `end` (exclusive).
119    pub fn new(start: T, end: T) -> Result<Self, AddrRangeError> {
120        if end < start {
121            return Err(AddrRangeError::EndBeforeStart);
122        }
123        Ok(Self { start, end })
124    }
125
126    /// Produces an [`AddrIter`] to iterate over this memory range.
127    pub fn iter(&self) -> AddrIter<T> {
128        AddrIter {
129            current: self.start,
130            end: Some(self.end),
131            _phantom: PhantomData,
132        }
133    }
134
135    /// Check, wether `element` is part of the address range.
136    pub fn contains(&self, element: &T) -> bool {
137        element.raw() >= self.start.raw() && element.raw() < self.end.raw()
138    }
139
140    /// Amount of addresses in the range.
141    ///
142    /// `AddrRange`s are half open, so the range from `0x0` to `0x1` has length 1.
143    pub fn length(&self) -> usize {
144        (self.end.raw() - self.start.raw())
145            .try_into()
146            .expect("address range is larger than the architecture's usize")
147    }
148}
149
150/// An iterator over a memory range
151#[allow(private_bounds)]
152pub struct AddrIter<T: MemoryAddress, I: IterInclusivity = NonInclusive> {
153    current: T,
154    end: Option<T>, // None here indicates that this is exhausted
155    _phantom: PhantomData<I>,
156}
157
158// Note this is deliberately private.
159// Users may need to know about its existence but do not need to implement or use it.
160trait IterInclusivity: 'static {
161    fn exhausted<T: Ord>(start: &T, end: &T) -> bool;
162}
163
164/// This marks [AddrIter] as as acting as non-inclusive.
165///
166/// This is the behaviour when using [AddrRange::iter], it can also be constructed using [From].
167///
168///```
169/// # use memory_addresses::AddrIter;
170/// let start = memory_addresses::PhysAddr::new(0);
171/// let end = memory_addresses::PhysAddr::new(0x1000);
172///
173/// for i in AddrIter::from(start..end) {
174///    // ...
175/// }
176/// assert_eq!(AddrIter::from(start..end).last(), Some(memory_addresses::PhysAddr::new(0xfff)))
177/// ```
178pub enum NonInclusive {}
179
180impl IterInclusivity for NonInclusive {
181    fn exhausted<T: Ord>(start: &T, end: &T) -> bool {
182        start >= end
183    }
184}
185
186/// This marks [AddrIter] as as acting as inclusive.
187///
188/// The inclusive variant of [AddrIter] can be constructed using [From].
189///
190///```
191/// # use memory_addresses::AddrIter;
192/// let start = memory_addresses::PhysAddr::new(0);
193/// let end = memory_addresses::PhysAddr::new(0x1000);
194///
195/// for i in AddrIter::from(start..=end) {
196///    // ...
197/// }
198/// assert_eq!(AddrIter::from(start..=end).last(), Some(memory_addresses::PhysAddr::new(0x1000)))
199/// ```
200pub enum Inclusive {}
201
202impl IterInclusivity for Inclusive {
203    fn exhausted<T: Ord>(start: &T, end: &T) -> bool {
204        start > end
205    }
206}
207
208impl<T: MemoryAddress, I: IterInclusivity> Iterator for AddrIter<T, I> {
209    type Item = T;
210    fn next(&mut self) -> Option<Self::Item> {
211        if I::exhausted(&self.current, &self.end?) {
212            None
213        } else {
214            let ret = Some(self.current);
215            self.current += 1.into();
216            ret
217        }
218    }
219
220    fn size_hint(&self) -> (usize, Option<usize>) {
221        let Some(end) = self.end else {
222            return (0, Some(0));
223        };
224        let ni_count = (end - self.current)
225            .try_into()
226            .expect("address range is larger than the architecture's usize");
227
228        // I whis there was a nicer way to do this.
229        // This will determine whether I is `Inclusive` or `NonInclusive` and take the correct branch.
230        // The compiler can determine which branch will be taken at compile time so this doesnt get checked at runtime.
231        if core::any::TypeId::of::<I>() == core::any::TypeId::of::<NonInclusive>() {
232            (ni_count, Some(ni_count))
233        } else if core::any::TypeId::of::<I>() == core::any::TypeId::of::<Inclusive>() {
234            (ni_count + 1, Some(ni_count + 1))
235        } else {
236            // Explicitly panic when I is not expected.
237            // This should never be possible but it doesnt do any harm.
238            unreachable!()
239        }
240    }
241
242    fn last(self) -> Option<Self::Item> {
243        self.max()
244    }
245
246    fn nth(&mut self, n: usize) -> Option<Self::Item> {
247        let Ok(n): Result<T::RAW, _> = n.try_into() else {
248            // Fail to cast indicates that n > T::RAW::MAX, so we explicitly exhaust self.
249            self.end.take();
250            return None;
251        };
252        match self.current.checked_add(n) {
253            Some(n) => self.current = n,
254            None if self.current.raw() < n => {
255                self.end.take();
256                return None;
257            }
258            None => panic!("Attempted to iterate over invalid address"),
259        }
260        if I::exhausted(&self.current, &self.end?) {
261            return None;
262        }
263        Some(self.current)
264    }
265
266    fn max(self) -> Option<Self::Item>
267    where
268        Self: Sized,
269        Self::Item: Ord,
270    {
271        let end = self.end?;
272        if core::any::TypeId::of::<I>() == core::any::TypeId::of::<NonInclusive>() {
273            let Some(ret) = end.checked_sub(1.into()) else {
274                if end.raw() == 0u8.into() || end == self.current {
275                    return None; // underflow (which is ok) or exhausted.
276                } else {
277                    panic!("Attempted to iterate over invalid address")
278                }
279            };
280            Some(ret)
281        } else if core::any::TypeId::of::<I>() == core::any::TypeId::of::<Inclusive>() {
282            Some(end)
283        } else {
284            // Explicitly panic when I is not expected.
285            // This should never be possible but it doesnt do any harm.
286            unreachable!()
287        }
288    }
289
290    fn min(self) -> Option<Self::Item> {
291        Some(self.current)
292    }
293
294    fn is_sorted(self) -> bool
295    where
296        Self: Sized,
297        Self::Item: PartialOrd,
298    {
299        true
300    }
301}
302
303impl<T: MemoryAddress> DoubleEndedIterator for AddrIter<T, NonInclusive> {
304    fn next_back(&mut self) -> Option<Self::Item> {
305        if NonInclusive::exhausted(&self.current, &self.end?) {
306            None
307        } else {
308            let one: T::RAW = 1u8.into();
309            self.end = Some(self.end? - one);
310            self.end
311        }
312    }
313    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
314        if n == 0 {
315            return self.next_back(); // Avoids sub-with-overflow below
316        }
317        let Ok(n): Result<T::RAW, _> = n.try_into() else {
318            // Fail to cast indicates that n > T::RAW::MAX, so we explicitly exhaust self.
319            self.end.take();
320            return None;
321        };
322        let Some(ret) = self.end?.checked_sub(n) else {
323            if self.end?.raw() < n {
324                panic!("Attempted to iterate over invalid address")
325            }
326            self.end.take();
327            return None;
328        };
329        self.end = Some(ret);
330        self.next_back()
331    }
332}
333
334impl<T: MemoryAddress> DoubleEndedIterator for AddrIter<T, Inclusive> {
335    fn next_back(&mut self) -> Option<Self::Item> {
336        if Inclusive::exhausted(&self.current, &self.end?) {
337            None
338        } else {
339            let ret = self.end?;
340
341            // We need to be able to step back to `0`.
342            // We return `0` when self.end is currently `0`.
343            // But then we subtract `0` by `1` triggering a sub-with-overflow
344            // When we trigger a sub-with-overflow we return early and dont decrement `self.end`
345            // The next call to self.next() will return as exhausted and the
346            let Some(step) = self.end?.checked_sub(1.into()) else {
347                // Check if this was an underflow or a non-canonical address
348                // Panic on non-canonical
349                // We can eat the overhead here because this branch is rare
350                if self.end?.raw() != 0u8.into() {
351                    panic!("Attempted to iterate over invalid address")
352                }
353                self.end = None;
354                return Some(ret);
355            };
356            self.end = Some(step);
357            Some(ret)
358        }
359    }
360
361    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
362        if n == 0 {
363            return self.next_back();
364        }
365        let Ok(n): Result<T::RAW, _> = n.try_into() else {
366            // Fail to cast indicates that n > T::RAW::MAX, so we explicitly exhaust self.
367            self.end.take();
368            return None;
369        };
370
371        let Some(ret) = self.end?.checked_sub(n) else {
372            if self.end?.raw() < n {
373                panic!("Attempted to iterate over invalid address")
374            }
375            self.end.take();
376            return None;
377        };
378        self.end = Some(ret);
379        self.end
380    }
381}
382
383impl<T: MemoryAddress> ExactSizeIterator for AddrIter<T, Inclusive> {
384    fn len(&self) -> usize {
385        let Some(end) = self.end else { return 0 };
386        (end - self.current)
387            .try_into()
388            .expect("address range is larger than the architecture's usize")
389            + 1
390    }
391}
392
393impl<T: MemoryAddress> ExactSizeIterator for AddrIter<T, NonInclusive> {
394    fn len(&self) -> usize {
395        let Some(end) = self.end else { return 0 };
396        (end - self.current)
397            .try_into()
398            .expect("address range is larger than the architecture's usize")
399    }
400}
401
402impl<T: MemoryAddress> FusedIterator for AddrIter<T> {}
403
404impl<T: MemoryAddress> From<core::ops::Range<T>> for AddrIter<T, NonInclusive> {
405    fn from(range: core::ops::Range<T>) -> Self {
406        Self {
407            current: range.start,
408            end: Some(range.end),
409            _phantom: PhantomData,
410        }
411    }
412}
413
414impl<T: MemoryAddress> From<core::ops::RangeInclusive<T>> for AddrIter<T, Inclusive> {
415    fn from(range: core::ops::RangeInclusive<T>) -> Self {
416        Self {
417            current: *range.start(),
418            end: Some(*range.end()),
419            _phantom: PhantomData,
420        }
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    pub fn virtaddr_new_truncate() {
430        assert_eq!(VirtAddr::new_truncate(0), VirtAddr::new(0));
431        assert_eq!(VirtAddr::new_truncate(123), VirtAddr::new(123));
432    }
433
434    #[test]
435    fn test_from_ptr_array() {
436        let slice = &[1, 2, 3, 4, 5];
437        // Make sure that from_ptr(slice) is the address of the first element
438        assert_eq!(VirtAddr::from_ptr(slice), VirtAddr::from_ptr(&slice[0]));
439    }
440
441    #[test]
442    #[allow(clippy::iter_nth_zero)]
443    fn test_addr_range() {
444        let r = AddrRange::new(VirtAddr::new(0x0), VirtAddr::new(0x3)).unwrap();
445        assert!(r.contains(&VirtAddr::new(0x0)));
446        assert!(r.contains(&VirtAddr::new(0x1)));
447        assert!(!r.contains(&VirtAddr::new(0x3)));
448        let mut i = r.iter();
449        assert_eq!(i.next().unwrap(), VirtAddr::new(0x0));
450        assert_eq!(i.next().unwrap(), VirtAddr::new(0x1));
451        assert_eq!(i.next().unwrap(), VirtAddr::new(0x2));
452        assert!(i.next().is_none());
453
454        for (i, a) in r.iter().enumerate() {
455            assert_eq!(a.raw() as usize, i);
456        }
457
458        assert_eq!(r.iter().nth(0), Some(VirtAddr::new(0x0)));
459        assert_eq!(r.iter().nth(1), Some(VirtAddr::new(0x1)));
460        assert_eq!(r.iter().nth(2), Some(VirtAddr::new(0x2)));
461        assert_eq!(r.iter().nth(3), None);
462
463        {
464            let mut range = r.iter();
465            assert_eq!(range.next_back(), Some(VirtAddr::new(0x2)));
466            assert_eq!(range.next_back(), Some(VirtAddr::new(0x1)));
467            assert_eq!(range.next_back(), Some(VirtAddr::new(0x0)));
468            assert_eq!(range.next_back(), None);
469            assert_eq!(range.next(), None);
470
471            let mut range = r.iter();
472            assert_eq!(range.next(), Some(VirtAddr::new(0x0)));
473            assert_eq!(range.next_back(), Some(VirtAddr::new(0x2)));
474            assert_eq!(range.next(), Some(VirtAddr::new(0x1)));
475            assert_eq!(range.next_back(), None);
476
477            assert_eq!(r.iter().nth_back(0), Some(VirtAddr::new(0x2)));
478            assert_eq!(r.iter().nth_back(1), Some(VirtAddr::new(0x1)));
479            assert_eq!(r.iter().nth_back(2), Some(VirtAddr::new(0x0)));
480            assert_eq!(r.iter().nth_back(3), None);
481        }
482
483        let r = AddrRange::new(PhysAddr::new(0x2), PhysAddr::new(0x4)).unwrap();
484        let mut i = r.iter();
485        assert_eq!(i.next().unwrap(), PhysAddr::new(0x2));
486        assert_eq!(i.next().unwrap(), PhysAddr::new(0x3));
487        assert!(i.next().is_none());
488
489        assert_eq!(r.iter().map(|a| a.raw() as usize).sum::<usize>(), 0x5);
490    }
491
492    #[test]
493    #[allow(clippy::iter_nth_zero)]
494    fn test_iter_incl() {
495        let range = VirtAddr::new(0x0)..=VirtAddr::new(0x3);
496        let mut i = AddrIter::from(range.clone());
497        assert_eq!(i.next().unwrap(), VirtAddr::new(0x0));
498        assert_eq!(i.next().unwrap(), VirtAddr::new(0x1));
499        assert_eq!(i.next().unwrap(), VirtAddr::new(0x2));
500        assert_eq!(i.next().unwrap(), VirtAddr::new(0x3));
501
502        let mut i = AddrIter::from(range.clone());
503        assert_eq!(i.next_back(), Some(VirtAddr::new(0x3)));
504        assert_eq!(i.next_back(), Some(VirtAddr::new(0x2)));
505        assert_eq!(i.next_back(), Some(VirtAddr::new(0x1)));
506        assert_eq!(i.next_back(), Some(VirtAddr::new(0x0)));
507        assert_eq!(i.next_back(), None);
508
509        let mut i = AddrIter::from(range.clone());
510        assert_eq!(i.next_back(), Some(VirtAddr::new(0x3)));
511        assert_eq!(i.next(), Some(VirtAddr::new(0x0)));
512        assert_eq!(i.next_back(), Some(VirtAddr::new(0x2)));
513        assert_eq!(i.next(), Some(VirtAddr::new(0x1)));
514        assert_eq!(i.next_back(), None);
515
516        assert_eq!(
517            AddrIter::from(range.clone()).nth(0),
518            Some(VirtAddr::new(0x0))
519        );
520        assert_eq!(
521            AddrIter::from(range.clone()).nth(1),
522            Some(VirtAddr::new(0x1))
523        );
524        assert_eq!(
525            AddrIter::from(range.clone()).nth(2),
526            Some(VirtAddr::new(0x2))
527        );
528        assert_eq!(
529            AddrIter::from(range.clone()).nth(3),
530            Some(VirtAddr::new(0x3))
531        );
532        assert_eq!(AddrIter::from(range.clone()).nth(4), None);
533    }
534
535    #[test]
536    fn iterator_assert_sizes() {
537        let range_incl = VirtAddr::new(0x0)..=VirtAddr::new(0x3);
538        assert_eq!(
539            AddrIter::from(range_incl.clone()).count(),
540            AddrIter::from(range_incl.clone()).len()
541        );
542        assert_eq!(
543            AddrIter::from(range_incl.clone()).count(),
544            AddrIter::from(range_incl.clone()).size_hint().0
545        );
546        assert_eq!(
547            AddrIter::from(range_incl.clone()).count(),
548            AddrIter::from(range_incl.clone()).size_hint().1.unwrap()
549        );
550    }
551}