Skip to content

Commit 3fd9554

Browse files
lukas-codeRalfJung
authored and
Lukas Markeffsky
committed
Apply suggestions from code review
* Remove `const_align_offset` and just call `align_offset` again * Remove miri shim for `div_exact` * Add more comments Co-authored-by: Ralf Jung <post@ralfj.de>
1 parent 9304b84 commit 3fd9554

File tree

5 files changed

+215
-80
lines changed

5 files changed

+215
-80
lines changed

compiler/rustc_const_eval/src/const_eval/machine.rs

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use rustc_hir::def::DefKind;
22
use rustc_middle::mir;
33
use rustc_middle::mir::interpret::PointerArithmetic;
4-
use rustc_middle::ty::layout::{FnAbiOf, LayoutOf};
4+
use rustc_middle::ty::layout::FnAbiOf;
55
use rustc_middle::ty::{self, Ty, TyCtxt};
66
use std::borrow::Borrow;
77
use std::collections::hash_map::Entry;
88
use std::hash::Hash;
9+
use std::ops::ControlFlow;
910

1011
use rustc_data_structures::fx::FxHashMap;
1112
use std::fmt;
@@ -147,9 +148,10 @@ impl interpret::MayLeak for ! {
147148
}
148149

149150
impl<'mir, 'tcx: 'mir> CompileTimeEvalContext<'mir, 'tcx> {
150-
/// "Intercept" a function call to a panic-related function
151-
/// because we have something special to do for it.
152-
/// If this returns successfully (`Ok`), the function should just be evaluated normally.
151+
/// "Intercept" a function call, because we have something special to do for it.
152+
/// All `#[rustc_do_not_const_check]` functions should be hooked here.
153+
/// If this returns `Some`, then evaluation should continue with that function.
154+
/// Otherwise, the function call has been handled and the function has returned.
153155
fn hook_special_const_fn(
154156
&mut self,
155157
instance: ty::Instance<'tcx>,
@@ -158,7 +160,6 @@ impl<'mir, 'tcx: 'mir> CompileTimeEvalContext<'mir, 'tcx> {
158160
dest: &PlaceTy<'tcx>,
159161
ret: Option<mir::BasicBlock>,
160162
) -> InterpResult<'tcx, Option<ty::Instance<'tcx>>> {
161-
// All `#[rustc_do_not_const_check]` functions should be hooked here.
162163
let def_id = instance.def_id();
163164

164165
if Some(def_id) == self.tcx.lang_items().panic_display()
@@ -192,34 +193,27 @@ impl<'mir, 'tcx: 'mir> CompileTimeEvalContext<'mir, 'tcx> {
192193

193194
return Ok(Some(new_instance));
194195
} else if Some(def_id) == self.tcx.lang_items().align_offset_fn() {
195-
// For align_offset, we either call const_align_offset or return usize::MAX directly.
196-
197-
let Some(const_def_id) = self.tcx.lang_items().const_align_offset_fn() else {
198-
bug!("`const_align_offset` must be defined to call `align_offset` in const eval")
199-
};
200-
let const_instance = ty::Instance::resolve(
201-
*self.tcx,
202-
ty::ParamEnv::reveal_all(),
203-
const_def_id,
204-
instance.substs,
205-
)
206-
.unwrap()
207-
.unwrap();
208-
209-
self.align_offset(const_instance, args, dest, ret)?;
210-
211-
return Ok(None);
196+
// For align_offset, we replace the function call if the pointer has no address.
197+
match self.align_offset(instance, args, dest, ret)? {
198+
ControlFlow::Continue(()) => return Ok(Some(instance)),
199+
ControlFlow::Break(()) => return Ok(None),
200+
}
212201
}
213202
Ok(Some(instance))
214203
}
215204

205+
/// `align_offset(ptr, target_align)` needs special handling in const eval, because the pointer
206+
/// may not have an address.
207+
///
208+
/// If the pointer does have a known address we return `CONTINUE` and the function call should
209+
/// proceed as normal. Otherwise we will replace the function call and return `BREAK`.
216210
fn align_offset(
217211
&mut self,
218-
const_instance: ty::Instance<'tcx>,
212+
instance: ty::Instance<'tcx>,
219213
args: &[OpTy<'tcx>],
220214
dest: &PlaceTy<'tcx>,
221215
ret: Option<mir::BasicBlock>,
222-
) -> InterpResult<'tcx> {
216+
) -> InterpResult<'tcx, ControlFlow<()>> {
223217
assert_eq!(args.len(), 2);
224218

