From 4c9b9d7258773d6d0aac6711ba350d9395120e32 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Sun, 3 Nov 2024 18:34:45 +0100 Subject: [PATCH] Use more explicit and reliable ptr select in sort impls Using if ... with the intent to avoid branches can be surprising to readers and carries the risk of turning into jumps/branches generated by some future compiler version, breaking crucial optimizations. This commit replaces their usage with the explicit and IR annotated `bool::select_unpredictable`. --- .../core/src/slice/sort/shared/smallsort.rs | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/library/core/src/slice/sort/shared/smallsort.rs b/library/core/src/slice/sort/shared/smallsort.rs index 09f898309bd65..f6dcf42ba6037 100644 --- a/library/core/src/slice/sort/shared/smallsort.rs +++ b/library/core/src/slice/sort/shared/smallsort.rs @@ -387,7 +387,7 @@ unsafe fn swap_if_less(v_base: *mut T, a_pos: usize, b_pos: usize, is_less where F: FnMut(&T, &T) -> bool, { - // SAFETY: the caller must guarantee that `a` and `b` each added to `v_base` yield valid + // SAFETY: the caller must guarantee that `a_pos` and `b_pos` each added to `v_base` yield valid // pointers into `v_base`, and are properly aligned, and part of the same allocation. unsafe { let v_a = v_base.add(a_pos); @@ -404,16 +404,16 @@ where // The equivalent code with a branch would be: // // if should_swap { - // ptr::swap(left, right, 1); + // ptr::swap(v_a, v_b, 1); // } // The goal is to generate cmov instructions here. - let left_swap = if should_swap { v_b } else { v_a }; - let right_swap = if should_swap { v_a } else { v_b }; + let v_a_swap = should_swap.select_unpredictable(v_b, v_a); + let v_b_swap = should_swap.select_unpredictable(v_a, v_b); - let right_swap_tmp = ManuallyDrop::new(ptr::read(right_swap)); - ptr::copy(left_swap, v_a, 1); - ptr::copy_nonoverlapping(&*right_swap_tmp, v_b, 1); + let v_b_swap_tmp = ManuallyDrop::new(ptr::read(v_b_swap)); + ptr::copy(v_a_swap, v_a, 1); + ptr::copy_nonoverlapping(&*v_b_swap_tmp, v_b, 1); } } @@ -640,26 +640,21 @@ pub unsafe fn sort4_stable bool>( // 1, 1 | c b a d let c3 = is_less(&*c, &*a); let c4 = is_less(&*d, &*b); - let min = select(c3, c, a); - let max = select(c4, b, d); - let unknown_left = select(c3, a, select(c4, c, b)); - let unknown_right = select(c4, d, select(c3, b, c)); + let min = c3.select_unpredictable(c, a); + let max = c4.select_unpredictable(b, d); + let unknown_left = c3.select_unpredictable(a, c4.select_unpredictable(c, b)); + let unknown_right = c4.select_unpredictable(d, c3.select_unpredictable(b, c)); // Sort the last two unknown elements. let c5 = is_less(&*unknown_right, &*unknown_left); - let lo = select(c5, unknown_right, unknown_left); - let hi = select(c5, unknown_left, unknown_right); + let lo = c5.select_unpredictable(unknown_right, unknown_left); + let hi = c5.select_unpredictable(unknown_left, unknown_right); ptr::copy_nonoverlapping(min, dst, 1); ptr::copy_nonoverlapping(lo, dst.add(1), 1); ptr::copy_nonoverlapping(hi, dst.add(2), 1); ptr::copy_nonoverlapping(max, dst.add(3), 1); } - - #[inline(always)] - fn select(cond: bool, if_true: *const T, if_false: *const T) -> *const T { - if cond { if_true } else { if_false } - } } /// SAFETY: The caller MUST guarantee that `v_base` is valid for 8 reads and