Skip to content

Commit ee42d1e

Browse files
committed
NaN non-determinism for SIMD intrinsics
1 parent 0f98c0e commit ee42d1e

File tree

2 files changed

+145
-75
lines changed

2 files changed

+145
-75
lines changed

src/tools/miri/src/shims/intrinsics/simd.rs

Lines changed: 88 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@ use rustc_span::{sym, Symbol};
55
use rustc_target::abi::{Endian, HasDataLayout};
66

77
use crate::helpers::{
8-
bool_to_simd_element, check_arg_count, round_to_next_multiple_of, simd_element_to_bool,
8+
bool_to_simd_element, check_arg_count, round_to_next_multiple_of, simd_element_to_bool, ToHost,
9+
ToSoft,
910
};
1011
use crate::*;
1112

13+
#[derive(Copy, Clone)]
14+
pub(crate) enum MinMax {
15+
Min,
16+
Max,
17+
}
18+
1219
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
1320
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
1421
/// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed.
@@ -67,13 +74,17 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
6774
let op = this.read_immediate(&this.project_index(&op, i)?)?;
6875
let dest = this.project_index(&dest, i)?;
6976
let val = match which {
70-
Op::MirOp(mir_op) => this.wrapping_unary_op(mir_op, &op)?.to_scalar(),
77+
Op::MirOp(mir_op) => {
78+
// This already does NaN adjustments
79+
this.wrapping_unary_op(mir_op, &op)?.to_scalar()
80+
}
7181
Op::Abs => {
7282
// Works for f32 and f64.
7383
let ty::Float(float_ty) = op.layout.ty.kind() else {
7484
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
7585
};
7686
let op = op.to_scalar();
87+
// "Bitwise" operation, no NaN adjustments
7788
match float_ty {
7889
FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
7990
FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
@@ -86,14 +97,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
8697
// FIXME using host floats
8798
match float_ty {
8899
FloatTy::F32 => {
89-
let f = f32::from_bits(op.to_scalar().to_u32()?);
90-
let res = f.sqrt();
91-
Scalar::from_u32(res.to_bits())
100+
let f = op.to_scalar().to_f32()?;
101+
let res = f.to_host().sqrt().to_soft();
102+
let res = this.adjust_nan(res, &[f]);
103+
Scalar::from(res)
92104
}
93105
FloatTy::F64 => {
94-
let f = f64::from_bits(op.to_scalar().to_u64()?);
95-
let res = f.sqrt();
96-
Scalar::from_u64(res.to_bits())
106+
let f = op.to_scalar().to_f64()?;
107+
let res = f.to_host().sqrt().to_soft();
108+
let res = this.adjust_nan(res, &[f]);
109+
Scalar::from(res)
97110
}
98111
}
99112
}
@@ -105,11 +118,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
105118
FloatTy::F32 => {
106119
let f = op.to_scalar().to_f32()?;
107120
let res = f.round_to_integral(rounding).value;
121+
let res = this.adjust_nan(res, &[f]);
108122
Scalar::from_f32(res)
109123
}
110124
FloatTy::F64 => {
111125
let f = op.to_scalar().to_f64()?;
112126
let res = f.round_to_integral(rounding).value;
127+
let res = this.adjust_nan(res, &[f]);
113128
Scalar::from_f64(res)
114129
}
115130
}
@@ -157,8 +172,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
157172
enum Op {
158173
MirOp(BinOp),
159174
SaturatingOp(BinOp),
160-
FMax,
161-
FMin,
175+
FMinMax(MinMax),
162176
WrappingOffset,
163177
}
164178
let which = match intrinsic_name {
@@ -178,8 +192,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
178192
"le" => Op::MirOp(BinOp::Le),
179193
"gt" => Op::MirOp(BinOp::Gt),
180194
"ge" => Op::MirOp(BinOp::Ge),
181-
"fmax" => Op::FMax,
182-
"fmin" => Op::FMin,
195+
"fmax" => Op::FMinMax(MinMax::Max),
196+
"fmin" => Op::FMinMax(MinMax::Min),
183197
"saturating_add" => Op::SaturatingOp(BinOp::Add),
184198
"saturating_sub" => Op::SaturatingOp(BinOp::Sub),
185199
"arith_offset" => Op::WrappingOffset,
@@ -192,6 +206,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
192206
let dest = this.project_index(&dest, i)?;
193207
let val = match which {
194208
Op::MirOp(mir_op) => {
209+
// This does NaN adjustments.
195210
let (val, overflowed) = this.overflowing_binary_op(mir_op, &left, &right)?;
196211
if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
197212
// Shifts have extra UB as SIMD operations that the MIR binop does not have.
@@ -225,11 +240,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
225240
let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, this);
226241
Scalar::from_maybe_pointer(offset_ptr, this)
227242
}
228-
Op::FMax => {
229-
fmax_op(&left, &right)?
230-
}
231-
Op::FMin => {
232-
fmin_op(&left, &right)?
243+
Op::FMinMax(op) => {
244+
this.fminmax_op(op, &left, &right)?
233245
}
234246
};
235247
this.write_scalar(val, &dest)?;
@@ -259,18 +271,20 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
259271
};
260272
let val = match float_ty {
261273
FloatTy::F32 => {
262-
let a = f32::from_bits(a.to_u32()?);
263-
let b = f32::from_bits(b.to_u32()?);
264-
let c = f32::from_bits(c.to_u32()?);
265-
let res = a.mul_add(b, c);
266-
Scalar::from_u32(res.to_bits())
274+
let a = a.to_f32()?;
275+
let b = b.to_f32()?;
276+
let c = c.to_f32()?;
277+
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
278+
let res = this.adjust_nan(res, &[a, b, c]);
279+
Scalar::from(res)
267280
}
268281
FloatTy::F64 => {
269-
let a = f64::from_bits(a.to_u64()?);
270-
let b = f64::from_bits(b.to_u64()?);
271-
let c = f64::from_bits(c.to_u64()?);
272-
let res = a.mul_add(b, c);
273-
Scalar::from_u64(res.to_bits())
282+
let a = a.to_f64()?;
283+
let b = b.to_f64()?;
284+
let c = c.to_f64()?;
285+
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
286+
let res = this.adjust_nan(res, &[a, b, c]);
287+
Scalar::from(res)
274288
}
275289
};
276290
this.write_scalar(val, &dest)?;
@@ -295,17 +309,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
295309
enum Op {
296310
MirOp(BinOp),
297311
MirOpBool(BinOp),
298-
Max,
299-
Min,
312+
MinMax(MinMax),
300313
}
301314
let which = match intrinsic_name {
302315
"reduce_and" => Op::MirOp(BinOp::BitAnd),
303316
"reduce_or" => Op::MirOp(BinOp::BitOr),
304317
"reduce_xor" => Op::MirOp(BinOp::BitXor),
305318
"reduce_any" => Op::MirOpBool(BinOp::BitOr),
306319
"reduce_all" => Op::MirOpBool(BinOp::BitAnd),
307-
"reduce_max" => Op::Max,
308-
"reduce_min" => Op::Min,
320+
"reduce_max" => Op::MinMax(MinMax::Max),
321+
"reduce_min" => Op::MinMax(MinMax::Min),
309322
_ => unreachable!(),
310323
};
311324

@@ -325,24 +338,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
325338
let op = imm_from_bool(simd_element_to_bool(op)?);
326339
this.wrapping_binary_op(mir_op, &res, &op)?
327340
}
328-
Op::Max => {
329-
if matches!(res.layout.ty.kind(), ty::Float(_)) {
330-
ImmTy::from_scalar(fmax_op(&res, &op)?, res.layout)
331-
} else {
332-
// Just boring integers, so NaNs to worry about
333-
if this.wrapping_binary_op(BinOp::Ge, &res, &op)?.to_scalar().to_bool()? {
334-
res
335-
} else {
336-
op
337-
}
338-
}
339-
}
340-
Op::Min => {
341+
Op::MinMax(mmop) => {
341342
if matches!(res.layout.ty.kind(), ty::Float(_)) {
342-
ImmTy::from_scalar(fmin_op(&res, &op)?, res.layout)
343+
ImmTy::from_scalar(this.fminmax_op(mmop, &res, &op)?, res.layout)
343344
} else {
344345
// Just boring integers, so NaNs to worry about
345-
if this.wrapping_binary_op(BinOp::Le, &res, &op)?.to_scalar().to_bool()? {
346+
let mirop = match mmop {
347+
MinMax::Min => BinOp::Le,
348+
MinMax::Max => BinOp::Ge,
349+
};
350+
if this.wrapping_binary_op(mirop, &res, &op)?.to_scalar().to_bool()? {
346351
res
347352
} else {
348353
op
@@ -709,6 +714,43 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
709714
}
710715
Ok(())
711716
}
717+
718+
fn fminmax_op(
719+
&self,
720+
op: MinMax,
721+
left: &ImmTy<'tcx, Provenance>,
722+
right: &ImmTy<'tcx, Provenance>,
723+
) -> InterpResult<'tcx, Scalar<Provenance>> {
724+
let this = self.eval_context_ref();
725+
assert_eq!(left.layout.ty, right.layout.ty);
726+
let ty::Float(float_ty) = left.layout.ty.kind() else {
727+
bug!("fmax operand is not a float")
728+
};
729+
let left = left.to_scalar();
730+
let right = right.to_scalar();
731+
Ok(match float_ty {
732+
FloatTy::F32 => {
733+
let left = left.to_f32()?;
734+
let right = right.to_f32()?;
735+
let res = match op {
736+
MinMax::Min => left.min(right),
737+
MinMax::Max => left.max(right),
738+
};
739+
let res = this.adjust_nan(res, &[left, right]);
740+
Scalar::from_f32(res)
741+
}
742+
FloatTy::F64 => {
743+
let left = left.to_f64()?;
744+
let right = right.to_f64()?;
745+
let res = match op {
746+
MinMax::Min => left.min(right),
747+
MinMax::Max => left.max(right),
748+
};
749+
let res = this.adjust_nan(res, &[left, right]);
750+
Scalar::from_f64(res)
751+
}
752+
})
753+
}
712754
}
713755

714756
fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
@@ -719,31 +761,3 @@ fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
719761
Endian::Big => vec_len - 1 - idx, // reverse order of bits
720762
}
721763
}
722-
723-
fn fmax_op<'tcx>(
724-
left: &ImmTy<'tcx, Provenance>,
725-
right: &ImmTy<'tcx, Provenance>,
726-
) -> InterpResult<'tcx, Scalar<Provenance>> {
727-
assert_eq!(left.layout.ty, right.layout.ty);
728-
let ty::Float(float_ty) = left.layout.ty.kind() else { bug!("fmax operand is not a float") };
729-
let left = left.to_scalar();
730-
let right = right.to_scalar();
731-
Ok(match float_ty {
732-
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.max(right.to_f32()?)),
733-
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.max(right.to_f64()?)),
734-
})
735-
}
736-
737-
fn fmin_op<'tcx>(
738-
left: &ImmTy<'tcx, Provenance>,
739-
right: &ImmTy<'tcx, Provenance>,
740-
) -> InterpResult<'tcx, Scalar<Provenance>> {
741-
assert_eq!(left.layout.ty, right.layout.ty);
742-
let ty::Float(float_ty) = left.layout.ty.kind() else { bug!("fmin operand is not a float") };
743-
let left = left.to_scalar();
744-
let right = right.to_scalar();
745-
Ok(match float_ty {
746-
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.min(right.to_f32()?)),
747-
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.min(right.to_f64()?)),
748-
})
749-
}