225219
let ptr = self.read_pointer(&args[0])?;
@@ -229,36 +223,40 @@ impl<'mir, 'tcx: 'mir> CompileTimeEvalContext<'mir, 'tcx> {
229223
throw_ub_format!("`align_offset` called with non-power-of-two align: {}", target_align);
230224
}
231225

232-
let addr = match self.ptr_try_get_alloc_id(ptr) {
226+
match self.ptr_try_get_alloc_id(ptr) {
233227
Ok((alloc_id, offset, _extra)) => {
234228
let (_size, alloc_align, _kind) = self.get_alloc_info(alloc_id);
235229

236-
if target_align > alloc_align.bytes() {
230+
if target_align <= alloc_align.bytes() {
231+
// Extract the address relative to the allocation base that is definitely
232+
// sufficiently aligned and call `align_offset` again.
233+
let addr = ImmTy::from_uint(offset.bytes(), args[0].layout).into();
234+
let align = ImmTy::from_uint(target_align, args[1].layout).into();
235+
236+
let fn_abi = self.fn_abi_of_instance(instance, ty::List::empty())?;
237+
self.eval_fn_call(
238+
FnVal::Instance(instance),
239+
(CallAbi::Rust, fn_abi),
240+
&[addr, align],
241+
false,
242+
dest,
243+
ret,
244+
StackPopUnwind::NotAllowed,
245+
)?;
246+
Ok(ControlFlow::BREAK)
247+
} else {
248+
// Not alignable in const, return `usize::MAX`.
237249
let usize_max = Scalar::from_machine_usize(self.machine_usize_max(), self);
238250
self.write_scalar(usize_max, dest)?;
239251
self.return_to_block(ret)?;
240-
return Ok(());
241-
} else {
242-
offset.bytes()
252+
Ok(ControlFlow::BREAK)
243253
}
244254
}
245-
Err(addr) => addr,
246-
};
247-
248-
let usize_layout = self.layout_of(self.tcx.types.usize)?;
249-
let addr = ImmTy::from_uint(addr, usize_layout).into();
250-
let align = ImmTy::from_uint(target_align, usize_layout).into();
251-
252-
let fn_abi = self.fn_abi_of_instance(const_instance, ty::List::empty())?;
253-
self.eval_fn_call(
254-
FnVal::Instance(const_instance),
255-
(CallAbi::Rust, fn_abi),
256-
&[addr, align],
257-
false,
258-
dest,
259-
ret,
260-
StackPopUnwind::NotAllowed,
261-
)
255+
Err(_addr) => {
256+
// The pointer has an address, continue with function call.
257+
Ok(ControlFlow::CONTINUE)
258+
}
259+
}
262260
}
263261

264262
/// See documentation on the `ptr_guaranteed_cmp` intrinsic.

