Skip to content

Commit 6e88e96

Browse files
gnzlbgKodrAus
authored andcommitted
Support repr(simd) on ADTs containing a single array field
This PR allows using `#[repr(simd)]` on ADTs containing a single array field: ```rust #[repr(simd)] struct S0([f32; 4]); #[repr(simd)] struct S1<const N: usize>([f32; N]); #[repr(simd)] struct S2<T, const N: usize>([T; N]); ``` This should allow experimenting with portable packed SIMD abstractions on nightly that make use of const generics.
1 parent 9d78d1d commit 6e88e96

15 files changed

+427
-160
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 72 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,23 @@ fn generic_simd_intrinsic(
740740
llret_ty: &'ll Type,
741741
span: Span,
742742
) -> Result<&'ll Value, ()> {
743+
// Given a SIMD vector type `x` return the element type and the number of
744+
// elements in the vector.
745+
fn simd_ty_and_len(bx: &Builder<'a, 'll, 'tcx>, simd_ty: Ty<'tcx>) -> (Ty<'tcx>, u64) {
746+
let ty = if let ty::Adt(_def, _substs) = simd_ty.kind() {
747+
let f0_ty = bx.layout_of(simd_ty).field(bx, 0).ty;
748+
if let ty::Array(element_ty, _) = f0_ty.kind() { element_ty } else { f0_ty }
749+
} else {
750+
bug!("should only be called with a SIMD type")
751+
};
752+
let count = if let abi::Abi::Vector { count, .. } = bx.layout_of(simd_ty).abi {
753+
count
754+
} else {
755+
bug!("should only be called with a SIMD type")
756+
};
757+
(ty, count)
758+
}
759+
743760
// macros for error handling:
744761
macro_rules! emit_error {
745762
($msg: tt) => {
@@ -792,7 +809,7 @@ fn generic_simd_intrinsic(
792809
_ => return_error!("`{}` is not an integral type", in_ty),
793810
};
794811
require_simd!(arg_tys[1], "argument");
795-
let v_len = arg_tys[1].simd_size(tcx);
812+
let (_, v_len) = simd_ty_and_len(bx, arg_tys[1]);
796813
require!(
797814
// Allow masks for vectors with fewer than 8 elements to be
798815
// represented with a u8 or i8.
@@ -812,8 +829,6 @@ fn generic_simd_intrinsic(
812829
// every intrinsic below takes a SIMD vector as its first argument
813830
require_simd!(arg_tys[0], "input");
814831
let in_ty = arg_tys[0];
815-
let in_elem = arg_tys[0].simd_type(tcx);
816-
let in_len = arg_tys[0].simd_size(tcx);
817832

818833
let comparison = match name {
819834
sym::simd_eq => Some(hir::BinOpKind::Eq),
@@ -825,14 +840,15 @@ fn generic_simd_intrinsic(
825840
_ => None,
826841
};
827842

843+
let (in_elem, in_len) = simd_ty_and_len(bx, arg_tys[0]);
828844
if let Some(cmp_op) = comparison {
829845
require_simd!(ret_ty, "return");
830846

831-
let out_len = ret_ty.simd_size(tcx);
847+
let (out_ty, out_len) = simd_ty_and_len(bx, ret_ty);
832848
require!(
833849
in_len == out_len,
834850
"expected return type with length {} (same as input type `{}`), \
835-
found `{}` with length {}",
851+
found `{}` with length {}",
836852
in_len,
837853
in_ty,
838854
ret_ty,
@@ -842,7 +858,7 @@ fn generic_simd_intrinsic(
842858
bx.type_kind(bx.element_type(llret_ty)) == TypeKind::Integer,
843859
"expected return type with integer elements, found `{}` with non-integer `{}`",
844860
ret_ty,
845-
ret_ty.simd_type(tcx)
861+
out_ty
846862
);
847863

848864
return Ok(compare_simd_types(
@@ -862,7 +878,7 @@ fn generic_simd_intrinsic(
862878

863879
require_simd!(ret_ty, "return");
864880

865-
let out_len = ret_ty.simd_size(tcx);
881+
let (out_ty, out_len) = simd_ty_and_len(bx, ret_ty);
866882
require!(
867883
out_len == n,
868884
"expected return type of length {}, found `{}` with length {}",
@@ -871,13 +887,13 @@ fn generic_simd_intrinsic(
871887
out_len
872888
);
873889
require!(
874-
in_elem == ret_ty.simd_type(tcx),
890+
in_elem == out_ty,
875891
"expected return element type `{}` (element of input `{}`), \
876-
found `{}` with element type `{}`",
892+
found `{}` with element type `{}`",
877893
in_elem,
878894
in_ty,
879895
ret_ty,
880-
ret_ty.simd_type(tcx)
896+
out_ty
881897
);
882898

883899
let total_len = u128::from(in_len) * 2;
@@ -946,7 +962,7 @@ fn generic_simd_intrinsic(
946962
let m_elem_ty = in_elem;
947963
let m_len = in_len;
948964
require_simd!(arg_tys[1], "argument");
949-
let v_len = arg_tys[1].simd_size(tcx);
965+
let (_, v_len) = simd_ty_and_len(bx, arg_tys[1]);
950966
require!(
951967
m_len == v_len,
952968
"mismatched lengths: mask length `{}` != other vector length `{}`",
@@ -1171,25 +1187,27 @@ fn generic_simd_intrinsic(
11711187
require_simd!(ret_ty, "return");
11721188

11731189
// Of the same length:
1190+
let (_, out_len) = simd_ty_and_len(bx, arg_tys[1]);
1191+
let (_, out_len2) = simd_ty_and_len(bx, arg_tys[2]);
11741192
require!(
1175-
in_len == arg_tys[1].simd_size(tcx),
1193+
in_len == out_len,
11761194
"expected {} argument with length {} (same as input type `{}`), \
1177-
found `{}` with length {}",
1195+
found `{}` with length {}",
11781196
"second",
11791197
in_len,
11801198
in_ty,
11811199
arg_tys[1],
1182-
arg_tys[1].simd_size(tcx)
1200+
out_len
11831201
);
11841202
require!(
1185-
in_len == arg_tys[2].simd_size(tcx),
1203+
in_len == out_len2,
11861204
"expected {} argument with length {} (same as input type `{}`), \
1187-
found `{}` with length {}",
1205+
found `{}` with length {}",
11881206
"third",
11891207
in_len,
11901208
in_ty,
11911209
arg_tys[2],
1192-
arg_tys[2].simd_size(tcx)
1210+
out_len2
11931211
);
11941212

11951213
// The return type must match the first argument type
@@ -1213,39 +1231,40 @@ fn generic_simd_intrinsic(
12131231

12141232
// The second argument must be a simd vector with an element type that's a pointer
12151233
// to the element type of the first argument
1216-
let (pointer_count, underlying_ty) = match arg_tys[1].simd_type(tcx).kind() {
1217-
ty::RawPtr(p) if p.ty == in_elem => {
1218-
(ptr_count(arg_tys[1].simd_type(tcx)), non_ptr(arg_tys[1].simd_type(tcx)))
1219-
}
1234+
let (element_ty0, _) = simd_ty_and_len(bx, arg_tys[0]);
1235+
let (element_ty1, _) = simd_ty_and_len(bx, arg_tys[1]);
1236+
let (pointer_count, underlying_ty) = match element_ty1.kind() {
1237+
ty::RawPtr(p) if p.ty == in_elem => (ptr_count(element_ty1), non_ptr(element_ty1)),
12201238
_ => {
12211239
require!(
12221240
false,
12231241
"expected element type `{}` of second argument `{}` \
1224-
to be a pointer to the element type `{}` of the first \
1225-
argument `{}`, found `{}` != `*_ {}`",
1226-
arg_tys[1].simd_type(tcx),
1242+
to be a pointer to the element type `{}` of the first \
1243+
argument `{}`, found `{}` != `*_ {}`",
1244+
element_ty1,
12271245
arg_tys[1],
12281246
in_elem,
12291247
in_ty,
1230-
arg_tys[1].simd_type(tcx),
1248+
element_ty1,
12311249
in_elem
12321250
);
12331251
unreachable!();
12341252
}
12351253
};
12361254
assert!(pointer_count > 0);
1237-
assert_eq!(pointer_count - 1, ptr_count(arg_tys[0].simd_type(tcx)));
1238-
assert_eq!(underlying_ty, non_ptr(arg_tys[0].simd_type(tcx)));
1255+
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
1256+
assert_eq!(underlying_ty, non_ptr(element_ty0));
12391257

12401258
// The element type of the third argument must be a signed integer type of any width:
1241-
match arg_tys[2].simd_type(tcx).kind() {
1259+
let (element_ty2, _) = simd_ty_and_len(bx, arg_tys[2]);
1260+
match element_ty2.kind() {
12421261
ty::Int(_) => (),
12431262
_ => {
12441263
require!(
12451264
false,
12461265
"expected element type `{}` of third argument `{}` \
12471266
to be a signed integer type",
1248-
arg_tys[2].simd_type(tcx),
1267+
element_ty2,
12491268
arg_tys[2]
12501269
);
12511270
}
@@ -1297,25 +1316,27 @@ fn generic_simd_intrinsic(
12971316
require_simd!(arg_tys[2], "third");
12981317

12991318
// Of the same length:
1319+
let (_, element_len1) = simd_ty_and_len(bx, arg_tys[1]);
1320+
let (_, element_len2) = simd_ty_and_len(bx, arg_tys[2]);
13001321
require!(
1301-
in_len == arg_tys[1].simd_size(tcx),
1322+
in_len == element_len1,
13021323
"expected {} argument with length {} (same as input type `{}`), \
1303-
found `{}` with length {}",
1324+
found `{}` with length {}",
13041325
"second",
13051326
in_len,
13061327
in_ty,
13071328
arg_tys[1],
1308-
arg_tys[1].simd_size(tcx)
1329+
element_len1
13091330
);
13101331
require!(
1311-
in_len == arg_tys[2].simd_size(tcx),
1332+
in_len == element_len2,
13121333
"expected {} argument with length {} (same as input type `{}`), \
1313-
found `{}` with length {}",
1334+
found `{}` with length {}",
13141335
"third",
13151336
in_len,
13161337
in_ty,
13171338
arg_tys[2],
1318-
arg_tys[2].simd_size(tcx)
1339+
element_len2
13191340
);
13201341

13211342
// This counts how many pointers
@@ -1336,39 +1357,42 @@ fn generic_simd_intrinsic(
13361357

13371358
// The second argument must be a simd vector with an element type that's a pointer
13381359
// to the element type of the first argument
1339-
let (pointer_count, underlying_ty) = match arg_tys[1].simd_type(tcx).kind() {
1360+
let (element_ty0, _element_len0) = simd_ty_and_len(bx, arg_tys[0]);
1361+
let (element_ty1, _element_len1) = simd_ty_and_len(bx, arg_tys[1]);
1362+
let (element_ty2, _element_len2) = simd_ty_and_len(bx, arg_tys[2]);
1363+
let (pointer_count, underlying_ty) = match element_ty1.kind() {
13401364
ty::RawPtr(p) if p.ty == in_elem && p.mutbl == hir::Mutability::Mut => {
1341-
(ptr_count(arg_tys[1].simd_type(tcx)), non_ptr(arg_tys[1].simd_type(tcx)))
1365+
(ptr_count(element_ty1), non_ptr(element_ty1))
13421366
}
13431367
_ => {
13441368
require!(
13451369
false,
13461370
"expected element type `{}` of second argument `{}` \
1347-
to be a pointer to the element type `{}` of the first \
1348-
argument `{}`, found `{}` != `*mut {}`",
1349-
arg_tys[1].simd_type(tcx),
1371+
to be a pointer to the element type `{}` of the first \
1372+
argument `{}`, found `{}` != `*mut {}`",
1373+
element_ty1,
13501374
arg_tys[1],
13511375
in_elem,
13521376
in_ty,
1353-
arg_tys[1].simd_type(tcx),
1377+
element_ty1,
13541378
in_elem
13551379
);
13561380
unreachable!();
13571381
}
13581382
};
13591383
assert!(pointer_count > 0);
1360-
assert_eq!(pointer_count - 1, ptr_count(arg_tys[0].simd_type(tcx)));
1361-
assert_eq!(underlying_ty, non_ptr(arg_tys[0].simd_type(tcx)));
1384+
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
1385+
assert_eq!(underlying_ty, non_ptr(element_ty0));
13621386

13631387
// The element type of the third argument must be a signed integer type of any width:
1364-
match arg_tys[2].simd_type(tcx).kind() {
1388+
match element_ty2.kind() {
13651389
ty::Int(_) => (),
13661390
_ => {
13671391
require!(
13681392
false,
13691393
"expected element type `{}` of third argument `{}` \
1370-
to be a signed integer type",
1371-
arg_tys[2].simd_type(tcx),
1394+
be a signed integer type",
1395+
element_ty2,
13721396
arg_tys[2]
13731397
);
13741398
}
@@ -1565,7 +1589,7 @@ unsupported {} from `{}` with element `{}` of size `{}` to `{}`"#,
15651589

15661590
if name == sym::simd_cast {
15671591
require_simd!(ret_ty, "return");
1568-
let out_len = ret_ty.simd_size(tcx);
1592+
let (out_elem, out_len) = simd_ty_and_len(bx, ret_ty);
15691593
require!(
15701594
in_len == out_len,
15711595
"expected return type with length {} (same as input type `{}`), \
@@ -1576,8 +1600,6 @@ unsupported {} from `{}` with element `{}` of size `{}` to `{}`"#,
15761600
out_len
15771601
);
15781602
// casting cares about nominal type, not just structural type
1579-
let out_elem = ret_ty.simd_type(tcx);
1580-
15811603
if in_elem == out_elem {
15821604
return Ok(args[0].immediate());
15831605
}
@@ -1693,7 +1715,7 @@ unsupported {} from `{}` with element `{}` of size `{}` to `{}`"#,
16931715
return_error!(
16941716
"expected element type `{}` of vector type `{}` \
16951717
to be a signed or unsigned integer type",
1696-
arg_tys[0].simd_type(tcx),
1718+
simd_ty_and_len(bx, arg_tys[0]).0,
16971719
arg_tys[0]
16981720
);
16991721
}

0 commit comments

Comments
 (0)