src/tools/miri/tests/pass/float_nan.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#![feature(float_gamma)]
1+
#![feature(float_gamma, portable_simd, core_intrinsics, platform_intrinsics)]
22
use std::collections::HashSet;
33
use std::fmt;
44
use std::hash::Hash;
@@ -535,6 +535,61 @@ fn test_casts() {
535535
);
536536
}
537537

538+
fn test_simd() {
539+
use std::intrinsics::simd::*;
540+
use std::simd::*;
541+
542+
extern "platform-intrinsic" {
543+
fn simd_fsqrt<T>(x: T) -> T;
544+
fn simd_ceil<T>(x: T) -> T;
545+
fn simd_fma<T>(x: T, y: T, z: T) -> T;
546+
}
547+
548+
let nan = F32::nan(Neg, Quiet, 0).as_f32();
549+
check_all_outcomes(
550+
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
551+
|| F32::from(unsafe { simd_div(f32x4::splat(0.0), f32x4::splat(0.0)) }[0]),
552+
);
553+
check_all_outcomes(
554+
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
555+
|| F32::from(unsafe { simd_fmin(f32x4::splat(nan), f32x4::splat(nan)) }[0]),
556+
);
557+
check_all_outcomes(
558+
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
559+
|| F32::from(unsafe { simd_fmax(f32x4::splat(nan), f32x4::splat(nan)) }[0]),
560+
);
561+
check_all_outcomes(
562+
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
563+
|| {
564+
F32::from(
565+
unsafe { simd_fma(f32x4::splat(nan), f32x4::splat(nan), f32x4::splat(nan)) }[0],
566+
)
567+
},
568+
);
569+
check_all_outcomes(
570+
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
571+
|| F32::from(unsafe { simd_reduce_add_ordered::<_, f32>(f32x4::splat(nan), nan) }),
572+
);
573+
check_all_outcomes(
574+
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
575+
|| F32::from(unsafe { simd_reduce_max::<_, f32>(f32x4::splat(nan)) }),
576+
);
577+
check_all_outcomes(
578+
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
579+
|| F32::from(unsafe { simd_fsqrt(f32x4::splat(nan)) }[0]),
580+
);
581+
check_all_outcomes(
582+
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
583+
|| F32::from(unsafe { simd_ceil(f32x4::splat(nan)) }[0]),
584+
);
585+
586+
// Casts
587+
check_all_outcomes(
588+
HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]),
589+
|| F64::from(unsafe { simd_cast::<f32x4, f64x4>(f32x4::splat(nan)) }[0]),
590+
);
591+
}
592+
538593
fn main() {
539594
// Check our constants against std, just to be sure.
540595
// We add 1 since our numbers are the number of bits stored
@@ -546,4 +601,5 @@ fn main() {
546601
test_f32();
547602
test_f64();
548603
test_casts();
604+
test_simd();
549605
}

0 commit comments

Comments
 (0)