compiler/rustc_hir/src/lang_items.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,7 @@ language_item_table! {
283283
MaybeUninit, sym::maybe_uninit, maybe_uninit, Target::Union, GenericRequirement::None;
284284

285285
/// Align offset for stride != 1; must not panic.
286-
AlignOffset, sym::align_offset, align_offset_fn, Target::Fn, GenericRequirement::Exact(1);
287-
ConstAlignOffset, sym::const_align_offset, const_align_offset_fn, Target::Fn, GenericRequirement::Exact(1);
286+
AlignOffset, sym::align_offset, align_offset_fn, Target::Fn, GenericRequirement::None;
288287

289288
Termination, sym::termination, termination, Target::Trait, GenericRequirement::None;
290289

compiler/rustc_span/src/symbol.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,6 @@ symbols! {
510510
concat_macro,
511511
conservative_impl_trait,
512512
console,
513-
const_align_offset,
514513
const_allocate,
515514
const_async_blocks,
516515
const_compare_raw_pointers,

library/core/src/ptr/mod.rs

Lines changed: 170 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,41 +1559,20 @@ pub unsafe fn write_volatile<T>(dst: *mut T, src: T) {
15591559
///
15601560
/// # Safety
15611561
/// `a` must be a power of two.
1562-
#[lang = "align_offset"]
1563-
#[rustc_do_not_const_check]
1564-
#[cfg(not(bootstrap))]
1565-
pub(crate) const unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
1566-
// SAFETY: Caller ensures that `a` is a power of two.
1567-
unsafe { const_align_offset::<T>(p.addr(), a) }
1568-
}
1569-
1570-
#[lang = "align_offset"]
1571-
#[cfg(bootstrap)]
1572-
pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
1573-
// SAFETY: Caller ensures that `a` is a power of two.
1574-
unsafe { const_align_offset::<T>(p.addr(), a) }
1575-
}
1576-
1577-
/// Align address `addr`.
1578-
///
1579-
/// Calculate offset (in terms of elements of `size_of::<T>()` stride) that has to be applied
1580-
/// to address `addr` so that `addr` would get aligned to `a`.
15811562
///
1582-
/// Note: This implementation has been carefully tailored to not panic. It is UB for this to panic.
1563+
/// # Notes
1564+
/// This implementation has been carefully tailored to not panic. It is UB for this to panic.
15831565
/// The only real change that can be made here is change of `INV_TABLE_MOD_16` and associated
15841566
/// constants.
15851567
///
1586-
/// # Safety
1587-
/// `a` must be a power of two.
1588-
///
15891568
/// If we ever decide to make it possible to call the intrinsic with `a` that is not a
15901569
/// power-of-two, it will probably be more prudent to just change to a naive implementation rather
15911570
/// than trying to adapt this to accommodate that change.
15921571
///
15931572
/// Any questions go to @nagisa.
1594-
#[cfg_attr(not(bootstrap), lang = "const_align_offset")]
1595-
#[rustc_allow_const_fn_unstable(const_exact_div)]
1596-
pub(crate) const unsafe fn const_align_offset<T: Sized>(addr: usize, a: usize) -> usize {
1573+
#[lang = "align_offset"]
1574+
#[cfg(not(bootstrap))]
1575+
pub(crate) const unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
15971576
// FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
15981577
// 1, where the method versions of these operations are not inlined.
15991578
use intrinsics::{
@@ -1650,6 +1629,171 @@ pub(crate) const unsafe fn const_align_offset<T: Sized>(addr: usize, a: usize) -
16501629
}
16511630
}
16521631

1632+
let stride = mem::size_of::<T>();
1633+
1634+
// SAFETY: At runtime transmuting a pointer to `usize` is always safe, because they have the
1635+
// same layout. During const eval we hook this function to ensure that the pointer always has
1636+
// an address (only the standard library can do this).
1637+
let addr = unsafe { mem::transmute(p) };
1638+
1639+
// SAFETY: `a` is a power-of-two, therefore non-zero.
1640+
let a_minus_one = unsafe { unchecked_sub(a, 1) };
1641+
1642+
if stride == 0 {
1643+
// SPECIAL_CASE: handle 0-sized types. No matter how many times we step, the address will
1644+
// stay the same, so no offset will be able to align the pointer unless it is already
1645+
// aligned. This branch _will_ be optimized out as `stride` is known at compile-time.
1646+
let p_mod_a = addr & a_minus_one;
1647+
return if p_mod_a == 0 { 0 } else { usize::MAX };
1648+
}
1649+
1650+
// SAFETY: `stride == 0` case has been handled by the special case above.
1651+
let a_mod_stride = unsafe { unchecked_rem(a, stride) };
1652+
if a_mod_stride == 0 {
1653+
// SPECIAL_CASE: In cases where the `a` is divisible by `stride`, byte offset to align a
1654+
// pointer can be computed more simply through `-p (mod a)`. In the off-chance the byte
1655+
// offset is not a multiple of `stride`, the input pointer was misaligned and no pointer
1656+
// offset will be able to produce a `p` aligned to the specified `a`.
1657+
//
1658+
// The naive `-p (mod a)` equation inhibits LLVM's ability to select instructions
1659+
// like `lea`. We compute `(round_up_to_next_alignment(p, a) - p)` instead. This
1660+
// redistributes operations around the load-bearing, but pessimizing `and` instruction
1661+
// sufficiently for LLVM to be able to utilize the various optimizations it knows about.
1662+
//
1663+
// LLVM handles the branch here particularly nicely. If this branch needs to be evaluated
1664+
// at runtime, it will produce a mask `if addr_mod_stride == 0 { 0 } else { usize::MAX }`
1665+
// in a branch-free way and then bitwise-OR it with whatever result the `-p mod a`
1666+
// computation produces.
1667+
1668+
// SAFETY: `stride == 0` case has been handled by the special case above.
1669+
let addr_mod_stride = unsafe { unchecked_rem(addr, stride) };
1670+
1671+
return if addr_mod_stride == 0 {
1672+
let aligned_address = wrapping_add(addr, a_minus_one) & wrapping_sub(0, a);
1673+
let byte_offset = wrapping_sub(aligned_address, addr);
1674+
// SAFETY: `stride` is non-zero. This is guaranteed to divide exactly as well, because
1675+
// addr has been verified to be aligned to the original type’s alignment requirements.
1676+
unsafe { exact_div(byte_offset, stride) }
1677+
} else {
1678+
usize::MAX
1679+
};
1680+
}
1681+
1682+
// GENERAL_CASE: From here on we’re handling the very general case where `addr` may be
1683+
// misaligned, there isn’t an obvious relationship between `stride` and `a` that we can take an
1684+
// advantage of, etc. This case produces machine code that isn’t particularly high quality,
1685+
// compared to the special cases above. The code produced here is still within the realm of
1686+
// miracles, given the situations this case has to deal with.
1687+
1688+
// SAFETY: a is power-of-two hence non-zero. stride == 0 case is handled above.
1689+
let gcdpow = unsafe { cttz_nonzero(stride).min(cttz_nonzero(a)) };
1690+
// SAFETY: gcdpow has an upper-bound that’s at most the number of bits in a usize.
1691+
let gcd = unsafe { unchecked_shl(1usize, gcdpow) };
1692+
// SAFETY: gcd is always greater or equal to 1.
1693+
if addr & unsafe { unchecked_sub(gcd, 1) } == 0 {
1694+
// This branch solves for the following linear congruence equation:
1695+
//
1696+
// ` p + so = 0 mod a `
1697+
//
1698+
// `p` here is the pointer value, `s` - stride of `T`, `o` offset in `T`s, and `a` - the
1699+
// requested alignment.
1700+
//
1701+
// With `g = gcd(a, s)`, and the above condition asserting that `p` is also divisible by
1702+
// `g`, we can denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
1703+
//
1704+
// ` p' + s'o = 0 mod a' `
1705+
// ` o = (a' - (p' mod a')) * (s'^-1 mod a') `
1706+
//
1707+
// The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the
1708+
// second term is "how does incrementing `p` by `s` bytes change the relative alignment of
1709+
// `p`" (again divided by `g`). Division by `g` is necessary to make the inverse well
1710+
// formed if `a` and `s` are not co-prime.
1711+
//
1712+
// Furthermore, the result produced by this solution is not "minimal", so it is necessary
1713+
// to take the result `o mod lcm(s, a)`. This `lcm(s, a)` is the same as `a'`.
1714+
1715+
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
1716+
// `a`.
1717+
let a2 = unsafe { unchecked_shr(a, gcdpow) };
1718+
// SAFETY: `a2` is non-zero. Shifting `a` by `gcdpow` cannot shift out any of the set bits
1719+
// in `a` (of which it has exactly one).
1720+
let a2minus1 = unsafe { unchecked_sub(a2, 1) };
1721+
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
1722+
// `a`.
1723+
let s2 = unsafe { unchecked_shr(stride & a_minus_one, gcdpow) };
1724+
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
1725+
// `a`. Furthermore, the subtraction cannot overflow, because `a2 = a >> gcdpow` will
1726+
// always be strictly greater than `(p % a) >> gcdpow`.
1727+
let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(addr & a_minus_one, gcdpow)) };
1728+
// SAFETY: `a2` is a power-of-two, as proven above. `s2` is strictly less than `a2`
1729+
// because `(s % a) >> gcdpow` is strictly less than `a >> gcdpow`.
1730+
return wrapping_mul(minusp2, unsafe { mod_inv(s2, a2) }) & a2minus1;
1731+
}
1732+
1733+
// Cannot be aligned at all.
1734+
usize::MAX
1735+
}
1736+
1737+
#[lang = "align_offset"]
1738+
#[cfg(bootstrap)]
1739+
pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
1740+
// FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
1741+
// 1, where the method versions of these operations are not inlined.
1742+
use intrinsics::{
1743+
cttz_nonzero, exact_div, unchecked_rem, unchecked_shl, unchecked_shr, unchecked_sub,
1744+
wrapping_add, wrapping_mul, wrapping_sub,
1745+
};
1746+
1747+
/// Calculate multiplicative modular inverse of `x` modulo `m`.
1748+
///
1749+
/// This implementation is tailored for `align_offset` and has following preconditions:
1750+
///
1751+
/// * `m` is a power-of-two;
1752+
/// * `x < m`; (if `x ≥ m`, pass in `x % m` instead)
1753+
///
1754+
/// Implementation of this function shall not panic. Ever.
1755+
#[inline]
1756+
unsafe fn mod_inv(x: usize, m: usize) -> usize {
1757+
/// Multiplicative modular inverse table modulo 2⁴ = 16.
1758+
///
1759+
/// Note, that this table does not contain values where inverse does not exist (i.e., for
1760+
/// `0⁻¹ mod 16`, `2⁻¹ mod 16`, etc.)
1761+
const INV_TABLE_MOD_16: [u8; 8] = [1, 11, 13, 7, 9, 3, 5, 15];
1762+
/// Modulo for which the `INV_TABLE_MOD_16` is intended.
1763+
const INV_TABLE_MOD: usize = 16;
1764+
/// INV_TABLE_MOD²
1765+
const INV_TABLE_MOD_SQUARED: usize = INV_TABLE_MOD * INV_TABLE_MOD;
1766+
1767+
let table_inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
1768+
// SAFETY: `m` is required to be a power-of-two, hence non-zero.
1769+
let m_minus_one = unsafe { unchecked_sub(m, 1) };
1770+
if m <= INV_TABLE_MOD {
1771+
table_inverse & m_minus_one
1772+
} else {
1773+
// We iterate "up" using the following formula:
1774+
//
1775+
// $$ xy ≡ 1 (mod 2ⁿ) → xy (2 - xy) ≡ 1 (mod 2²ⁿ) $$
1776+
//
1777+
// until 2²ⁿ ≥ m. Then we can reduce to our desired `m` by taking the result `mod m`.
1778+
let mut inverse = table_inverse;
1779+
let mut going_mod = INV_TABLE_MOD_SQUARED;
1780+
loop {
1781+
// y = y * (2 - xy) mod n
1782+
//
1783+
// Note, that we use wrapping operations here intentionally – the original formula
1784+
// uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
1785+
// usize::MAX` instead, because we take the result `mod n` at the end
1786+
// anyway.
1787+
inverse = wrapping_mul(inverse, wrapping_sub(2usize, wrapping_mul(x, inverse)));
1788+
if going_mod >= m {
1789+
return inverse & m_minus_one;
1790+
}
1791+
going_mod = wrapping_mul(going_mod, going_mod);
1792+
}
1793+
}
1794+
}
1795+
1796+
let addr = p.addr();
16531797
let stride = mem::size_of::<T>();
16541798
// SAFETY: `a` is a power-of-two, therefore non-zero.
16551799
let a_minus_one = unsafe { unchecked_sub(a, 1) };

src/tools/miri/src/shims/intrinsics/mod.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
357357
}
358358

359359
// Other
360-
"exact_div" => {
361-
let [num, denom] = check_arg_count(args)?;
362-
this.exact_div(&this.read_immediate(num)?, &this.read_immediate(denom)?, dest)?;
363-
}
364-
365360
"breakpoint" => {
366361
let [] = check_arg_count(args)?;
367362
// normally this would raise a SIGTRAP, which aborts if no debugger is connected

0 commit comments

Comments
 (0)