x86_64/
addr.rs

1//! Physical and virtual addresses manipulation
2
3use core::convert::TryFrom;
4use core::fmt;
5#[cfg(feature = "step_trait")]
6use core::iter::Step;
7use core::ops::{Add, AddAssign, Sub, SubAssign};
8
9use crate::structures::paging::page_table::PageTableLevel;
10use crate::structures::paging::{PageOffset, PageTableIndex};
11use bit_field::BitField;
12
13const ADDRESS_SPACE_SIZE: u64 = 0x1_0000_0000_0000;
14
15/// A canonical 64-bit virtual memory address.
16///
17/// This is a wrapper type around an `u64`, so it is always 8 bytes, even when compiled
18/// on non 64-bit systems. The
19/// [`TryFrom`](https://doc.rust-lang.org/std/convert/trait.TryFrom.html) trait can be used for performing conversions
20/// between `u64` and `usize`.
21///
22/// On `x86_64`, only the 48 lower bits of a virtual address can be used. The top 16 bits need
23/// to be copies of bit 47, i.e. the most significant bit. Addresses that fulfil this criterion
24/// are called “canonical”. This type guarantees that it always represents a canonical address.
25#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
26#[repr(transparent)]
27pub struct VirtAddr(u64);
28
29/// A 64-bit physical memory address.
30///
31/// This is a wrapper type around an `u64`, so it is always 8 bytes, even when compiled
32/// on non 64-bit systems. The
33/// [`TryFrom`](https://doc.rust-lang.org/std/convert/trait.TryFrom.html) trait can be used for performing conversions
34/// between `u64` and `usize`.
35///
36/// On `x86_64`, only the 52 lower bits of a physical address can be used. The top 12 bits need
37/// to be zero. This type guarantees that it always represents a valid physical address.
38#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
39#[repr(transparent)]
40pub struct PhysAddr(u64);
41
42/// A passed `u64` was not a valid virtual address.
43///
44/// This means that bits 48 to 64 are not
45/// a valid sign extension and are not null either. So automatic sign extension would have
46/// overwritten possibly meaningful bits. This likely indicates a bug, for example an invalid
47/// address calculation.
48///
49/// Contains the invalid address.
50pub struct VirtAddrNotValid(pub u64);
51
52impl core::fmt::Debug for VirtAddrNotValid {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        f.debug_tuple("VirtAddrNotValid")
55            .field(&format_args!("{:#x}", self.0))
56            .finish()
57    }
58}
59
60impl VirtAddr {
61    /// Creates a new canonical virtual address.
62    ///
63    /// The provided address should already be canonical. If you want to check
64    /// whether an address is canonical, use [`try_new`](Self::try_new).
65    ///
66    /// ## Panics
67    ///
68    /// This function panics if the bits in the range 48 to 64 are invalid
69    /// (i.e. are not a proper sign extension of bit 47).
70    #[inline]
71    pub const fn new(addr: u64) -> VirtAddr {
72        // TODO: Replace with .ok().expect(msg) when that works on stable.
73        match Self::try_new(addr) {
74            Ok(v) => v,
75            Err(_) => panic!("virtual address must be sign extended in bits 48 to 64"),
76        }
77    }
78
79    /// Tries to create a new canonical virtual address.
80    ///
81    /// This function checks wether the given address is canonical
82    /// and returns an error otherwise. An address is canonical
83    /// if bits 48 to 64 are a correct sign
84    /// extension (i.e. copies of bit 47).
85    #[inline]
86    pub const fn try_new(addr: u64) -> Result<VirtAddr, VirtAddrNotValid> {
87        let v = Self::new_truncate(addr);
88        if v.0 == addr {
89            Ok(v)
90        } else {
91            Err(VirtAddrNotValid(addr))
92        }
93    }
94
95    /// Creates a new canonical virtual address, throwing out bits 48..64.
96    ///
97    /// This function performs sign extension of bit 47 to make the address
98    /// canonical, overwriting bits 48 to 64. If you want to check whether an
99    /// address is canonical, use [`new`](Self::new) or [`try_new`](Self::try_new).
100    #[inline]
101    pub const fn new_truncate(addr: u64) -> VirtAddr {
102        // By doing the right shift as a signed operation (on a i64), it will
103        // sign extend the value, repeating the leftmost bit.
104        VirtAddr(((addr << 16) as i64 >> 16) as u64)
105    }
106
107    /// Creates a new virtual address, without any checks.
108    ///
109    /// ## Safety
110    ///
111    /// You must make sure bits 48..64 are equal to bit 47. This is not checked.
112    #[inline]
113    pub const unsafe fn new_unsafe(addr: u64) -> VirtAddr {
114        VirtAddr(addr)
115    }
116
117    /// Creates a virtual address that points to `0`.
118    #[inline]
119    pub const fn zero() -> VirtAddr {
120        VirtAddr(0)
121    }
122
123    /// Converts the address to an `u64`.
124    #[inline]
125    pub const fn as_u64(self) -> u64 {
126        self.0
127    }
128
129    /// Creates a virtual address from the given pointer
130    #[cfg(target_pointer_width = "64")]
131    #[inline]
132    pub fn from_ptr<T: ?Sized>(ptr: *const T) -> Self {
133        Self::new(ptr as *const () as u64)
134    }
135
136    /// Converts the address to a raw pointer.
137    #[cfg(target_pointer_width = "64")]
138    #[inline]
139    pub const fn as_ptr<T>(self) -> *const T {
140        self.as_u64() as *const T
141    }
142
143    /// Converts the address to a mutable raw pointer.
144    #[cfg(target_pointer_width = "64")]
145    #[inline]
146    pub const fn as_mut_ptr<T>(self) -> *mut T {
147        self.as_ptr::<T>() as *mut T
148    }
149
150    /// Convenience method for checking if a virtual address is null.
151    #[inline]
152    pub const fn is_null(self) -> bool {
153        self.0 == 0
154    }
155
156    /// Aligns the virtual address upwards to the given alignment.
157    ///
158    /// See the `align_up` function for more information.
159    ///
160    /// # Panics
161    ///
162    /// This function panics if the resulting address is higher than
163    /// `0xffff_ffff_ffff_ffff`.
164    #[inline]
165    pub fn align_up<U>(self, align: U) -> Self
166    where
167        U: Into<u64>,
168    {
169        VirtAddr::new_truncate(align_up(self.0, align.into()))
170    }
171
172    /// Aligns the virtual address downwards to the given alignment.
173    ///
174    /// See the `align_down` function for more information.
175    #[inline]
176    pub fn align_down<U>(self, align: U) -> Self
177    where
178        U: Into<u64>,
179    {
180        self.align_down_u64(align.into())
181    }
182
183    /// Aligns the virtual address downwards to the given alignment.
184    ///
185    /// See the `align_down` function for more information.
186    #[inline]
187    pub(crate) const fn align_down_u64(self, align: u64) -> Self {
188        VirtAddr::new_truncate(align_down(self.0, align))
189    }
190
191    /// Checks whether the virtual address has the demanded alignment.
192    #[inline]
193    pub fn is_aligned<U>(self, align: U) -> bool
194    where
195        U: Into<u64>,
196    {
197        self.is_aligned_u64(align.into())
198    }
199
200    /// Checks whether the virtual address has the demanded alignment.
201    #[inline]
202    pub(crate) const fn is_aligned_u64(self, align: u64) -> bool {
203        self.align_down_u64(align).as_u64() == self.as_u64()
204    }
205
206    /// Returns the 12-bit page offset of this virtual address.
207    #[inline]
208    pub const fn page_offset(self) -> PageOffset {
209        PageOffset::new_truncate(self.0 as u16)
210    }
211
212    /// Returns the 9-bit level 1 page table index.
213    #[inline]
214    pub const fn p1_index(self) -> PageTableIndex {
215        PageTableIndex::new_truncate((self.0 >> 12) as u16)
216    }
217
218    /// Returns the 9-bit level 2 page table index.
219    #[inline]
220    pub const fn p2_index(self) -> PageTableIndex {
221        PageTableIndex::new_truncate((self.0 >> 12 >> 9) as u16)
222    }
223
224    /// Returns the 9-bit level 3 page table index.
225    #[inline]
226    pub const fn p3_index(self) -> PageTableIndex {
227        PageTableIndex::new_truncate((self.0 >> 12 >> 9 >> 9) as u16)
228    }
229
230    /// Returns the 9-bit level 4 page table index.
231    #[inline]
232    pub const fn p4_index(self) -> PageTableIndex {
233        PageTableIndex::new_truncate((self.0 >> 12 >> 9 >> 9 >> 9) as u16)
234    }
235
236    /// Returns the 9-bit level page table index.
237    #[inline]
238    pub const fn page_table_index(self, level: PageTableLevel) -> PageTableIndex {
239        PageTableIndex::new_truncate((self.0 >> 12 >> ((level as u8 - 1) * 9)) as u16)
240    }
241
242    // FIXME: Move this into the `Step` impl, once `Step` is stabilized.
243    #[cfg(feature = "step_trait")]
244    pub(crate) fn steps_between_impl(start: &Self, end: &Self) -> (usize, Option<usize>) {
245        if let Some(steps) = Self::steps_between_u64(start, end) {
246            let steps = usize::try_from(steps).ok();
247            (steps.unwrap_or(usize::MAX), steps)
248        } else {
249            (0, None)
250        }
251    }
252
253    /// An implementation of steps_between that returns u64. Note that this
254    /// function always returns the exact bound, so it doesn't need to return a
255    /// lower and upper bound like steps_between does.
256    #[cfg(any(feature = "instructions", feature = "step_trait"))]
257    pub(crate) fn steps_between_u64(start: &Self, end: &Self) -> Option<u64> {
258        let mut steps = end.0.checked_sub(start.0)?;
259
260        // Mask away extra bits that appear while jumping the gap.
261        steps &= 0xffff_ffff_ffff;
262
263        Some(steps)
264    }
265
266    // FIXME: Move this into the `Step` impl, once `Step` is stabilized.
267    #[inline]
268    pub(crate) fn forward_checked_impl(start: Self, count: usize) -> Option<Self> {
269        Self::forward_checked_u64(start, u64::try_from(count).ok()?)
270    }
271
272    /// An implementation of forward_checked that takes u64 instead of usize.
273    #[inline]
274    pub(crate) fn forward_checked_u64(start: Self, count: u64) -> Option<Self> {
275        if count > ADDRESS_SPACE_SIZE {
276            return None;
277        }
278
279        let mut addr = start.0.checked_add(count)?;
280
281        match addr.get_bits(47..) {
282            0x1 => {
283                // Jump the gap by sign extending the 47th bit.
284                addr.set_bits(47.., 0x1ffff);
285            }
286            0x2 => {
287                // Address overflow
288                return None;
289            }
290            _ => {}
291        }
292
293        Some(unsafe { Self::new_unsafe(addr) })
294    }
295
296    /// An implementation of backward_checked that takes u64 instead of usize.
297    #[cfg(feature = "step_trait")]
298    #[inline]
299    pub(crate) fn backward_checked_u64(start: Self, count: u64) -> Option<Self> {
300        if count > ADDRESS_SPACE_SIZE {
301            return None;
302        }
303
304        let mut addr = start.0.checked_sub(count)?;
305
306        match addr.get_bits(47..) {
307            0x1fffe => {
308                // Jump the gap by sign extending the 47th bit.
309                addr.set_bits(47.., 0);
310            }
311            0x1fffd => {
312                // Address underflow
313                return None;
314            }
315            _ => {}
316        }
317
318        Some(unsafe { Self::new_unsafe(addr) })
319    }
320}
321
322impl fmt::Debug for VirtAddr {
323    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
324        f.debug_tuple("VirtAddr")
325            .field(&format_args!("{:#x}", self.0))
326            .finish()
327    }
328}
329
330impl fmt::Binary for VirtAddr {
331    #[inline]
332    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
333        fmt::Binary::fmt(&self.0, f)
334    }
335}
336
337impl fmt::LowerHex for VirtAddr {
338    #[inline]
339    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
340        fmt::LowerHex::fmt(&self.0, f)
341    }
342}
343
344impl fmt::Octal for VirtAddr {
345    #[inline]
346    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
347        fmt::Octal::fmt(&self.0, f)
348    }
349}
350
351impl fmt::UpperHex for VirtAddr {
352    #[inline]
353    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
354        fmt::UpperHex::fmt(&self.0, f)
355    }
356}
357
358impl fmt::Pointer for VirtAddr {
359    #[inline]
360    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
361        fmt::Pointer::fmt(&(self.0 as *const ()), f)
362    }
363}
364
365impl Add<u64> for VirtAddr {
366    type Output = Self;
367    #[inline]
368    fn add(self, rhs: u64) -> Self::Output {
369        VirtAddr::new(self.0 + rhs)
370    }
371}
372
373impl AddAssign<u64> for VirtAddr {
374    #[inline]
375    fn add_assign(&mut self, rhs: u64) {
376        *self = *self + rhs;
377    }
378}
379
380impl Sub<u64> for VirtAddr {
381    type Output = Self;
382    #[inline]
383    fn sub(self, rhs: u64) -> Self::Output {
384        VirtAddr::new(self.0.checked_sub(rhs).unwrap())
385    }
386}
387
388impl SubAssign<u64> for VirtAddr {
389    #[inline]
390    fn sub_assign(&mut self, rhs: u64) {
391        *self = *self - rhs;
392    }
393}
394
395impl Sub<VirtAddr> for VirtAddr {
396    type Output = u64;
397    #[inline]
398    fn sub(self, rhs: VirtAddr) -> Self::Output {
399        self.as_u64().checked_sub(rhs.as_u64()).unwrap()
400    }
401}
402
403#[cfg(feature = "step_trait")]
404impl Step for VirtAddr {
405    #[inline]
406    fn steps_between(start: &Self, end: &Self) -> (usize, Option<usize>) {
407        Self::steps_between_impl(start, end)
408    }
409
410    #[inline]
411    fn forward_checked(start: Self, count: usize) -> Option<Self> {
412        Self::forward_checked_impl(start, count)
413    }
414
415    #[inline]
416    fn backward_checked(start: Self, count: usize) -> Option<Self> {
417        Self::backward_checked_u64(start, u64::try_from(count).ok()?)
418    }
419}
420
421/// A passed `u64` was not a valid physical address.
422///
423/// This means that bits 52 to 64 were not all null.
424///
425/// Contains the invalid address.
426pub struct PhysAddrNotValid(pub u64);
427
428impl core::fmt::Debug for PhysAddrNotValid {
429    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
430        f.debug_tuple("PhysAddrNotValid")
431            .field(&format_args!("{:#x}", self.0))
432            .finish()
433    }
434}
435
436impl PhysAddr {
437    /// Creates a new physical address.
438    ///
439    /// ## Panics
440    ///
441    /// This function panics if a bit in the range 52 to 64 is set.
442    #[inline]
443    pub const fn new(addr: u64) -> Self {
444        // TODO: Replace with .ok().expect(msg) when that works on stable.
445        match Self::try_new(addr) {
446            Ok(p) => p,
447            Err(_) => panic!("physical addresses must not have any bits in the range 52 to 64 set"),
448        }
449    }
450
451    /// Creates a new physical address, throwing bits 52..64 away.
452    #[inline]
453    pub const fn new_truncate(addr: u64) -> PhysAddr {
454        PhysAddr(addr % (1 << 52))
455    }
456
457    /// Creates a new physical address, without any checks.
458    ///
459    /// ## Safety
460    ///
461    /// You must make sure bits 52..64 are zero. This is not checked.
462    #[inline]
463    pub const unsafe fn new_unsafe(addr: u64) -> PhysAddr {
464        PhysAddr(addr)
465    }
466
467    /// Tries to create a new physical address.
468    ///
469    /// Fails if any bits in the range 52 to 64 are set.
470    #[inline]
471    pub const fn try_new(addr: u64) -> Result<Self, PhysAddrNotValid> {
472        let p = Self::new_truncate(addr);
473        if p.0 == addr {
474            Ok(p)
475        } else {
476            Err(PhysAddrNotValid(addr))
477        }
478    }
479
480    /// Creates a physical address that points to `0`.
481    #[inline]
482    pub const fn zero() -> PhysAddr {
483        PhysAddr(0)
484    }
485
486    /// Converts the address to an `u64`.
487    #[inline]
488    pub const fn as_u64(self) -> u64 {
489        self.0
490    }
491
492    /// Convenience method for checking if a physical address is null.
493    #[inline]
494    pub const fn is_null(self) -> bool {
495        self.0 == 0
496    }
497
498    /// Aligns the physical address upwards to the given alignment.
499    ///
500    /// See the `align_up` function for more information.
501    ///
502    /// # Panics
503    ///
504    /// This function panics if the resulting address has a bit in the range 52
505    /// to 64 set.
506    #[inline]
507    pub fn align_up<U>(self, align: U) -> Self
508    where
509        U: Into<u64>,
510    {
511        PhysAddr::new(align_up(self.0, align.into()))
512    }
513
514    /// Aligns the physical address downwards to the given alignment.
515    ///
516    /// See the `align_down` function for more information.
517    #[inline]
518    pub fn align_down<U>(self, align: U) -> Self
519    where
520        U: Into<u64>,
521    {
522        self.align_down_u64(align.into())
523    }
524
525    /// Aligns the physical address downwards to the given alignment.
526    ///
527    /// See the `align_down` function for more information.
528    #[inline]
529    pub(crate) const fn align_down_u64(self, align: u64) -> Self {
530        PhysAddr(align_down(self.0, align))
531    }
532
533    /// Checks whether the physical address has the demanded alignment.
534    #[inline]
535    pub fn is_aligned<U>(self, align: U) -> bool
536    where
537        U: Into<u64>,
538    {
539        self.is_aligned_u64(align.into())
540    }
541
542    /// Checks whether the physical address has the demanded alignment.
543    #[inline]
544    pub(crate) const fn is_aligned_u64(self, align: u64) -> bool {
545        self.align_down_u64(align).as_u64() == self.as_u64()
546    }
547}
548
549impl fmt::Debug for PhysAddr {
550    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
551        f.debug_tuple("PhysAddr")
552            .field(&format_args!("{:#x}", self.0))
553            .finish()
554    }
555}
556
557impl fmt::Binary for PhysAddr {
558    #[inline]
559    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
560        fmt::Binary::fmt(&self.0, f)
561    }
562}
563
564impl fmt::LowerHex for PhysAddr {
565    #[inline]
566    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
567        fmt::LowerHex::fmt(&self.0, f)
568    }
569}
570
571impl fmt::Octal for PhysAddr {
572    #[inline]
573    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
574        fmt::Octal::fmt(&self.0, f)
575    }
576}
577
578impl fmt::UpperHex for PhysAddr {
579    #[inline]
580    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
581        fmt::UpperHex::fmt(&self.0, f)
582    }
583}
584
585impl fmt::Pointer for PhysAddr {
586    #[inline]
587    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
588        fmt::Pointer::fmt(&(self.0 as *const ()), f)
589    }
590}
591
592impl Add<u64> for PhysAddr {
593    type Output = Self;
594    #[inline]
595    fn add(self, rhs: u64) -> Self::Output {
596        PhysAddr::new(self.0 + rhs)
597    }
598}
599
600impl AddAssign<u64> for PhysAddr {
601    #[inline]
602    fn add_assign(&mut self, rhs: u64) {
603        *self = *self + rhs;
604    }
605}
606
607impl Sub<u64> for PhysAddr {
608    type Output = Self;
609    #[inline]
610    fn sub(self, rhs: u64) -> Self::Output {
611        PhysAddr::new(self.0.checked_sub(rhs).unwrap())
612    }
613}
614
615impl SubAssign<u64> for PhysAddr {
616    #[inline]
617    fn sub_assign(&mut self, rhs: u64) {
618        *self = *self - rhs;
619    }
620}
621
622impl Sub<PhysAddr> for PhysAddr {
623    type Output = u64;
624    #[inline]
625    fn sub(self, rhs: PhysAddr) -> Self::Output {
626        self.as_u64().checked_sub(rhs.as_u64()).unwrap()
627    }
628}
629
630/// Align address downwards.
631///
632/// Returns the greatest `x` with alignment `align` so that `x <= addr`.
633///
634/// Panics if the alignment is not a power of two.
635#[inline]
636pub const fn align_down(addr: u64, align: u64) -> u64 {
637    assert!(align.is_power_of_two(), "`align` must be a power of two");
638    addr & !(align - 1)
639}
640
641/// Align address upwards.
642///
643/// Returns the smallest `x` with alignment `align` so that `x >= addr`.
644///
645/// Panics if the alignment is not a power of two or if an overflow occurs.
646#[inline]
647pub const fn align_up(addr: u64, align: u64) -> u64 {
648    assert!(align.is_power_of_two(), "`align` must be a power of two");
649    let align_mask = align - 1;
650    if addr & align_mask == 0 {
651        addr // already aligned
652    } else {
653        // FIXME: Replace with .expect, once `Option::expect` is const.
654        if let Some(aligned) = (addr | align_mask).checked_add(1) {
655            aligned
656        } else {
657            panic!("attempt to add with overflow")
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    pub fn virtaddr_new_truncate() {
668        assert_eq!(VirtAddr::new_truncate(0), VirtAddr(0));
669        assert_eq!(VirtAddr::new_truncate(1 << 47), VirtAddr(0xfffff << 47));
670        assert_eq!(VirtAddr::new_truncate(123), VirtAddr(123));
671        assert_eq!(VirtAddr::new_truncate(123 << 47), VirtAddr(0xfffff << 47));
672    }
673
674    #[test]
675    #[cfg(feature = "step_trait")]
676    fn virtaddr_step_forward() {
677        assert_eq!(Step::forward(VirtAddr(0), 0), VirtAddr(0));
678        assert_eq!(Step::forward(VirtAddr(0), 1), VirtAddr(1));
679        assert_eq!(
680            Step::forward(VirtAddr(0x7fff_ffff_ffff), 1),
681            VirtAddr(0xffff_8000_0000_0000)
682        );
683        assert_eq!(
684            Step::forward(VirtAddr(0xffff_8000_0000_0000), 1),
685            VirtAddr(0xffff_8000_0000_0001)
686        );
687        assert_eq!(
688            Step::forward_checked(VirtAddr(0xffff_ffff_ffff_ffff), 1),
689            None
690        );
691        #[cfg(target_pointer_width = "64")]
692        assert_eq!(
693            Step::forward(VirtAddr(0x7fff_ffff_ffff), 0x1234_5678_9abd),
694            VirtAddr(0xffff_9234_5678_9abc)
695        );
696        #[cfg(target_pointer_width = "64")]
697        assert_eq!(
698            Step::forward(VirtAddr(0x7fff_ffff_ffff), 0x8000_0000_0000),
699            VirtAddr(0xffff_ffff_ffff_ffff)
700        );
701        #[cfg(target_pointer_width = "64")]
702        assert_eq!(
703            Step::forward(VirtAddr(0x7fff_ffff_ff00), 0x8000_0000_00ff),
704            VirtAddr(0xffff_ffff_ffff_ffff)
705        );
706        #[cfg(target_pointer_width = "64")]
707        assert_eq!(
708            Step::forward_checked(VirtAddr(0x7fff_ffff_ff00), 0x8000_0000_0100),
709            None
710        );
711        #[cfg(target_pointer_width = "64")]
712        assert_eq!(
713            Step::forward_checked(VirtAddr(0x7fff_ffff_ffff), 0x8000_0000_0001),
714            None
715        );
716    }
717
718    #[test]
719    #[cfg(feature = "step_trait")]
720    fn virtaddr_step_backward() {
721        assert_eq!(Step::backward(VirtAddr(0), 0), VirtAddr(0));
722        assert_eq!(Step::backward_checked(VirtAddr(0), 1), None);
723        assert_eq!(Step::backward(VirtAddr(1), 1), VirtAddr(0));
724        assert_eq!(
725            Step::backward(VirtAddr(0xffff_8000_0000_0000), 1),
726            VirtAddr(0x7fff_ffff_ffff)
727        );
728        assert_eq!(
729            Step::backward(VirtAddr(0xffff_8000_0000_0001), 1),
730            VirtAddr(0xffff_8000_0000_0000)
731        );
732        #[cfg(target_pointer_width = "64")]
733        assert_eq!(
734            Step::backward(VirtAddr(0xffff_9234_5678_9abc), 0x1234_5678_9abd),
735            VirtAddr(0x7fff_ffff_ffff)
736        );
737        #[cfg(target_pointer_width = "64")]
738        assert_eq!(
739            Step::backward(VirtAddr(0xffff_8000_0000_0000), 0x8000_0000_0000),
740            VirtAddr(0)
741        );
742        #[cfg(target_pointer_width = "64")]
743        assert_eq!(
744            Step::backward(VirtAddr(0xffff_8000_0000_0000), 0x7fff_ffff_ff01),
745            VirtAddr(0xff)
746        );
747        #[cfg(target_pointer_width = "64")]
748        assert_eq!(
749            Step::backward_checked(VirtAddr(0xffff_8000_0000_0000), 0x8000_0000_0001),
750            None
751        );
752    }
753
754    #[test]
755    #[cfg(feature = "step_trait")]
756    fn virtaddr_steps_between() {
757        assert_eq!(
758            Step::steps_between(&VirtAddr(0), &VirtAddr(0)),
759            (0, Some(0))
760        );
761        assert_eq!(
762            Step::steps_between(&VirtAddr(0), &VirtAddr(1)),
763            (1, Some(1))
764        );
765        assert_eq!(Step::steps_between(&VirtAddr(1), &VirtAddr(0)), (0, None));
766        assert_eq!(
767            Step::steps_between(
768                &VirtAddr(0x7fff_ffff_ffff),
769                &VirtAddr(0xffff_8000_0000_0000)
770            ),
771            (1, Some(1))
772        );
773        assert_eq!(
774            Step::steps_between(
775                &VirtAddr(0xffff_8000_0000_0000),
776                &VirtAddr(0x7fff_ffff_ffff)
777            ),
778            (0, None)
779        );
780        assert_eq!(
781            Step::steps_between(
782                &VirtAddr(0xffff_8000_0000_0000),
783                &VirtAddr(0xffff_8000_0000_0000)
784            ),
785            (0, Some(0))
786        );
787        assert_eq!(
788            Step::steps_between(
789                &VirtAddr(0xffff_8000_0000_0000),
790                &VirtAddr(0xffff_8000_0000_0001)
791            ),
792            (1, Some(1))
793        );
794        assert_eq!(
795            Step::steps_between(
796                &VirtAddr(0xffff_8000_0000_0001),
797                &VirtAddr(0xffff_8000_0000_0000)
798            ),
799            (0, None)
800        );
801        // Make sure that we handle `steps > u32::MAX` correctly on 32-bit
802        // targets. On 64-bit targets, `0x1_0000_0000` fits into `usize`, so we
803        // can return exact lower and upper bounds. On 32-bit targets,
804        // `0x1_0000_0000` doesn't fit into `usize`, so we only return an lower
805        // bound of `usize::MAX` and don't return an upper bound.
806        #[cfg(target_pointer_width = "64")]
807        assert_eq!(
808            Step::steps_between(&VirtAddr(0), &VirtAddr(0x1_0000_0000)),
809            (0x1_0000_0000, Some(0x1_0000_0000))
810        );
811        #[cfg(not(target_pointer_width = "64"))]
812        assert_eq!(
813            Step::steps_between(&VirtAddr(0), &VirtAddr(0x1_0000_0000)),
814            (usize::MAX, None)
815        );
816    }
817
818    #[test]
819    pub fn test_align_up() {
820        // align 1
821        assert_eq!(align_up(0, 1), 0);
822        assert_eq!(align_up(1234, 1), 1234);
823        assert_eq!(align_up(0xffff_ffff_ffff_ffff, 1), 0xffff_ffff_ffff_ffff);
824        // align 2
825        assert_eq!(align_up(0, 2), 0);
826        assert_eq!(align_up(1233, 2), 1234);
827        assert_eq!(align_up(0xffff_ffff_ffff_fffe, 2), 0xffff_ffff_ffff_fffe);
828        // address 0
829        assert_eq!(align_up(0, 128), 0);
830        assert_eq!(align_up(0, 1), 0);
831        assert_eq!(align_up(0, 2), 0);
832        assert_eq!(align_up(0, 0x8000_0000_0000_0000), 0);
833    }
834
835    #[test]
836    fn test_virt_addr_align_up() {
837        // Make sure the 47th bit is extended.
838        assert_eq!(
839            VirtAddr::new(0x7fff_ffff_ffff).align_up(2u64),
840            VirtAddr::new(0xffff_8000_0000_0000)
841        );
842    }
843
844    #[test]
845    fn test_virt_addr_align_down() {
846        // Make sure the 47th bit is extended.
847        assert_eq!(
848            VirtAddr::new(0xffff_8000_0000_0000).align_down(1u64 << 48),
849            VirtAddr::new(0)
850        );
851    }
852
853    #[test]
854    #[should_panic]
855    fn test_virt_addr_align_up_overflow() {
856        VirtAddr::new(0xffff_ffff_ffff_ffff).align_up(2u64);
857    }
858
859    #[test]
860    #[should_panic]
861    fn test_phys_addr_align_up_overflow() {
862        PhysAddr::new(0x000f_ffff_ffff_ffff).align_up(2u64);
863    }
864
865    #[test]
866    #[cfg(target_pointer_width = "64")]
867    fn test_from_ptr_array() {
868        let slice = &[1, 2, 3, 4, 5];
869        // Make sure that from_ptr(slice) is the address of the first element
870        assert_eq!(
871            VirtAddr::from_ptr(slice.as_slice()),
872            VirtAddr::from_ptr(&slice[0])
873        );
874    }
875}
876
877#[cfg(kani)]
878mod proofs {
879    use super::*;
880
881    // The next two proof harnesses prove the correctness of the `forward`
882    // implementation of VirtAddr.
883
884    // This harness proves that our implementation can correctly take 0 or 1
885    // step starting from any address.
886    #[kani::proof]
887    fn forward_base_case() {
888        let start_raw: u64 = kani::any();
889        let Ok(start) = VirtAddr::try_new(start_raw) else {
890            return;
891        };
892
893        // Adding 0 to any address should always yield the same address.
894        let same = Step::forward(start, 0);
895        assert!(start == same);
896
897        // Manually calculate the expected address after stepping once.
898        let expected = match start_raw {
899            // Adding 1 to addresses in this range don't require gap jumps, so
900            // we can just add 1.
901            0x0000_0000_0000_0000..=0x0000_7fff_ffff_fffe => Some(start_raw + 1),
902            // Adding 1 to this address jumps the gap.
903            0x0000_7fff_ffff_ffff => Some(0xffff_8000_0000_0000),
904            // The range of non-canonical addresses.
905            0x0000_8000_0000_0000..=0xffff_7fff_ffff_ffff => unreachable!(),
906            // Adding 1 to addresses in this range don't require gap jumps, so
907            // we can just add 1.
908            0xffff_8000_0000_0000..=0xffff_ffff_ffff_fffe => Some(start_raw + 1),
909            // Adding 1 to this address causes an overflow.
910            0xffff_ffff_ffff_ffff => None,
911        };
912        if let Some(expected) = expected {
913            // Verify that `expected` is a valid address.
914            assert!(VirtAddr::try_new(expected).is_ok());
915        }
916        // Verify `forward_checked`.
917        let next = Step::forward_checked(start, 1);
918        assert!(next.map(VirtAddr::as_u64) == expected);
919    }
920
921    // This harness proves that the result of taking two small steps is the
922    // same as taking one combined large step.
923    #[kani::proof]
924    fn forward_induction_step() {
925        let start_raw: u64 = kani::any();
926        let Ok(start) = VirtAddr::try_new(start_raw) else {
927            return;
928        };
929
930        let count1: usize = kani::any();
931        let count2: usize = kani::any();
932        // If we can take two small steps...
933        let Some(next1) = Step::forward_checked(start, count1) else {
934            return;
935        };
936        let Some(next2) = Step::forward_checked(next1, count2) else {
937            return;
938        };
939
940        // ...then we can also take one combined large step.
941        let count_both = count1 + count2;
942        let next_both = Step::forward(start, count_both);
943        assert!(next2 == next_both);
944    }
945
946    // The next two proof harnesses prove the correctness of the `backward`
947    // implementation of VirtAddr using the `forward` implementation which
948    // we've already proven to be correct.
949    // They do this by proving the symmetry between those two functions.
950
951    // This harness proves the correctness of the implementation of `backward`
952    // for all inputs for which `forward_checked` succeeds.
953    #[kani::proof]
954    fn forward_implies_backward() {
955        let start_raw: u64 = kani::any();
956        let Ok(start) = VirtAddr::try_new(start_raw) else {
957            return;
958        };
959        let count: usize = kani::any();
960
961        // If `forward_checked` succeeds...
962        let Some(end) = Step::forward_checked(start, count) else {
963            return;
964        };
965
966        // ...then `backward` succeeds as well.
967        let start2 = Step::backward(end, count);
968        assert!(start == start2);
969    }
970
971    // This harness proves that for all inputs for which `backward_checked`
972    // succeeds, `forward` succeeds as well.
973    #[kani::proof]
974    fn backward_implies_forward() {
975        let end_raw: u64 = kani::any();
976        let Ok(end) = VirtAddr::try_new(end_raw) else {
977            return;
978        };
979        let count: usize = kani::any();
980
981        // If `backward_checked` succeeds...
982        let Some(start) = Step::backward_checked(end, count) else {
983            return;
984        };
985
986        // ...then `forward` succeeds as well.
987        let end2 = Step::forward(start, count);
988        assert!(end == end2);
989    }
990
991    // The next two proof harnesses prove the correctness of the
992    // `steps_between` implementation of VirtAddr using the `forward`
993    // implementation which we've already proven to be correct.
994    // They do this by proving the symmetry between those two functions.
995
996    // This harness proves the correctness of the implementation of
997    // `steps_between` for all inputs for which `forward_checked` succeeds.
998    #[kani::proof]
999    fn forward_implies_steps_between() {
1000        let start: u64 = kani::any();
1001        let Ok(start) = VirtAddr::try_new(start) else {
1002            return;
1003        };
1004        let count: usize = kani::any();
1005
1006        // If `forward_checked` succeeds...
1007        let Some(end) = Step::forward_checked(start, count) else {
1008            return;
1009        };
1010
1011        // ...then `steps_between` succeeds as well.
1012        assert!(Step::steps_between(&start, &end) == (count, Some(count)));
1013    }
1014
1015    // This harness proves that for all inputs for which `steps_between`
1016    // succeeds, `forward` succeeds as well.
1017    #[kani::proof]
1018    fn steps_between_implies_forward() {
1019        let start: u64 = kani::any();
1020        let Ok(start) = VirtAddr::try_new(start) else {
1021            return;
1022        };
1023        let end: u64 = kani::any();
1024        let Ok(end) = VirtAddr::try_new(end) else {
1025            return;
1026        };
1027
1028        // If `steps_between` succeeds...
1029        let Some(count) = Step::steps_between(&start, &end).1 else {
1030            return;
1031        };
1032
1033        // ...then `forward` succeeds as well.
1034        assert!(Step::forward(start, count) == end);
1035    }
1036}