core/slice/rotate.rs
1use crate::mem::{self, MaybeUninit, SizedTypeProperties};
2use crate::{cmp, ptr};
3
4type BufType = [usize; 32];
5
6/// Rotates the range `[mid-left, mid+right)` such that the element at `mid` becomes the first
7/// element. Equivalently, rotates the range `left` elements to the left or `right` elements to the
8/// right.
9///
10/// # Safety
11///
12/// The specified range must be valid for reading and writing.
13#[inline]
14pub(super) unsafe fn ptr_rotate<T>(left: usize, mid: *mut T, right: usize) {
15    if T::IS_ZST {
16        return;
17    }
18    // abort early if the rotate is a no-op
19    if (left == 0) || (right == 0) {
20        return;
21    }
22    // `T` is not a zero-sized type, so it's okay to divide by its size.
23    if !cfg!(feature = "optimize_for_size")
24        && cmp::min(left, right) <= mem::size_of::<BufType>() / mem::size_of::<T>()
25    {
26        // SAFETY: guaranteed by the caller
27        unsafe { ptr_rotate_memmove(left, mid, right) };
28    } else if !cfg!(feature = "optimize_for_size")
29        && ((left + right < 24) || (mem::size_of::<T>() > mem::size_of::<[usize; 4]>()))
30    {
31        // SAFETY: guaranteed by the caller
32        unsafe { ptr_rotate_gcd(left, mid, right) }
33    } else {
34        // SAFETY: guaranteed by the caller
35        unsafe { ptr_rotate_swap(left, mid, right) }
36    }
37}
38
39/// Algorithm 1 is used if `min(left, right)` is small enough to fit onto a stack buffer. The
40/// `min(left, right)` elements are copied onto the buffer, `memmove` is applied to the others, and
41/// the ones on the buffer are moved back into the hole on the opposite side of where they
42/// originated.
43///
44/// # Safety
45///
46/// The specified range must be valid for reading and writing.
47#[inline]
48unsafe fn ptr_rotate_memmove<T>(left: usize, mid: *mut T, right: usize) {
49    // The `[T; 0]` here is to ensure this is appropriately aligned for T
50    let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit();
51    let buf = rawarray.as_mut_ptr() as *mut T;
52    // SAFETY: `mid-left <= mid-left+right < mid+right`
53    let dim = unsafe { mid.sub(left).add(right) };
54    if left <= right {
55        // SAFETY:
56        //
57        // 1) The `if` condition about the sizes ensures `[mid-left; left]` will fit in
58        //    `buf` without overflow and `buf` was created just above and so cannot be
59        //    overlapped with any value of `[mid-left; left]`
60        // 2) [mid-left, mid+right) are all valid for reading and writing and we don't care
61        //    about overlaps here.
62        // 3) The `if` condition about `left <= right` ensures writing `left` elements to
63        //    `dim = mid-left+right` is valid because:
64        //    - `buf` is valid and `left` elements were written in it in 1)
65        //    - `dim+left = mid-left+right+left = mid+right` and we write `[dim, dim+left)`
66        unsafe {
67            // 1)
68            ptr::copy_nonoverlapping(mid.sub(left), buf, left);
69            // 2)
70            ptr::copy(mid, mid.sub(left), right);
71            // 3)
72            ptr::copy_nonoverlapping(buf, dim, left);
73        }
74    } else {
75        // SAFETY: same reasoning as above but with `left` and `right` reversed
76        unsafe {
77            ptr::copy_nonoverlapping(mid, buf, right);
78            ptr::copy(mid.sub(left), dim, left);
79            ptr::copy_nonoverlapping(buf, mid.sub(left), right);
80        }
81    }
82}
83
84/// Algorithm 2 is used for small values of `left + right` or for large `T`. The elements
85/// are moved into their final positions one at a time starting at `mid - left` and advancing by
86/// `right` steps modulo `left + right`, such that only one temporary is needed. Eventually, we
87/// arrive back at `mid - left`. However, if `gcd(left + right, right)` is not 1, the above steps
88/// skipped over elements. For example:
89/// ```text
90/// left = 10, right = 6
91/// the `^` indicates an element in its final place
92/// 6 7 8 9 10 11 12 13 14 15 . 0 1 2 3 4 5
93/// after using one step of the above algorithm (The X will be overwritten at the end of the round,
94/// and 12 is stored in a temporary):
95/// X 7 8 9 10 11 6 13 14 15 . 0 1 2 3 4 5
96///               ^
97/// after using another step (now 2 is in the temporary):
98/// X 7 8 9 10 11 6 13 14 15 . 0 1 12 3 4 5
99///               ^                 ^
100/// after the third step (the steps wrap around, and 8 is in the temporary):
101/// X 7 2 9 10 11 6 13 14 15 . 0 1 12 3 4 5
102///     ^         ^                 ^
103/// after 7 more steps, the round ends with the temporary 0 getting put in the X:
104/// 0 7 2 9 4 11 6 13 8 15 . 10 1 12 3 14 5
105/// ^   ^   ^    ^    ^       ^    ^    ^
106/// ```
107/// Fortunately, the number of skipped over elements between finalized elements is always equal, so
108/// we can just offset our starting position and do more rounds (the total number of rounds is the
109/// `gcd(left + right, right)` value). The end result is that all elements are finalized once and
110/// only once.
111///
112/// Algorithm 2 can be vectorized by chunking and performing many rounds at once, but there are too
113/// few rounds on average until `left + right` is enormous, and the worst case of a single
114/// round is always there.
115///
116/// # Safety
117///
118/// The specified range must be valid for reading and writing.
119#[inline]
120unsafe fn ptr_rotate_gcd<T>(left: usize, mid: *mut T, right: usize) {
121    // Algorithm 2
122    // Microbenchmarks indicate that the average performance for random shifts is better all
123    // the way until about `left + right == 32`, but the worst case performance breaks even
124    // around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4
125    // `usize`s, this algorithm also outperforms other algorithms.
126    // SAFETY: callers must ensure `mid - left` is valid for reading and writing.
127    let x = unsafe { mid.sub(left) };
128    // beginning of first round
129    // SAFETY: see previous comment.
130    let mut tmp: T = unsafe { x.read() };
131    let mut i = right;
132    // `gcd` can be found before hand by calculating `gcd(left + right, right)`,
133    // but it is faster to do one loop which calculates the gcd as a side effect, then
134    // doing the rest of the chunk
135    let mut gcd = right;
136    // benchmarks reveal that it is faster to swap temporaries all the way through instead
137    // of reading one temporary once, copying backwards, and then writing that temporary at
138    // the very end. This is possibly due to the fact that swapping or replacing temporaries
139    // uses only one memory address in the loop instead of needing to manage two.
140    loop {
141        // [long-safety-expl]
142        // SAFETY: callers must ensure `[left, left+mid+right)` are all valid for reading and
143        // writing.
144        //
145        // - `i` start with `right` so `mid-left <= x+i = x+right = mid-left+right < mid+right`
146        // - `i <= left+right-1` is always true
147        //   - if `i < left`, `right` is added so `i < left+right` and on the next
148        //     iteration `left` is removed from `i` so it doesn't go further
149        //   - if `i >= left`, `left` is removed immediately and so it doesn't go further.
150        // - overflows cannot happen for `i` since the function's safety contract ask for
151        //   `mid+right-1 = x+left+right` to be valid for writing
152        // - underflows cannot happen because `i` must be bigger or equal to `left` for
153        //   a subtraction of `left` to happen.
154        //
155        // So `x+i` is valid for reading and writing if the caller respected the contract
156        tmp = unsafe { x.add(i).replace(tmp) };
157        // instead of incrementing `i` and then checking if it is outside the bounds, we
158        // check if `i` will go outside the bounds on the next increment. This prevents
159        // any wrapping of pointers or `usize`.
160        if i >= left {
161            i -= left;
162            if i == 0 {
163                // end of first round
164                // SAFETY: tmp has been read from a valid source and x is valid for writing
165                // according to the caller.
166                unsafe { x.write(tmp) };
167                break;
168            }
169            // this conditional must be here if `left + right >= 15`
170            if i < gcd {
171                gcd = i;
172            }
173        } else {
174            i += right;
175        }
176    }
177    // finish the chunk with more rounds
178    for start in 1..gcd {
179        // SAFETY: `gcd` is at most equal to `right` so all values in `1..gcd` are valid for
180        // reading and writing as per the function's safety contract, see [long-safety-expl]
181        // above
182        tmp = unsafe { x.add(start).read() };
183        // [safety-expl-addition]
184        //
185        // Here `start < gcd` so `start < right` so `i < right+right`: `right` being the
186        // greatest common divisor of `(left+right, right)` means that `left = right` so
187        // `i < left+right` so `x+i = mid-left+i` is always valid for reading and writing
188        // according to the function's safety contract.
189        i = start + right;
190        loop {
191            // SAFETY: see [long-safety-expl] and [safety-expl-addition]
192            tmp = unsafe { x.add(i).replace(tmp) };
193            if i >= left {
194                i -= left;
195                if i == start {
196                    // SAFETY: see [long-safety-expl] and [safety-expl-addition]
197                    unsafe { x.add(start).write(tmp) };
198                    break;
199                }
200            } else {
201                i += right;
202            }
203        }
204    }
205}
206
207/// Algorithm 3 utilizes repeated swapping of `min(left, right)` elements.
208///
209/// ///
210/// ```text
211/// left = 11, right = 4
212/// [4 5 6 7 8 9 10 11 12 13 14 . 0 1 2 3]
213///                  ^  ^  ^  ^   ^ ^ ^ ^ swapping the right most elements with elements to the left
214/// [4 5 6 7 8 9 10 . 0 1 2 3] 11 12 13 14
215///        ^ ^ ^  ^   ^ ^ ^ ^ swapping these
216/// [4 5 6 . 0 1 2 3] 7 8 9 10 11 12 13 14
217/// we cannot swap any more, but a smaller rotation problem is left to solve
218/// ```
219/// when `left < right` the swapping happens from the left instead.
220///
221/// # Safety
222///
223/// The specified range must be valid for reading and writing.
224#[inline]
225unsafe fn ptr_rotate_swap<T>(mut left: usize, mut mid: *mut T, mut right: usize) {
226    loop {
227        if left >= right {
228            // Algorithm 3
229            // There is an alternate way of swapping that involves finding where the last swap
230            // of this algorithm would be, and swapping using that last chunk instead of swapping
231            // adjacent chunks like this algorithm is doing, but this way is still faster.
232            loop {
233                // SAFETY:
234                // `left >= right` so `[mid-right, mid+right)` is valid for reading and writing
235                // Subtracting `right` from `mid` each turn is counterbalanced by the addition and
236                // check after it.
237                unsafe {
238                    ptr::swap_nonoverlapping(mid.sub(right), mid, right);
239                    mid = mid.sub(right);
240                }
241                left -= right;
242                if left < right {
243                    break;
244                }
245            }
246        } else {
247            // Algorithm 3, `left < right`
248            loop {
249                // SAFETY: `[mid-left, mid+left)` is valid for reading and writing because
250                // `left < right` so `mid+left < mid+right`.
251                // Adding `left` to `mid` each turn is counterbalanced by the subtraction and check
252                // after it.
253                unsafe {
254                    ptr::swap_nonoverlapping(mid.sub(left), mid, left);
255                    mid = mid.add(left);
256                }
257                right -= left;
258                if right < left {
259                    break;
260                }
261            }
262        }
263        if (right == 0) || (left == 0) {
264            return;
265        }
266    }
267}